In [None]:
from transformers import BertConfig
from gnn import GNNTrainer
from movie_lens_loader import MovieLensLoader
from llm import PromptBertClassifier, VanillaBertClassifier, AddingEmbeddingsBertClassifierBase,mean_over_ranges, avg_over_states

import itertools

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from torch.utils.data import DataLoader
import networkx as nx
import os
import pandas as pd
import ast

In [None]:
config = BertConfig.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
KGE_DIMENSION_PROMPT = 4
KGE_DIMENSION_ADDING = 128
KGE_DIMENSIONS = [KGE_DIMENSION_PROMPT, KGE_DIMENSION_ADDING] # Output Dimension of the GNN Encoder.
model_max_length = 256
movie_lens_loader = MovieLensLoader(kge_dimensions = KGE_DIMENSIONS)
gnn_trainer_prompt =    GNNTrainer(movie_lens_loader.data, kge_dimension = KGE_DIMENSION_PROMPT)
gnn_trainer_prompt.get_embeddings(movie_lens_loader)
gnn_trainer_adding =    GNNTrainer(movie_lens_loader.data, hidden_channels=KGE_DIMENSION_ADDING, kge_dimension = KGE_DIMENSION_ADDING)
gnn_trainer_adding.get_embeddings(movie_lens_loader)
vanilla_bert_only_classifier = VanillaBertClassifier(movie_lens_loader.llm_df,model_max_length = model_max_length)
dataset_vanilla = movie_lens_loader.generate_vanilla_dataset(vanilla_bert_only_classifier.tokenize_function)
prompt_bert_only_classifier = PromptBertClassifier(movie_lens_loader, gnn_trainer_prompt.get_embedding, kge_dimension=gnn_trainer_prompt.kge_dimension, batch_size=64,model_max_length = model_max_length)
dataset_prompt = movie_lens_loader.generate_prompt_embedding_dataset(prompt_bert_only_classifier.tokenize_function, kge_dimension = prompt_bert_only_classifier.kge_dimension)
adding_embedding_bert_only_classifier = AddingEmbeddingsBertClassifierBase(movie_lens_loader, gnn_trainer_adding.get_embedding, kge_dimension=config.hidden_size, batch_size=64,model_max_length = model_max_length)
dataset_adding_embedding = movie_lens_loader.generate_adding_embedding_dataset(adding_embedding_bert_only_classifier.tokenizer.sep_token, adding_embedding_bert_only_classifier.tokenizer.pad_token, adding_embedding_bert_only_classifier.tokenize_function, kge_dimension = config.hidden_size)

In [None]:
movie_lens_loader.llm_df["user_embedding_4"][movie_lens_loader.llm_df["user_embedding_4"].apply(lambda emb: "e"  in emb)]

In [None]:
movie_lens_loader.llm_df["user_embedding_4"].apply(lambda ebs: ast.literal_eval(ebs))

In [None]:
vanilla_hidden_states, vanilla_attentions = vanilla_bert_only_classifier.forward_dataset_and_save_outputs(dataset_vanilla)
prompt_hidden_states, prompt_attentions = prompt_bert_only_classifier.forward_dataset_and_save_outputs(dataset_prompt)
adding_embedding_hidden_states, adding_embedding_attentions = adding_embedding_bert_only_classifier.forward_dataset_and_save_outputs(dataset_adding_embedding)

In [None]:
def get_tokens_as_df_vanilla(self: VanillaBertClassifier, input_ids, all_ranges_over_batch) -> pd.DataFrame:
    user_ids = []
    titles = []
    genres = []
    all_semantic_tokens = [user_ids, titles, genres]
    ends = all_ranges_over_batch[:,:,1]
    starts = all_ranges_over_batch[:,:,0]
    # input: # ends: torch.tensor([2, 5, 6]) starts: tensor([0, 2, 4])
    # Compute the maximum length of the ranges
    max_length = (ends - starts).max()
    # Create a range tensor from 0 to max_length-1
    range_tensor = torch.arange(max_length).unsqueeze(0)
    for pos, semantic_tokens in enumerate(all_semantic_tokens):
        # Compute the ranges using broadcasting and masking
        ranges =  starts[:,pos].unsqueeze(1) + range_tensor
        mask = ranges < ends[:,pos].unsqueeze(1)

        # Apply the mask
        result = ranges * mask  # result: tensor([[0, 1, 0], [2, 3, 4], [4, 5, 0]]) here padding index is 0
                                #                        -                     -    positions were padded
        #result = result.unsqueeze(dim = 2).repeat(1,1, input_ids.shape[2])
        gather = input_ids.gather(dim = 1, index = result)
        decoded = self.tokenizer.batch_decode(gather, skip_special_tokens = True)
        if pos == 0:
            semantic_tokens.extend([decode[len("user : "):] for decode in decoded])
        if pos == 1:
            semantic_tokens.extend([decode[len("title : "):] for decode in decoded])
        if pos == 2:
            semantic_tokens.extend([decode[len("genres : "):] for decode in decoded])
    all_semantic_tokens[0] = [int(id) for id in all_semantic_tokens[0]]
    all_semantic_tokens[2] = [ast.literal_eval(string_list) for string_list in all_semantic_tokens[2]]
    data = {"user_id": all_semantic_tokens[0], "title": all_semantic_tokens[1], "genres": all_semantic_tokens[2]}
    df = pd.DataFrame(data)
    return df

