# Training and Evaluation of GNNs and LLMs
In this notebook, we train the models on the [MovieLens Dataset](https://movielens.org/) after the Pytorch Geometrics Tutorial on [Link Prediction](https://colab.research.google.com/drive/1xpzn1Nvai1ygd_P5Yambc_oe4VBPK_ZT?usp=sharing#scrollTo=vit8xKCiXAue).

First we import all of our dependencies.

The **GraphRepresentationGenerator** manages and trains a GNN model. Its most important interfaces include
**the constructor**, which defines the GNN architecture and loads the pre-trained GNN model if it is already on the hard disk,
**the training method**, which initializes the training on the GNN model and
**the get_embedding methods**, which represent the inference interface to the GNN model and return the corresponding embeddings in the dimension defined in the constructor for given user movie node pairs.

**The MovieLensLoader** loads and manages the data sets. The most important tasks include **saving and (re)loading and transforming** the data sets.

**PromptEncoderOnlyClassifier** and **VanillaEncoderOnlyClassifier** each manage a **prompt (model) LLM** and a **vanilla (model) LLM**. An EncoderOnlyClassifier (ClassifierBase) provides interfaces for training and testing an LLM model.
PromptEncoder and VanillaEncoder differ from their DataCollectors. DataCollectors change the behavior of the models during training and testing and allow data points to be created at runtime. With the help of these collators, we **create non-existent edges on the fly**.

In [1]:
from graph_representation_generator import GraphRepresentationGenerator
from dataset_manager import (
    MovieLensManager,
    PROMPT_KGE_DIMENSION,
    ATTENTION_KGE_DIMENSION,
)
from llm_manager import (
    PromptBertClassifier,
    VanillaBertClassifier,
    AttentionBertClassifierBase,
)

In [2]:
EPOCHS = 20
BATCH_SIZE = 1024

We define in advance which **Knowledge Graph Embedding Dimension (KGE_DIMENSION)** the GNN encoder has. We want to determine from which output dimension the GNN encoder can produce embeddings that lead to a significant increase in performance *without exceeding the context length of the LLMs*. In the original tutorial, the KGE_DIMENSION was $64$.

In [3]:
kg_manager = MovieLensManager()

Using existing file ml-latest-small.zip
Extracting ./data\ml-latest-small.zip


splitting LLM dataset
generate llm dataset...


First we load the MovieLensLoader, which downloads the Movie Lens dataset (https://files.grouplens.org/datasets/movielens/ml-latest-small.zip) and prepares it to be used on GNN and LLM. We also pass the embedding dimensions that we will assume we are training with. First time takes approx. 30 sec.

In [4]:
kg_manager.data

HeteroData(
  source={ node_id=[610] },
  target={
    node_id=[9742],
    x=[9742, 20],
  },
  (source, edge, target)={ edge_index=[2, 100836] },
  (target, rev_edge, source)={ edge_index=[2, 100836] }
)

Next, we initialize the GNN trainers (possible on Cuda), one for each KGE_DIMENSION.
A GNN trainer manages a model and each model consists of an **encoder and classifier** part.

**The encoder** is a parameterized *Grap Convolutional Network (GCN)* with a *2-layer GNN computation graph* and a single *ReLU* activation function in between.

**The classifier** applies the dot-product between source and destination kges to derive edge-level predictions.

In [5]:
graph_representation_generator_prompt = GraphRepresentationGenerator(
    kg_manager.data,
    kg_manager.gnn_train_data,
    kg_manager.gnn_val_data,
    kg_manager.gnn_test_data,
    kge_dimension=PROMPT_KGE_DIMENSION,
)
graph_representation_generator_attention = GraphRepresentationGenerator(
    kg_manager.data,
    kg_manager.gnn_train_data,
    kg_manager.gnn_val_data,
    kg_manager.gnn_test_data,
    hidden_channels=ATTENTION_KGE_DIMENSION,
    kge_dimension=ATTENTION_KGE_DIMENSION,
)

Device: 'cuda'
Device: 'cuda'


We then train and validate the model on the link prediction task.

If the model is already trained, we can skip this part.
Training the models can take up to 5 minutes.

In [6]:
print("Prompt Training")
graph_representation_generator_prompt.train_model(
    kg_manager.gnn_train_data, EPOCHS, BATCH_SIZE
)
graph_representation_generator_prompt.validate_model(kg_manager.gnn_test_data)
print("Attention Training")
graph_representation_generator_attention.train_model(
    kg_manager.gnn_train_data, EPOCHS, BATCH_SIZE
)
graph_representation_generator_attention.validate_model(kg_manager.gnn_test_data)


Prompt Training


100%|██████████| 24/24 [00:01<00:00, 13.88it/s]


Epoch: 001, Loss: 0.6128


100%|██████████| 24/24 [00:01<00:00, 15.54it/s]


Epoch: 002, Loss: 0.4825


100%|██████████| 24/24 [00:01<00:00, 15.14it/s]


Epoch: 003, Loss: 0.3986


100%|██████████| 24/24 [00:01<00:00, 15.17it/s]


Epoch: 004, Loss: 0.3621


100%|██████████| 24/24 [00:01<00:00, 14.20it/s]


Epoch: 005, Loss: 0.3471


100%|██████████| 24/24 [00:01<00:00, 15.15it/s]


Epoch: 006, Loss: 0.3372


100%|██████████| 24/24 [00:01<00:00, 15.31it/s]


Epoch: 007, Loss: 0.3288


100%|██████████| 24/24 [00:01<00:00, 12.79it/s]


Epoch: 008, Loss: 0.3201


100%|██████████| 24/24 [00:02<00:00, 11.97it/s]


Epoch: 009, Loss: 0.3161


100%|██████████| 24/24 [00:01<00:00, 15.39it/s]


Epoch: 010, Loss: 0.3108


100%|██████████| 24/24 [00:01<00:00, 14.18it/s]


Epoch: 011, Loss: 0.3070


100%|██████████| 24/24 [00:01<00:00, 12.11it/s]


Epoch: 012, Loss: 0.3044


100%|██████████| 24/24 [00:01<00:00, 12.91it/s]


Epoch: 013, Loss: 0.3060


100%|██████████| 24/24 [00:01<00:00, 13.41it/s]


Epoch: 014, Loss: 0.3023


100%|██████████| 24/24 [00:01<00:00, 15.65it/s]


Epoch: 015, Loss: 0.2967


100%|██████████| 24/24 [00:01<00:00, 13.90it/s]


Epoch: 016, Loss: 0.2945


100%|██████████| 24/24 [00:01<00:00, 12.34it/s]


Epoch: 017, Loss: 0.2911


100%|██████████| 24/24 [00:01<00:00, 14.26it/s]


Epoch: 018, Loss: 0.2919


100%|██████████| 24/24 [00:01<00:00, 13.99it/s]


Epoch: 019, Loss: 0.2903


100%|██████████| 24/24 [00:02<00:00, 11.49it/s]


Epoch: 020, Loss: 0.2872


100%|██████████| 53/53 [00:03<00:00, 17.39it/s]



Validation AUC: 0.9237
Attention Training


100%|██████████| 24/24 [00:02<00:00, 11.74it/s]


Epoch: 001, Loss: 0.5823


100%|██████████| 24/24 [00:01<00:00, 12.77it/s]


Epoch: 002, Loss: 0.3952


100%|██████████| 24/24 [00:01<00:00, 14.15it/s]


Epoch: 003, Loss: 0.3464


100%|██████████| 24/24 [00:01<00:00, 12.62it/s]


Epoch: 004, Loss: 0.3216


100%|██████████| 24/24 [00:01<00:00, 15.03it/s]


Epoch: 005, Loss: 0.3098


100%|██████████| 24/24 [00:02<00:00, 11.89it/s]


Epoch: 006, Loss: 0.2966


100%|██████████| 24/24 [00:01<00:00, 12.47it/s]


Epoch: 007, Loss: 0.2844


100%|██████████| 24/24 [00:01<00:00, 15.61it/s]


Epoch: 008, Loss: 0.2745


100%|██████████| 24/24 [00:01<00:00, 13.33it/s]


Epoch: 009, Loss: 0.2658


100%|██████████| 24/24 [00:02<00:00, 11.73it/s]


Epoch: 010, Loss: 0.2597


100%|██████████| 24/24 [00:01<00:00, 12.26it/s]


Epoch: 011, Loss: 0.2490


100%|██████████| 24/24 [00:01<00:00, 12.08it/s]


Epoch: 012, Loss: 0.2407


100%|██████████| 24/24 [00:01<00:00, 14.92it/s]


Epoch: 013, Loss: 0.2356


100%|██████████| 24/24 [00:01<00:00, 13.22it/s]


Epoch: 014, Loss: 0.2266


100%|██████████| 24/24 [00:02<00:00, 11.58it/s]


Epoch: 015, Loss: 0.2221


100%|██████████| 24/24 [00:02<00:00, 11.93it/s]


Epoch: 016, Loss: 0.2145


100%|██████████| 24/24 [00:02<00:00, 11.90it/s]


Epoch: 017, Loss: 0.2076


100%|██████████| 24/24 [00:01<00:00, 13.38it/s]


Epoch: 018, Loss: 0.2009


100%|██████████| 24/24 [00:01<00:00, 15.53it/s]


Epoch: 019, Loss: 0.1968


100%|██████████| 24/24 [00:01<00:00, 14.58it/s]


Epoch: 020, Loss: 0.1948


100%|██████████| 53/53 [00:03<00:00, 16.64it/s]


Validation AUC: 0.9297





Next we produce the KGEs for every edge in the dataset. These embeddings can then be used for the LLM on the link-prediction task.

In [7]:
prompt_embeddings = graph_representation_generator_prompt.get_saved_embeddings("prompt")
attention_embeddings = graph_representation_generator_attention.get_saved_embeddings(
    "attention"
)
save = False
if prompt_embeddings is None or attention_embeddings is None:
    prompt_embeddings = graph_representation_generator_prompt.generate_embeddings(
        kg_manager.llm_df
    )
    attention_embeddings = graph_representation_generator_attention.generate_embeddings(
        kg_manager.llm_df
    )
    save = True

kg_manager.append_prompt_graph_embeddings(prompt_embeddings, save=save)
kg_manager.append_attention_graph_embeddings(attention_embeddings, save=save)


Computing embeddings for embedding dimension 4.


In [None]:
kg_manager.add_false_edges(
    1.0,
    graph_representation_generator_prompt.get_embedding,
    graph_representation_generator_attention.get_embedding,
    splits=["val", "test"],
)

Adding 34284 false edges for val.


KeyboardInterrupt: 

In [None]:
kg_manager.llm_df["labels"].unique()

Next we initialize the vanilla encoder only classifier. This classifier does only use the NLP part of the prompt (no KGE) for predicting if the given link exists.

In [None]:
vanilla_bert_classifier = VanillaBertClassifier(
    kg_manager.llm_df, kg_manager.source_df, kg_manager.target_df
)

Next we generate a vanilla llm dataset and tokenize it for training.

In [None]:
dataset_vanilla = kg_manager.generate_vanilla_dataset(
    vanilla_bert_classifier.tokenize_function
)

Next we train the model on the produced dataset. This can be skipped, if already trained ones.

In [None]:
vanilla_bert_classifier.train_model_on_data(
    dataset_vanilla, epochs=EPOCHS, batch_size=BATCH_SIZE
)

Next we initialize the prompt encoder only classifier. This classifier uses the vanilla prompt and the KGEs for its link prediction.

In [None]:
prompt_bert_classifier = PromptBertClassifier(
    kg_manager,
    graph_representation_generator_prompt.get_embedding,
    model_max_length=512,
)

We also generate a prompt dataset, this time the prompts also include the KGEs.

In [None]:
dataset_prompt = kg_manager.generate_prompt_embedding_dataset(
    prompt_bert_classifier.tokenize_function,
)

We also train the model. This can be skipped if already done ones.

In [None]:
prompt_bert_classifier.train_model_on_data(
    dataset_prompt, epochs=EPOCHS, batch_size=BATCH_SIZE
)

In [None]:
attention_bert_classifier = AttentionBertClassifierBase(
    kg_manager,
    graph_representation_generator_attention.get_embedding,
)

In [None]:
dataset_embedding = kg_manager.generate_attention_embedding_dataset(
    attention_bert_classifier.tokenizer.sep_token,
    attention_bert_classifier.tokenizer.pad_token,
    attention_bert_classifier.tokenize_function,
)

In [None]:
attention_bert_classifier.train_model_on_data(
    dataset_embedding, epochs=EPOCHS, batch_size=BATCH_SIZE
)

In [None]:
kg_manager.add_false_edges(
    1.0,
    graph_representation_generator_prompt.get_embedding,
    graph_representation_generator_attention.get_embedding,
    splits=["train"],
)

Adding 112938 false edges for train.


Unnamed: 0,source_id,target_id,id_x,id_y,prompt_feature_title,prompt_feature_genres,labels,split,prompt,prompt_source_embedding,prompt_target_embedding,attention_source_embedding,attention_target_embedding
0,0,0,0,0,Toy Story (1995),"['Adventure', 'Animation', 'Children', 'Comedy...",1,train,"0[SEP]0[SEP]Toy Story (1995)[SEP]['Adventure',...","[-1.0325675010681152, -2.0165305137634277, 0.1...","[0.5134283304214478, -0.30939173698425293, -0....","[0.08172091841697693, -0.06714070588350296, -0...","[-0.06923831254243851, -0.3653912842273712, 0...."
1,0,2,0,2,Grumpier Old Men (1995),"['Comedy', 'Romance']",1,train,0[SEP]2[SEP]Grumpier Old Men (1995)[SEP]['Come...,"[-1.0108203887939453, -1.9447051286697388, 0.1...","[0.8926394581794739, 0.005460023880004883, -1....","[0.230760395526886, -0.06058453395962715, -0.1...","[0.08294467628002167, 0.4050610065460205, 0.05..."
2,0,5,0,5,Heat (1995),"['Action', 'Crime', 'Thriller']",1,train,"0[SEP]5[SEP]Heat (1995)[SEP]['Action', 'Crime'...","[-0.8665281534194946, -1.7003031969070435, 0.2...","[0.6732535362243652, -0.9730542898178101, -1.9...","[0.07108935713768005, 0.008010722696781158, -0...","[-0.12152087688446045, 0.18416042625904083, 0...."
3,0,43,0,43,Seven (a.k.a. Se7en) (1995),"['Mystery', 'Thriller']",1,train,0[SEP]43[SEP]Seven (a.k.a. Se7en) (1995)[SEP][...,"[-1.0160322189331055, -2.0640125274658203, 0.1...","[1.2576618194580078, -1.0972590446472168, -1.4...","[-0.006847113370895386, 0.03658245503902435, -...","[0.31170788407325745, 0.10432334989309311, 0.3..."
4,0,46,0,46,"Usual Suspects, The (1995)","['Crime', 'Mystery', 'Thriller']",1,rest,"0[SEP]46[SEP]Usual Suspects, The (1995)[SEP]['...","[-0.9705848693847656, -1.8131654262542725, 0.1...","[1.1073698997497559, -1.117932915687561, -2.01...","[0.07750901579856873, -0.06391477584838867, -0...","[-0.12741614878177643, 0.42395302653312683, 0...."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
100831,609,9434,609,9434,Split (2017),"['Drama', 'Horror', 'Thriller']",1,val,"609[SEP]9434[SEP]Split (2017)[SEP]['Drama', 'H...","[0.03186917304992676, 0.45992347598075867, -0....","[2.241410732269287, 0.41119661927223206, -0.81...","[0.2923734486103058, -0.06350480020046234, 0.8...","[0.1981135904788971, 0.09776249527931213, 0.28..."
100832,609,9461,609,9461,John Wick: Chapter Two (2017),"['Action', 'Crime', 'Thriller']",1,test,609[SEP]9461[SEP]John Wick: Chapter Two (2017)...,"[0.14328372478485107, 0.7543345093727112, 0.06...","[2.413191318511963, 0.1366730034351349, -1.020...","[0.39458659291267395, 0.07710234820842743, 0.6...","[-0.1703999936580658, 0.4222191870212555, 0.24..."
100833,609,9462,609,9462,Get Out (2017),['Horror'],1,train,609[SEP]9462[SEP]Get Out (2017)[SEP]['Horror'],"[0.06981316208839417, 0.5853475332260132, -0.2...","[2.272693157196045, 0.5579444766044617, -1.008...","[0.2807372510433197, -0.13780882954597473, 0.7...","[-0.14196661114692688, -0.034432072192430496, ..."
100834,609,9463,609,9463,Logan (2017),"['Action', 'Sci-Fi']",1,train,"609[SEP]9463[SEP]Logan (2017)[SEP]['Action', '...","[0.12884068489074707, 0.8596121072769165, -0.2...","[1.514894962310791, 0.22600948810577393, -0.94...","[0.4778950810432434, 0.04597288370132446, 0.59...","[-0.6509568691253662, -0.21049413084983826, 0...."


In [None]:
kg_manager.llm_df["labels"].unique()

array([1], dtype=int64)

In [None]:
vanilla_bert_classifier = VanillaBertClassifier(
    kg_manager.llm_df,
    kg_manager.source_df,
    kg_manager.target_df,
    false_ratio=-1,
)
prompt_bert_classifier = PromptBertClassifier(
    kg_manager,
    graph_representation_generator_prompt.get_embedding,
    model_max_length=512,
    false_ratio=-1,
)
attention_bert_classifier = AttentionBertClassifierBase(
    kg_manager,
    graph_representation_generator_attention.get_embedding,
    false_ratio=-1,
)


In [None]:
dataset_vanilla_fixed = kg_manager.generate_vanilla_dataset(
    vanilla_bert_classifier.tokenize_function, suffix="_fixed", force_recompute=True
)
dataset_prompt_fixed = kg_manager.generate_prompt_embedding_dataset(
    prompt_bert_classifier.tokenize_function,
    suffix="_fixed",
    force_recompute=True,
)
dataset_attention_fixed = kg_manager.generate_attention_embedding_dataset(
    attention_bert_classifier.tokenizer.sep_token,
    attention_bert_classifier.tokenizer.pad_token,
    attention_bert_classifier.tokenize_function,
    suffix="_fixed",
    force_recompute=True,
)

Map:   0%|          | 0/56469 [00:00<?, ? examples/s]

Map:   0%|          | 0/17142 [00:00<?, ? examples/s]

Map:   0%|          | 0/17142 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/56469 [00:00<?, ? examples/s]

OSError: [Errno 22] Invalid argument: 'c:/Users/MARS/Ahmad/Hauptprojekt/data/llm/vanilla_dataset_fixed/train/data-00000-of-00001.arrow'

In [None]:
vanilla_df = vanilla_bert_classifier.forward_dataset_and_save_outputs(
    dataset_vanilla_fixed,
    kg_manager.get_vanilla_tokens_as_df,
    epochs=1,
    batch_size=BATCH_SIZE,
    force_recompute=False,
)
prompt_df = prompt_bert_classifier.forward_dataset_and_save_outputs(
    dataset_prompt_fixed,
    kg_manager.get_prompt_tokens_as_df,
    epochs=1,
    force_recompute=False,
)
attention_df = attention_bert_classifier.forward_dataset_and_save_outputs(
    dataset_attention_fixed,
    kg_manager.get_vanilla_tokens_as_df,
    epochs=1,
    batch_size=BATCH_SIZE,
    force_recompute=False,
)

In [None]:
dataset = kg_manager.generate_huggingface_dataset(vanilla_df, prompt_df, attention_df)

In [None]:
dataset.save_to_disk("./data/dataset.hf")

In [None]:
dataset.push_to_hub("AhmadPython/MovieLens_KGE")

Uploading the dataset shards:   0%|          | 0/7 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/9 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/9 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/9 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/9 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/9 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/9 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/9 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/3 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/3 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/157 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


CommitInfo(commit_url='https://huggingface.co/datasets/AhmadPython/MovieLens_KGE/commit/9ed636581248eb8991a9f6730a9b62a86287affd', commit_message='Upload dataset', commit_description='', oid='9ed636581248eb8991a9f6730a9b62a86287affd', pr_url=None, pr_revision=None, pr_num=None)

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['source_id', 'target_id', 'id_x', 'id_y', 'prompt_feature_title', 'prompt_feature_genres', 'labels', 'split', 'prompt', 'prompt_source_embedding', 'prompt_target_embedding', 'attention_source_embedding', 'attention_target_embedding', 'vanilla_attentions', 'vanilla_hidden_states', 'vanilla_attentions_original_shape', 'vanilla_hidden_states_original_shape', 'prompt_attentions', 'prompt_hidden_states', 'prompt_attentions_original_shape', 'prompt_hidden_states_original_shape', 'attention_attentions', 'attention_hidden_states', 'attention_attentions_original_shape', 'attention_hidden_states_original_shape'],
        num_rows: 56469
    })
    val: Dataset({
        features: ['source_id', 'target_id', 'id_x', 'id_y', 'prompt_feature_title', 'prompt_feature_genres', 'labels', 'split', 'prompt', 'prompt_source_embedding', 'prompt_target_embedding', 'attention_source_embedding', 'attention_target_embedding', 'vanilla_attentions', 'vanilla_h