In [None]:
from transformers import BertConfig
from gnn import GNNTrainer
from movie_lens_loader import MovieLensLoader
from llm import PromptBertClassifier, VanillaBertClassifier, AddingEmbeddingsBertClassifierBase

In [None]:
config = BertConfig.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
KGE_DIMENSION = 2 # Output Dimension of the GNN Encoder.
model_max_length = 256 if KGE_DIMENSION <= 8 else 512
movie_lens_loader = MovieLensLoader(kge_dimensions = [KGE_DIMENSION])
gnn_trainer =    GNNTrainer(movie_lens_loader.data, kge_dimension = KGE_DIMENSION)
vanilla_bert_only_classifier = VanillaBertClassifier(movie_lens_loader.llm_df,model_max_length = model_max_length)
dataset_vanilla = movie_lens_loader.generate_vanilla_dataset(vanilla_bert_only_classifier.tokenize_function)
prompt_bert_only_classifier = PromptBertClassifier(movie_lens_loader, gnn_trainer.get_embedding, kge_dimension=KGE_DIMENSION, batch_size=64,model_max_length = model_max_length)
dataset_prompt = movie_lens_loader.generate_prompt_embedding_dataset(prompt_bert_only_classifier.tokenize_function, kge_dimension = prompt_bert_only_classifier.kge_dimension)


In [None]:
gnn_trainer_large = GNNTrainer(movie_lens_loader.data, hidden_channels=config.hidden_size)
#gnn_trainer_large.train_model(movie_lens_loader.gnn_train_data, 10)
#gnn_trainer_large.validate_model(movie_lens_loader.gnn_test_data)
gnn_trainer_large.get_embeddings(movie_lens_loader)
adding_embedding_bert_only_classifier = AddingEmbeddingsBertClassifierBase(movie_lens_loader, gnn_trainer_large.get_embedding, kge_dimension=config.hidden_size, batch_size=64,model_max_length = model_max_length)
dataset_adding_embedding = movie_lens_loader.generate_adding_embedding_dataset(adding_embedding_bert_only_classifier.tokenizer.sep_token, adding_embedding_bert_only_classifier.tokenizer.pad_token, adding_embedding_bert_only_classifier.tokenize_function, kge_dimension = config.hidden_size)

In [None]:
dataset_adding_embedding["train"][0]["graph_embeddings"]

In [None]:
adding_embedding_bert_only_classifier.train_model_on_data(dataset_adding_embedding, epochs = 1)

In [None]:
import random as rd
import ast
from typing import Optional, Union, Tuple, List

import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import BertForSequenceClassification, BertModel, BertTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments
from transformers.modeling_outputs import SequenceClassifierOutput