def get_tokens_as_df_prompt(self: PromptBertClassifier, input_ids, all_ranges_over_batch) -> pd.DataFrame:
    user_ids = []
    titles = []
    genres = []
    user_embeddings = []
    movie_embeddings = []
    all_semantic_tokens = [user_ids, titles, genres, user_embeddings, movie_embeddings]
    ends = all_ranges_over_batch[:,:,1]
    starts = all_ranges_over_batch[:,:,0]
    # input: # ends: torch.tensor([2, 5, 6]) starts: tensor([0, 2, 4])
    # Compute the maximum length of the ranges
    max_length = (ends - starts).max()
    # Create a range tensor from 0 to max_length-1
    range_tensor = torch.arange(max_length).unsqueeze(0)
    for pos, semantic_tokens in enumerate(all_semantic_tokens):
        # Compute the ranges using broadcasting and masking
        ranges =  starts[:,pos].unsqueeze(1) + range_tensor
        mask = ranges < ends[:,pos].unsqueeze(1)

        # Apply the mask
        result = ranges * mask  # result: tensor([[0, 1, 0], [2, 3, 4], [4, 5, 0]]) here padding index is 0
                                #                        -                     -    positions were padded
        #result = result.unsqueeze(dim = 2).repeat(1,1, input_ids.shape[2])
        gather = input_ids.gather(dim = 1, index = result)
        decoded = self.tokenizer.batch_decode(gather, skip_special_tokens = True)
        if pos == 0:
            semantic_tokens.extend([decode[len("user : "):] for decode in decoded])
        if pos == 1:
            semantic_tokens.extend([decode[len("title : "):] for decode in decoded])
        if pos == 2:
            semantic_tokens.extend([decode[len("genres : "):] for decode in decoded])
        if pos == 3:
            semantic_tokens.extend([decode[len("user_embeddings :"):] for decode in decoded])
        if pos == 4:
            semantic_tokens.extend([decode[len("movie_embeddings :"):] for decode in decoded])
    all_semantic_tokens[0] = [int(id) for id in all_semantic_tokens[0]]
    all_semantic_tokens[2] = [ast.literal_eval(string_list) for string_list in all_semantic_tokens[2]]
    all_semantic_tokens[3] = [ast.literal_eval(string_list.replace(" ", "")) for string_list in all_semantic_tokens[3]]
    all_semantic_tokens[4] = [ast.literal_eval(string_list.replace(" ", "")) for string_list in all_semantic_tokens[4]]
    user_embeddings = torch.tensor(all_semantic_tokens[3])
    movie_embeddings = torch.tensor(all_semantic_tokens[4])
    graph_embeddings = torch.stack([user_embeddings, movie_embeddings])
    data = {"user_id": all_semantic_tokens[0], "title": all_semantic_tokens[1], "genres": all_semantic_tokens[2], "user_embeddings": all_semantic_tokens[3], "movie_embeddings": all_semantic_tokens[4]}
    df = pd.DataFrame(data)
    return df, graph_embeddings

