In [None]:
from gnn import GNNTrainer
from movie_lens_loader import MovieLensLoader
from llm import PromptEncoderOnlyClassifier, VanillaEncoderOnlyClassifier

First we load the MovieLensLoader, which downloads the Movie Lens dataset (https://files.grouplens.org/datasets/movielens/ml-latest-small.zip) and prepares it to be used on GNN and LLM (approximatly 30 secs first time)

In [None]:

movie_lens_loader = MovieLensLoader()

Next we initialize the GNNTrainer, which expects the complete dataset to read the dataset schema. The GNNTrainer can later be used to train in link prediction.

In [None]:
gnn_trainer = GNNTrainer(movie_lens_loader.data)

We then train and validate the model on the link prediction task. If the model is already trained, we can skip this part.

In [None]:
#gnn_trainer.train_model(movie_lens_loader.gnn_train_data, 10)
#gnn_trainer.validate_model(movie_lens_loader.gnn_val_data)

Next we produce the user embedding and movie embedding for every edge in the dataset. These embeddings can then be used for the LLM on the link-prediction task. Can be skipped if this was already done ones.

In [None]:
llm_df = gnn_trainer.get_embeddings(movie_lens_loader)


Next we initialize the vanilla encoder only classifier. This classifier does only use the NLP part of the prompt (no graph embeddings) for predicting if the given link exists.

In [None]:
vanilla_encoder_only_classifier = VanillaEncoderOnlyClassifier(movie_lens_loader.llm_df)

Next we generate a vanilla llm dataset and tokenize it for training.

In [None]:
dataset_vanilla = movie_lens_loader.generate_vanilla_dataset(vanilla_encoder_only_classifier.tokenize_function)

Next we train the model on the produced dataset. This can be skipped, if already trained ones.

In [None]:
vanilla_encoder_only_classifier.train_model_on_data(dataset_vanilla, epochs=3)

Next we initialize the prompt encoder only classifier. This classifier uses the vanilla prompt and the graph embeddings for its link prediction.

In [None]:
prompt_encoder_only_classifier = PromptEncoderOnlyClassifier(movie_lens_loader, gnn_trainer.get_embedding)

We also generate a prompt dataset, this time the prompts also include 2d embeddings of user and movie

In [None]:
dataset_prompt = movie_lens_loader.generate_prompt_embedding_dataset(prompt_encoder_only_classifier.tokenize_function)

We also train the model. This can be skipped if already done ones.

In [None]:
prompt_encoder_only_classifier.train_model_on_data(dataset_prompt, epochs = 3)

In [None]:
prompt_negative_sample = movie_lens_loader.sample_prompt_datapoint(existing=False, get_embedding_cb=gnn_trainer.get_embedding, tokenize_function=prompt_encoder_only_classifier.tokenize_function)
prompt_positive_sample = movie_lens_loader.sample_prompt_datapoint(tokenize_function=prompt_encoder_only_classifier.tokenize_function)
vanilla_negative_sample = movie_lens_loader.sample_vanilla_datapoint(existing=False, tokenize_function=vanilla_encoder_only_classifier.tokenize_function)
vanilla_positive_sample = movie_lens_loader.sample_vanilla_datapoint(tokenize_function=vanilla_encoder_only_classifier.tokenize_function)

In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2", model_max_length=256)

In [None]:
test = "user: 0, title: Toy Story (1995), genres: ['Adventure', 'Animation', 'Children', 'Comedy', 'Fantasy'],[0.09566975384950638, 0.1871771365404129, -0.9063614010810852, -0.11498883366584778, -0.19359524548053741, 0.05777040123939514, 0.8392307758331299, 0.11021402478218079, -0.9423925876617432, -0.5996741652488708, -0.4607434868812561, 0.14498648047447205, 0.24977082014083862, -0.42377233505249023, -0.30711787939071655, 0.25683215260505676, -0.44503113627433777, 0.36742305755615234, -0.07619776576757431, -0.4392299950122833, -0.5404568910598755, 0.5816857218742371, 0.17842932045459747, 0.4835047721862793, 0.42097967863082886, -0.08676570653915405, 0.7208966612815857, 0.01506001502275467, 0.7144961357116699, 0.8553838729858398, 0.007255390286445618, 0.03739391267299652, -0.4714123010635376, 0.5460798740386963, -0.40846264362335205, 0.23547154664993286, 0.2240963578224182, -0.16536155343055725, -0.2971140146255493, -0.138551726937294, -0.07566192001104355, 0.15557360649108887, -0.44871455430984497, -0.03367704525589943, -0.6841654777526855, 0.49103862047195435, -0.07477638870477676, -0.480299711227417, -0.32982349395751953, 0.34012627601623535, 0.16804444789886475, -0.983254075050354, 0.35996460914611816, -0.46269819140434265, -0.7246091365814209, -1.1533417701721191, -0.45359450578689575, 0.5924228429794312, -0.8475757241249084, 0.47164255380630493, 0.29910025000572205, -0.386064350605011, -0.656150758266449, 0.0900864452123642],,,[0.228755921125412, 0.3039299249649048, -0.3139677047729492, -0.12748712301254272, -0.049015939235687256, -0.5172000527381897, 0.31521815061569214, 0.1052461564540863, -0.41933900117874146, -0.5138099789619446, -0.08881375193595886, 0.5101838707923889, 0.5606170892715454, -0.05260148644447327, -0.24679863452911377, -0.1766817569732666, -0.08767355978488922, 0.5131239295005798, 0.7313504219055176, -0.4598398208618164, -0.126076340675354, 0.10454896092414856, 0.11541566252708435, 0.4154021143913269, -0.12601269781589508, -0.1887604296207428, 0.07926511764526367, 0.1717890202999115, 0.3770367205142975, -0.1629912257194519, 0.07116822898387909, -0.02958083152770996, -0.1266523003578186, -0.23133450746536255, -0.20251916348934174, 0.19798435270786285, -0.011754646897315979, 0.04234490543603897, -0.13728711009025574, -0.1959775686264038, -0.04987315833568573, 0.03847554326057434, -0.18482638895511627, 0.02168307825922966, -0.4490634500980377, -0.12788552045822144, 0.011244302615523338, -0.09192948043346405, 0.3016648590564728, 0.2745092809200287, -0.22431962192058563, -0.26696911454200745, 0.07516489923000336, -0.12599758803844452, -0.30530601739883423, -0.15176814794540405, -0.21515658497810364, 0.06336072087287903, -0.2754612863063812, 0.1787707507610321, -0.06976184248924255, -0.32254263758659363, 0.16390198469161987, 0.37581944465637207]"

In [None]:
dats = [0.09566975384950638, 0.1871771365404129, -0.9063614010810852, -0.11498883366584778, -0.19359524548053741, 0.05777040123939514, 0.8392307758331299, 0.11021402478218079, -0.9423925876617432, -0.5996741652488708, -0.4607434868812561, 0.14498648047447205, 0.24977082014083862, -0.42377233505249023, -0.30711787939071655, 0.25683215260505676, -0.44503113627433777, 0.36742305755615234, -0.07619776576757431, -0.4392299950122833, -0.5404568910598755, 0.5816857218742371, 0.17842932045459747, 0.4835047721862793, 0.42097967863082886, -0.08676570653915405, 0.7208966612815857, 0.01506001502275467, 0.7144961357116699, 0.8553838729858398, 0.007255390286445618, 0.03739391267299652, -0.4714123010635376, 0.5460798740386963, -0.40846264362335205, 0.23547154664993286, 0.2240963578224182, -0.16536155343055725, -0.2971140146255493, -0.138551726937294, -0.07566192001104355, 0.15557360649108887, -0.44871455430984497, -0.03367704525589943, -0.6841654777526855, 0.49103862047195435, -0.07477638870477676, -0.480299711227417, -0.32982349395751953, 0.34012627601623535, 0.16804444789886475, -0.983254075050354, 0.35996460914611816, -0.46269819140434265, -0.7246091365814209, -1.1533417701721191, -0.45359450578689575, 0.5924228429794312, -0.8475757241249084, 0.47164255380630493, 0.29910025000572205, -0.386064350605011, -0.656150758266449, 0.0900864452123642],[0.228755921125412, 0.3039299249649048, -0.3139677047729492, -0.12748712301254272, -0.049015939235687256, -0.5172000527381897, 0.31521815061569214, 0.1052461564540863, -0.41933900117874146, -0.5138099789619446, -0.08881375193595886, 0.5101838707923889, 0.5606170892715454, -0.05260148644447327, -0.24679863452911377, -0.1766817569732666, -0.08767355978488922, 0.5131239295005798, 0.7313504219055176, -0.4598398208618164, -0.126076340675354, 0.10454896092414856, 0.11541566252708435, 0.4154021143913269, -0.12601269781589508, -0.1887604296207428, 0.07926511764526367, 0.1717890202999115, 0.3770367205142975, -0.1629912257194519, 0.07116822898387909, -0.02958083152770996, -0.1266523003578186, -0.23133450746536255, -0.20251916348934174, 0.19798435270786285, -0.011754646897315979, 0.04234490543603897, -0.13728711009025574, -0.1959775686264038, -0.04987315833568573, 0.03847554326057434, -0.18482638895511627, 0.02168307825922966, -0.4490634500980377, -0.12788552045822144, 0.011244302615523338, -0.09192948043346405, 0.3016648590564728, 0.2745092809200287, -0.22431962192058563, -0.26696911454200745, 0.07516489923000336, -0.12599758803844452, -0.30530601739883423, -0.15176814794540405, -0.21515658497810364, 0.06336072087287903, -0.2754612863063812, 0.1787707507610321, -0.06976184248924255, -0.32254263758659363, 0.16390198469161987, 0.37581944465637207]

In [None]:
import ast

In [None]:
embeddings = list(movie_lens_loader.llm_df["user_embedding"].apply(lambda emb: ast.literal_eval(emb)).values)

In [None]:
from sklearn.decomposition import PCA
# Assuming embeddings is your original (n_samples, 64) array
pca = PCA(n_components=1)  # Reduce to 1 dimensions
compressed_embeddings = pca.fit_transform(dats)
compressed_embeddings.squeeze().tolist()

In [None]:
movie_lens_loader.llm_df["compressed_embeddings"] = compressed_embeddings

In [None]:
movie_lens_loader.llm_df["compressed_embeddings"]

In [None]:
len(tokenizer.encode(test))

In [None]:
def find_sub_list(sl,l):
    sll=len(sl)
    for ind in (i for i,e in enumerate(l) if e==sl[0]):
        if l[ind:ind+sll]==sl:
            return ind

def foo(self: PromptEncoderOnlyClassifier, sample: dict, layer = -1):
    self.model.eval()
    with torch.no_grad():
        outputs = self.model(input_ids = sample["input_ids"], attention_mask = sample["attention_mask"], output_attentions=True)
        attentions = outputs.attentions  # This will contain the attention weights for each layer and head
    combined_attention = torch.sum(attentions[layer], dim=1).squeeze().detach().numpy()
    # Tokenize the text to get the token labels
    tokens = self.tokenizer.convert_ids_to_tokens(sample['input_ids'][0])
    starting_index_user_embeddings = find_sub_list(['user', 'em', '##bed', '##ding'], tokens)
    starting_index_movie_embeddings = find_sub_list(['user', 'em', '##bed', '##ding'], tokens)
    print

    # Plot the combined attention weights
    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(combined_attention, xticklabels=tokens, yticklabels=tokens, cmap='viridis', ax=ax)
    plt.title('Combined Attention Weights for Layer 1 After Linear Projection')
    plt.xlabel('Tokens')
    plt.ylabel('Tokens')
    plt.show()
foo(prompt_encoder_only_classifier, prompt_negative_sample)

In [None]:
prompt_encoder_only_classifier.plot_confusion_matrix(dataset=dataset_prompt)

In [None]:
vanilla_encoder_only_classifier.plot_confusion_matrix(dataset=dataset_vanilla)