In [None]:
class InsertEmbeddingBertForSequenceClassification(BertForSequenceClassification):
    
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        graph_embeddings: Optional[torch.Tensor] = None,
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        if inputs_embeds is None:
            inputs_embeds = self.bert.embeddings(input_ids)
        if graph_embeddings is not None and len(graph_embeddings) > 0:
            
            if attention_mask is not None:
                
                mask = ((attention_mask.sum(dim = 1) -1).unsqueeze(1).repeat((1,2))-torch.tensor([3,1])).unsqueeze(2).repeat((1,1,self.config.hidden_size))        
                inputs_embeds = inputs_embeds.scatter(1, mask, graph_embeddings)
        outputs = self.bert(
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
model = InsertEmbeddingBertForSequenceClassification.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
tokenizer = BertTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2", model_max_length=128)
sep_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
cls_token = tokenizer.cls_token

In [None]:
embeddings = torch.rand(3, 2, model.config.hidden_size)
encodes = tokenizer([f"test{sep_token}{pad_token}{sep_token}{pad_token}", f"test2{sep_token}{pad_token}{sep_token}{pad_token}"], return_tensors="pt", padding = "max_length", truncation=True)
embeddings.shape, encodes.input_ids.shape, encodes.attention_mask.shape

In [None]:
model.forward(**encodes, graph_embeddings=embeddings)

In [None]:
mask = (torch.tensor([6, 7]).unsqueeze(1).repeat((1,2))-3).unsqueeze(2).repeat((1,1,3))
mask

In [None]:
emb = torch.rand(2, 7, 3)
print("input", emb.shape)
mask_ = torch.tensor([[[2,2,2],[4,4,4]], [[3,3,3],[5,5,5]]])
mask = (torch.tensor([6, 7]).unsqueeze(1).repeat((1,2))-torch.tensor([3,1])).unsqueeze(2).repeat((1,1,3))
print("mask", mask.shape)
values = torch.rand(2, 2, 3)
print("assign", values.shape)
print("mask", mask)
print("assign", values)
print("input", emb)
print(emb.scatter(1, mask, values))
#emb[mask] = torch.tensor([0.111,0.222])

In [None]:
model = vanilla_bert_only_classifier.model
tokenize_function = vanilla_bert_only_classifier.tokenize_function

In [None]:
sample = movie_lens_loader.sample_vanilla_datapoint(existing=False)
sample = tokenize_function(sample, return_pt=True)
output = model.forward(input_ids = sample["input_ids"], attention_mask=sample["attention_mask"])
output

In [None]:
tokenize_function = prompt_bert_only_classifier.tokenize_function

In [None]:
prompt_sample = movie_lens_loader.sample_prompt_datapoint(gnn_trainer.get_embedding, kgeg_dimension=KGE_DIMENSION)
prompt_sample_tokenized = tokenize_function(prompt_sample, return_pt=True)
print(prompt_sample)
model.forward(input_ids = prompt_sample_tokenized["input_ids"], attention_mask=prompt_sample_tokenized["attention_mask"], extra_features=torch.tensor([-1.1454768180847168, 2.0188798904418945]))

In [None]:
vanilla_bert_only_classifier.plot_confusion_matrix(dataset=dataset_vanilla, split = "val")

In [None]:
[prompt_bert_only_classifier.plot_confusion_matrix(dataset=dataset_prompt, split = "val") for prompt_bert_only_classifier, dataset_prompt in zip(prompt_bert_only_classifiers, datasets_prompt)]

In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer
import pandas as pd
import numpy as np

In [None]:
prompt_negative_sample = movie_lens_loader.sample_prompt_datapoint(existing=False, get_embedding_cb=gnn_trainer.get_embedding, tokenize_function=prompt_bert_only_classifier.tokenize_function)
prompt_positive_sample = movie_lens_loader.sample_prompt_datapoint(tokenize_function=prompt_bert_only_classifier.tokenize_function)
vanilla_negative_sample = movie_lens_loader.sample_vanilla_datapoint(existing=False, tokenize_function=vanilla_bert_only_classifier.tokenize_function)
vanilla_positive_sample = movie_lens_loader.sample_vanilla_datapoint(tokenize_function=vanilla_bert_only_classifier.tokenize_function)

In [None]:
prompt_negative_sample

# Current State
Here I want to plot the attentions not only between single tokens but between the embedding part and non-embedding part.

In [None]:
def find_sub_list(sl,l):
    sll=len(sl)
    for ind in (i for i,e in enumerate(l) if e==sl[0]):
        if l[ind:ind+sll]==sl:
            return ind

def foo(self: PromptBertClassifier, sample: dict, layer = -1):
    self.model.eval()
    with torch.no_grad():
        outputs = self.model(input_ids = sample["input_ids"], attention_mask = sample["attention_mask"], output_attentions=True)
        attentions = outputs.attentions  # This will contain the attention weights for each layer and head
    combined_attention = torch.sum(attentions[layer], dim=1).squeeze().detach().numpy()
    # Tokenize the text to get the token labels
    tokens = self.tokenizer.convert_ids_to_tokens(sample['input_ids'][0])
    print(tokens)
    starting_index_user_embeddings = find_sub_list(['user', 'em', '##bed', '##ding', ':', '['], tokens)
    starting_index_movie_embeddings = find_sub_list(['movie', 'em', '##bed', '##ding', ':', '['], tokens)

    # Plot the combined attention weights
    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(combined_attention, xticklabels=tokens, yticklabels=tokens, cmap='viridis', ax=ax)
    plt.title('Combined Attention Weights for Layer 1 After Linear Projection')
    plt.xlabel('Tokens')
    plt.ylabel('Tokens')
    plt.show()
foo(prompt_bert_only_classifiers[0], prompt_negative_sample)