def forward_dataset_and_save_outputs(self: VanillaBertClassifier, dataset, splits = ["val"], batch_size = 64, epochs = 3, force_recompute = False):
    if force_recompute or not os.path.exists(self.attentions_path) or not os.path.exists(self.hidden_states_path) or not os.path.exists(self.tokens_path):
        self.model.eval()
        last_hidden_states = []
        all_attentions = []
        all_ranges_over_batch = []
        input_ids =  []
        labels = []
        splits_ = []
        for split in splits:
            splits_.extend([split] * epochs * batch_size) #* len(dataset[split]))
            data_collator = self._get_data_collator(split)
            for epoch in range(epochs):
                data_loader = DataLoader(dataset=dataset[split], batch_size= batch_size, collate_fn = data_collator)
                #for idx, batch in enumerate(data_loader):
                if True:
                    batch = next(iter(data_loader))
                    with torch.no_grad():
                        outputs = self.model(input_ids = batch["input_ids"], attention_mask = batch["attention_mask"], output_hidden_states=True, output_attentions = True)
                        input_ids.append(batch["input_ids"])
                        labels.extend(batch["labels"])
                        hidden_states = outputs.hidden_states
                        attentions = outputs.attentions
                        last_hidden_states.append(hidden_states[-1])
                        ranges_over_batch = self._get_ranges_over_batch(batch["input_ids"])
                        all_ranges_over_batch.append(ranges_over_batch)
                        all_attentions.append(torch.stack([torch.sum(attention, dim=1).detach() for attention in attentions]))
        # Concatenate all hidden states across batches
        last_hidden_states = torch.cat(last_hidden_states)
        all_ranges_over_batch = torch.cat(all_ranges_over_batch)
        input_ids = torch.cat(input_ids)
        labels = torch.stack(labels).tolist()
        all_attentions = [layer.reshape(layer.shape[1], layer.shape[2], layer.shape[3], -1) for layer in all_attentions]
        all_attentions = torch.cat(all_attentions)
        averaged_hidden_states, averaged_attentions = avg_over_states(all_ranges_over_batch, last_hidden_states, all_attentions)
        all_tokens = get_tokens_as_df_prompt(self, input_ids, all_ranges_over_batch)
        all_tokens["labels"] = labels
        all_tokens["split"] = splits_ 
        
        torch.save(averaged_attentions, self.attentions_path)
        torch.save(averaged_hidden_states, self.hidden_states_path)
        all_tokens.to_csv(self.tokens_path, index = False)
        #averaged_hidden_states =averaged_hidden_states.reshape(averaged_hidden_states.shape[1], averaged_hidden_states.shape[0], -1)
    else:
        averaged_hidden_states = torch.load(self.hidden_states_path)
        averaged_attentions = torch.load(self.attentions_path)
        all_tokens = pd.read_csv(self.tokens_path)
    return averaged_hidden_states, averaged_attentions, all_tokens

def forward_dataset_and_save_outputs_prompt(self: PromptBertClassifier, dataset, splits = ["val"], batch_size = 64, epochs = 3, force_recompute = False):
    if force_recompute or not os.path.exists(self.attentions_path) or not os.path.exists(self.hidden_states_path) or not os.path.exists(self.tokens_path):
        self.model.eval()
        last_hidden_states = []
        all_attentions = []
        all_ranges_over_batch = []
        input_ids =  []
        labels = []
        splits_ = []
        for split in splits:
            splits_.extend([split] * epochs * batch_size) #* len(dataset[split]))
            data_collator = self._get_data_collator(split)
            for epoch in range(epochs):
                data_loader = DataLoader(dataset=dataset[split], batch_size= batch_size, collate_fn = data_collator)
                #for idx, batch in enumerate(data_loader):
                if True:
                    batch = next(iter(data_loader))
                    with torch.no_grad():
                        outputs = self.model(input_ids = batch["input_ids"], attention_mask = batch["attention_mask"], output_hidden_states=True, output_attentions = True)
                        input_ids.append(batch["input_ids"])
                        labels.extend(batch["labels"])
                        hidden_states = outputs.hidden_states
                        attentions = outputs.attentions
                        last_hidden_states.append(hidden_states[-1])
                        ranges_over_batch = self._get_ranges_over_batch(batch["input_ids"])
                        all_ranges_over_batch.append(ranges_over_batch)
                        all_attentions.append(torch.stack([torch.sum(attention, dim=1).detach() for attention in attentions]))
        # Concatenate all hidden states across batches
        last_hidden_states = torch.cat(last_hidden_states)
        all_ranges_over_batch = torch.cat(all_ranges_over_batch)
        input_ids = torch.cat(input_ids)
        labels = torch.stack(labels).tolist()
        all_attentions = [layer.reshape(layer.shape[1], layer.shape[2], layer.shape[3], -1) for layer in all_attentions]
        all_attentions = torch.cat(all_attentions)
        averaged_hidden_states, averaged_attentions = avg_over_states(all_ranges_over_batch, last_hidden_states, all_attentions)
        all_tokens, graph_embeddings = get_tokens_as_df_prompt(self, input_ids, all_ranges_over_batch)
        all_tokens["labels"] = labels
        all_tokens["split"] = splits_ 
        
        torch.save(averaged_attentions, self.attentions_path)
        torch.save(averaged_hidden_states, self.hidden_states_path)
        torch.save(graph_embeddings, self.graph_embeddings_path)
        all_tokens.to_csv(self.tokens_path, index = False)
        #averaged_hidden_states =averaged_hidden_states.reshape(averaged_hidden_states.shape[1], averaged_hidden_states.shape[0], -1)
    else:
        averaged_hidden_states = torch.load(self.hidden_states_path)
        averaged_attentions = torch.load(self.attentions_path)
        graph_embeddings = torch.load(self.graph_embeddings)
        all_tokens = pd.read_csv(self.tokens_path)
    return averaged_hidden_states, averaged_attentions, graph_embeddings, all_tokens


def forward_dataset_and_save_outputs_adding(self: AddingEmbeddingsBertClassifierBase, dataset, splits = ["val"], batch_size = 64, epochs = 3, force_recompute = False):
    if force_recompute or not os.path.exists(self.attentions_path) or not os.path.exists(self.hidden_states_path) or not os.path.exists(self.tokens_path):
        self.model.eval()
        last_hidden_states = []
        all_attentions = []
        all_ranges_over_batch = []
        input_ids =  []
        labels = []
        graph_embeddings = []
        splits_ = []
        for split in splits:
            splits_.extend([split] * epochs * batch_size) #* len(dataset[split]))
            data_collator = self._get_data_collator(split)
            for epoch in range(epochs):
                data_loader = DataLoader(dataset=dataset[split], batch_size= batch_size, collate_fn = data_collator)
                #for idx, batch in enumerate(data_loader):
                if True:
                    batch = next(iter(data_loader))
                    with torch.no_grad():
                        outputs = self.model(input_ids = batch["input_ids"], attention_mask = batch["attention_mask"], graph_embeddings =  batch["graph_embeddings"], output_hidden_states=True, output_attentions = True)
                        input_ids.append(batch["input_ids"])
                        graph_embeddings.append(batch["graph_embeddings"])
                        labels.extend(batch["labels"])
                        hidden_states = outputs.hidden_states
                        attentions = outputs.attentions
                        last_hidden_states.append(hidden_states[-1])
                        ranges_over_batch = self._get_ranges_over_batch(batch["input_ids"])
                        all_ranges_over_batch.append(ranges_over_batch)
                        all_attentions.append(torch.stack([torch.sum(attention, dim=1).detach() for attention in attentions]))
        # Concatenate all hidden states across batches
        graph_embeddings = torch.cat(graph_embeddings)
        last_hidden_states = torch.cat(last_hidden_states)
        all_ranges_over_batch = torch.cat(all_ranges_over_batch)
        input_ids = torch.cat(input_ids)
        labels = torch.stack(labels).tolist()
        all_attentions = [layer.reshape(layer.shape[1], layer.shape[2], layer.shape[3], -1) for layer in all_attentions]
        all_attentions = torch.cat(all_attentions)
        averaged_hidden_states, averaged_attentions = avg_over_states(all_ranges_over_batch, last_hidden_states, all_attentions)
        all_tokens = get_tokens_as_df_vanilla(self, input_ids, all_ranges_over_batch)
        all_tokens["labels"] = labels
        all_tokens["split"] = splits_ 
        
        torch.save(averaged_attentions, self.attentions_path)
        torch.save(averaged_hidden_states, self.hidden_states_path)
        torch.save(graph_embeddings, self.graph_embeddings_path)
        all_tokens.to_csv(self.tokens_path, index = False)
        #averaged_hidden_states =averaged_hidden_states.reshape(averaged_hidden_states.shape[1], averaged_hidden_states.shape[0], -1)
    else:
        averaged_hidden_states = torch.load(self.hidden_states_path)
        averaged_attentions = torch.load(self.attentions_path)
        graph_embeddings = torch.load(self.graph_embeddings)
        all_tokens = pd.read_csv(self.tokens_path)
    return averaged_hidden_states, averaged_attentions, graph_embeddings, all_tokens
prompt_hidden_states, prompt_attentions, graph_embeddings, all_tokens = forward_dataset_and_save_outputs_prompt(prompt_bert_only_classifier, dataset_prompt, splits = ["train","test", "val"], force_recompute=True)
prompt_hidden_states.shape, prompt_attentions.shape


In [None]:
dataset_prompt["train"]["prompt"][0]

In [None]:
for row in dataset_prompt["test"]["prompt"]:
    if "][SEP]movie embedding: " not in row:
        print(row)

In [None]:
all_tokens

In [None]:
vanilla_bert_only_classifier.plot_attention_graph(vanilla_attentions, "Vanilla Attentions Grouped by relevant Semantic Ranges.")
prompt_bert_only_classifier.plot_attention_graph(prompt_attentions, "Prompt Attentions Grouped by relevant Semantic Ranges.")
adding_embedding_bert_only_classifier.plot_attention_graph(adding_embedding_attentions, "Embedding Attentions Grouped by relevant Semantic Ranges.")

In [None]:
vanilla_bert_only_classifier.plot_attention_graph(vanilla_attention_matrix_normalized, "Vanilla Attentions Grouped by relevant Semantic Ranges and Normalzied over Layers.")
prompt_bert_only_classifier.plot_attention_graph(prompt_attention_matrix_normalized, "Prompt Attentions Grouped by relevant Semantic Ranges and Normalzied over Layers.")
adding_embedding_bert_only_classifier.plot_attention_graph(adding_embedding_attention_matrix_normalized, "Embedding Attentions Grouped by relevant Semantic Ranges and Normalzied over Layers.")