# Training and Evaluation of GNNs and LLMs
In this notebook, we train the models on the [MovieLens Dataset](https://movielens.org/) after the Pytorch Geometrics Tutorial on [Link Prediction](https://colab.research.google.com/drive/1xpzn1Nvai1ygd_P5Yambc_oe4VBPK_ZT?usp=sharing#scrollTo=vit8xKCiXAue).

First we import all of our dependencies.

The **GraphRepresentationGenerator** manages and trains a GNN model. Its most important interfaces include
**the constructor**, which defines the GNN architecture and loads the pre-trained GNN model if it is already on the hard disk,
**the training method**, which initializes the training on the GNN model and
**the get_embedding methods**, which represent the inference interface to the GNN model and return the corresponding embeddings in the dimension defined in the constructor for given user movie node pairs.

**The MovieLensLoader** loads and manages the data sets. The most important tasks include **saving and (re)loading and transforming** the data sets.

**PromptEncoderOnlyClassifier** and **VanillaEncoderOnlyClassifier** each manage a **prompt (model) LLM** and a **vanilla (model) LLM**. An EncoderOnlyClassifier (ClassifierBase) provides interfaces for training and testing an LLM model.
PromptEncoder and VanillaEncoder differ from their DataCollectors. DataCollectors change the behavior of the models during training and testing and allow data points to be created at runtime. With the help of these collators, we **create non-existent edges on the fly**.

In [1]:
from graph_representation_generator.graph_representation_generator import (
    GraphRepresentationGenerator,
)
from dataset_manager.movie_lens_manager import MovieLensManager
from dataset_manager.kg_manager import ROOT
from llm_manager.graph_prompter_hf.classifier import GraphPrompterHF

In [2]:
MODEL_NAME = "google/bert_uncased_L-2_H-128_A-2"
EPOCHS = 20
BATCH_SIZE_LLM = 256

We define in advance which **Knowledge Graph Embedding Dimension (KGE_DIMENSION)** the GNN encoder has. We want to determine from which output dimension the GNN encoder can produce embeddings that lead to a significant increase in performance *without exceeding the context length of the LLMs*. In the original tutorial, the KGE_DIMENSION was $64$.

In [3]:
kg_manager = MovieLensManager()

llm_df = kg_manager.llm_df.merge(kg_manager.target_df[["id", "prompt_feature_title", "prompt_feature_genres"]].rename(columns={"id": "target_id"}), on = "target_id")
llm_df

First we load the MovieLensLoader, which downloads the Movie Lens dataset (https://files.grouplens.org/datasets/movielens/ml-32m.zip) and prepares it to be used on GNN and LLM. We also pass the embedding dimensions that we will assume we are training with. First time takes approx. 30 sec.

In [4]:
kg_manager.data

HeteroData(
  source={ node_id=[610] },
  target={
    node_id=[9742],
    x=[9742, 20],
  },
  (source, edge, target)={ edge_index=[2, 100836] },
  (target, rev_edge, source)={ edge_index=[2, 100836] }
)

Next, we initialize the GNN trainers (possible on Cuda), one for each KGE_DIMENSION.
A GNN trainer manages a model and each model consists of an **encoder and classifier** part.

**The encoder** is a parameterized *Grap Convolutional Network (GCN)* with a *2-layer GNN computation graph* and a single *ReLU* activation function in between.

**The classifier** applies the dot-product between source and destination kges to derive edge-level predictions.

In [5]:
graph_representation_generator_graph_prompter_hf = GraphRepresentationGenerator(
    kg_manager.data.to("cuda"),
    kg_manager.gnn_train_data.to("cuda"),
    kg_manager.gnn_val_data.to("cuda"),
    kg_manager.gnn_test_data.to("cuda"),
    hidden_channels=128,
    kge_dimension=128,
    force_recompute=False,
    device="cuda",
)

loading pretrained model
Device: 'cuda'


Next we produce the KGEs for every edge in the dataset. These embeddings can then be used for the LLM on the link-prediction task.

Next we initialize the vanilla encoder only classifier. This classifier does only use the NLP part of the prompt (no KGE) for predicting if the given link exists.

In [6]:
graph_prompter_hf_bert_classifier = GraphPrompterHF(
    kge_manager=kg_manager,
    get_embeddings_cb=graph_representation_generator_graph_prompter_hf.get_embeddings,
    model_name=MODEL_NAME,
    vanilla_model_path=f"{ROOT}/llm/graph_prompter_hf_frozen/training/best",
    gnn_parameters=list(
        graph_representation_generator_graph_prompter_hf.model.parameters()
    ),
)

device cuda
5
5
5


In [7]:
dataset_embedding = kg_manager.generate_graph_prompter_hf_embedding_dataset(
    graph_prompter_hf_bert_classifier.tokenizer.sep_token,
    graph_prompter_hf_bert_classifier.tokenizer.pad_token,
    graph_prompter_hf_bert_classifier.tokenize_function,
)

In [8]:
graph_prompter_hf_bert_classifier.model.device

device(type='cuda', index=0)

In [9]:
graph_prompter_hf_bert_classifier.train_model_on_data(
    dataset_embedding,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE_LLM,
)
graph_representation_generator_graph_prompter_hf.save_model()

Adding GNN parameters to optimizer


  0%|          | 0/4420 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'loss': 0.355, 'grad_norm': 1.0741111040115356, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.05}
{'loss': 0.3517, 'grad_norm': 1.876395344734192, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.09}
{'loss': 0.3609, 'grad_norm': 0.9085135459899902, 'learning_rate': 3e-06, 'epoch': 0.14}
{'loss': 0.3533, 'grad_norm': 2.4121317863464355, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.18}
{'loss': 0.3715, 'grad_norm': 1.5719934701919556, 'learning_rate': 5e-06, 'epoch': 0.23}
{'loss': 0.3641, 'grad_norm': 3.3062729835510254, 'learning_rate': 6e-06, 'epoch': 0.27}
{'loss': 0.3394, 'grad_norm': 1.954710602760315, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.32}
{'loss': 0.3445, 'grad_norm': 0.876276969909668, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.36}
{'loss': 0.348, 'grad_norm': 3.0029659271240234, 'learning_rate': 9e-06, 'epoch': 0.41}
{'loss': 0.3726, 'grad_norm': 1.1106078624725342, 'learning_rate': 1e-05, 'epoch': 0.45}
{'loss': 0.3467, 'grad_norm': 2.5

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.40815994143486023, 'eval_accuracy': 0.8198948725577705, 'eval_runtime': 25.1849, 'eval_samples_per_second': 800.716, 'eval_steps_per_second': 3.137, 'epoch': 1.0}
{'loss': 0.321, 'grad_norm': 1.1981068849563599, 'learning_rate': 2.3000000000000003e-05, 'epoch': 1.04}
{'loss': 0.3285, 'grad_norm': 1.0269941091537476, 'learning_rate': 2.4e-05, 'epoch': 1.09}
{'loss': 0.3332, 'grad_norm': 1.42103910446167, 'learning_rate': 2.5e-05, 'epoch': 1.13}
{'loss': 0.3398, 'grad_norm': 1.0261763334274292, 'learning_rate': 2.6000000000000002e-05, 'epoch': 1.18}
{'loss': 0.3434, 'grad_norm': 1.1420198678970337, 'learning_rate': 2.7000000000000002e-05, 'epoch': 1.22}
{'loss': 0.3436, 'grad_norm': 1.1744475364685059, 'learning_rate': 2.8000000000000003e-05, 'epoch': 1.27}
{'loss': 0.341, 'grad_norm': 0.8749231696128845, 'learning_rate': 2.9e-05, 'epoch': 1.31}
{'loss': 0.3241, 'grad_norm': 0.862678587436676, 'learning_rate': 3e-05, 'epoch': 1.36}
{'loss': 0.327, 'grad_norm': 0.991242706

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.39680215716362, 'eval_accuracy': 0.8296142021223842, 'eval_runtime': 25.0548, 'eval_samples_per_second': 804.875, 'eval_steps_per_second': 3.153, 'epoch': 2.0}
{'loss': 0.3373, 'grad_norm': 0.7364561557769775, 'learning_rate': 4.5e-05, 'epoch': 2.04}
{'loss': 0.3266, 'grad_norm': 1.0501465797424316, 'learning_rate': 4.600000000000001e-05, 'epoch': 2.08}
{'loss': 0.3261, 'grad_norm': 0.9130332469940186, 'learning_rate': 4.7e-05, 'epoch': 2.13}
{'loss': 0.3221, 'grad_norm': 0.8499203324317932, 'learning_rate': 4.8e-05, 'epoch': 2.17}
{'loss': 0.3307, 'grad_norm': 0.8420438170433044, 'learning_rate': 4.9e-05, 'epoch': 2.22}
{'loss': 0.3338, 'grad_norm': 0.9202855825424194, 'learning_rate': 5e-05, 'epoch': 2.26}
{'loss': 0.3218, 'grad_norm': 1.2867136001586914, 'learning_rate': 4.987244897959184e-05, 'epoch': 2.31}
{'loss': 0.3363, 'grad_norm': 0.9354099631309509, 'learning_rate': 4.974489795918368e-05, 'epoch': 2.35}
{'loss': 0.3181, 'grad_norm': 0.8162130117416382, 'learn

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.38974353671073914, 'eval_accuracy': 0.8294158484578003, 'eval_runtime': 25.1772, 'eval_samples_per_second': 800.964, 'eval_steps_per_second': 3.138, 'epoch': 3.0}
{'loss': 0.3238, 'grad_norm': 1.3200784921646118, 'learning_rate': 4.783163265306123e-05, 'epoch': 3.03}
{'loss': 0.3049, 'grad_norm': 0.9575657248497009, 'learning_rate': 4.7704081632653066e-05, 'epoch': 3.08}
{'loss': 0.3147, 'grad_norm': 0.6802219152450562, 'learning_rate': 4.7576530612244904e-05, 'epoch': 3.12}
{'loss': 0.3097, 'grad_norm': 0.9902365207672119, 'learning_rate': 4.744897959183674e-05, 'epoch': 3.17}
{'loss': 0.3209, 'grad_norm': 1.1015164852142334, 'learning_rate': 4.732142857142857e-05, 'epoch': 3.21}
{'loss': 0.3071, 'grad_norm': 0.8271260261535645, 'learning_rate': 4.719387755102041e-05, 'epoch': 3.26}
{'loss': 0.3249, 'grad_norm': 0.8030689358711243, 'learning_rate': 4.706632653061225e-05, 'epoch': 3.3}
{'loss': 0.3114, 'grad_norm': 0.7284666299819946, 'learning_rate': 4.6938775510204086

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.39990898966789246, 'eval_accuracy': 0.8269364276505008, 'eval_runtime': 25.1034, 'eval_samples_per_second': 803.317, 'eval_steps_per_second': 3.147, 'epoch': 4.0}
{'loss': 0.3131, 'grad_norm': 0.9976614713668823, 'learning_rate': 4.502551020408164e-05, 'epoch': 4.03}
{'loss': 0.3061, 'grad_norm': 0.8106124997138977, 'learning_rate': 4.4897959183673474e-05, 'epoch': 4.07}
{'loss': 0.2996, 'grad_norm': 1.612487554550171, 'learning_rate': 4.477040816326531e-05, 'epoch': 4.12}
{'loss': 0.3309, 'grad_norm': 1.1184052228927612, 'learning_rate': 4.464285714285715e-05, 'epoch': 4.16}
{'loss': 0.3209, 'grad_norm': 1.0129389762878418, 'learning_rate': 4.451530612244898e-05, 'epoch': 4.21}
{'loss': 0.3172, 'grad_norm': 0.9195023775100708, 'learning_rate': 4.438775510204082e-05, 'epoch': 4.25}
{'loss': 0.324, 'grad_norm': 1.047583818435669, 'learning_rate': 4.4260204081632656e-05, 'epoch': 4.3}
{'loss': 0.3171, 'grad_norm': 1.625493049621582, 'learning_rate': 4.4132653061224493e-05

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.3860858678817749, 'eval_accuracy': 0.8289199642963404, 'eval_runtime': 25.1535, 'eval_samples_per_second': 801.717, 'eval_steps_per_second': 3.141, 'epoch': 5.0}
{'loss': 0.3109, 'grad_norm': 0.8640084862709045, 'learning_rate': 4.2219387755102045e-05, 'epoch': 5.02}
{'loss': 0.3075, 'grad_norm': 1.278978705406189, 'learning_rate': 4.209183673469388e-05, 'epoch': 5.07}
{'loss': 0.3112, 'grad_norm': 0.7452021241188049, 'learning_rate': 4.196428571428572e-05, 'epoch': 5.11}
{'loss': 0.3115, 'grad_norm': 1.0173640251159668, 'learning_rate': 4.183673469387756e-05, 'epoch': 5.16}
{'loss': 0.2995, 'grad_norm': 1.6588777303695679, 'learning_rate': 4.170918367346939e-05, 'epoch': 5.2}
{'loss': 0.3092, 'grad_norm': 0.7470638155937195, 'learning_rate': 4.1581632653061226e-05, 'epoch': 5.25}
{'loss': 0.3176, 'grad_norm': 1.3117096424102783, 'learning_rate': 4.1454081632653064e-05, 'epoch': 5.29}
{'loss': 0.295, 'grad_norm': 0.9412911534309387, 'learning_rate': 4.13265306122449e-05

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.3716958463191986, 'eval_accuracy': 0.8376971139541803, 'eval_runtime': 25.0838, 'eval_samples_per_second': 803.946, 'eval_steps_per_second': 3.149, 'epoch': 6.0}
{'loss': 0.3249, 'grad_norm': 0.845267653465271, 'learning_rate': 3.9413265306122446e-05, 'epoch': 6.02}
{'loss': 0.3042, 'grad_norm': 0.7267133593559265, 'learning_rate': 3.928571428571429e-05, 'epoch': 6.06}
{'loss': 0.3059, 'grad_norm': 1.2131600379943848, 'learning_rate': 3.915816326530613e-05, 'epoch': 6.11}
{'loss': 0.2997, 'grad_norm': 1.5444597005844116, 'learning_rate': 3.9030612244897965e-05, 'epoch': 6.15}
{'loss': 0.3154, 'grad_norm': 0.7504377961158752, 'learning_rate': 3.8903061224489796e-05, 'epoch': 6.2}
{'loss': 0.3212, 'grad_norm': 0.9384281039237976, 'learning_rate': 3.8775510204081634e-05, 'epoch': 6.24}
{'loss': 0.3047, 'grad_norm': 0.7232848405838013, 'learning_rate': 3.864795918367347e-05, 'epoch': 6.29}
{'loss': 0.3101, 'grad_norm': 0.7312148213386536, 'learning_rate': 3.852040816326531e

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.39719563722610474, 'eval_accuracy': 0.8281761380541506, 'eval_runtime': 25.1109, 'eval_samples_per_second': 803.077, 'eval_steps_per_second': 3.146, 'epoch': 7.0}
{'loss': 0.335, 'grad_norm': 1.0525916814804077, 'learning_rate': 3.6607142857142853e-05, 'epoch': 7.01}
{'loss': 0.315, 'grad_norm': 0.7825809121131897, 'learning_rate': 3.64795918367347e-05, 'epoch': 7.06}
{'loss': 0.3012, 'grad_norm': 0.7532649040222168, 'learning_rate': 3.6352040816326536e-05, 'epoch': 7.1}
{'loss': 0.3211, 'grad_norm': 0.9102079272270203, 'learning_rate': 3.622448979591837e-05, 'epoch': 7.15}
{'loss': 0.308, 'grad_norm': 0.7587213516235352, 'learning_rate': 3.609693877551021e-05, 'epoch': 7.19}
{'loss': 0.3172, 'grad_norm': 0.6602509617805481, 'learning_rate': 3.596938775510204e-05, 'epoch': 7.24}
{'loss': 0.3214, 'grad_norm': 1.143277883529663, 'learning_rate': 3.584183673469388e-05, 'epoch': 7.29}
{'loss': 0.3201, 'grad_norm': 0.8287230730056763, 'learning_rate': 3.571428571428572e-05, 

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.3808134198188782, 'eval_accuracy': 0.8333333333333334, 'eval_runtime': 25.1981, 'eval_samples_per_second': 800.298, 'eval_steps_per_second': 3.135, 'epoch': 8.0}
{'loss': 0.3156, 'grad_norm': 0.873103678226471, 'learning_rate': 3.380102040816326e-05, 'epoch': 8.01}
{'loss': 0.3079, 'grad_norm': 0.643046498298645, 'learning_rate': 3.36734693877551e-05, 'epoch': 8.05}
{'loss': 0.294, 'grad_norm': 0.663425624370575, 'learning_rate': 3.354591836734694e-05, 'epoch': 8.1}
{'loss': 0.3314, 'grad_norm': 0.7626339793205261, 'learning_rate': 3.341836734693878e-05, 'epoch': 8.14}
{'loss': 0.3105, 'grad_norm': 0.8124741911888123, 'learning_rate': 3.329081632653062e-05, 'epoch': 8.19}
{'loss': 0.2906, 'grad_norm': 1.0091280937194824, 'learning_rate': 3.316326530612245e-05, 'epoch': 8.24}
{'loss': 0.3239, 'grad_norm': 1.1109000444412231, 'learning_rate': 3.303571428571429e-05, 'epoch': 8.28}
{'loss': 0.2889, 'grad_norm': 0.6811895966529846, 'learning_rate': 3.2908163265306125e-05, 'e

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.39749082922935486, 'eval_accuracy': 0.8306059704453039, 'eval_runtime': 25.1714, 'eval_samples_per_second': 801.147, 'eval_steps_per_second': 3.138, 'epoch': 9.0}
{'loss': 0.2928, 'grad_norm': 0.9235753417015076, 'learning_rate': 3.0994897959183676e-05, 'epoch': 9.0}
{'loss': 0.3083, 'grad_norm': 0.8418534398078918, 'learning_rate': 3.086734693877551e-05, 'epoch': 9.05}
{'loss': 0.305, 'grad_norm': 0.7816949486732483, 'learning_rate': 3.073979591836735e-05, 'epoch': 9.1}
{'loss': 0.3038, 'grad_norm': 1.0521022081375122, 'learning_rate': 3.061224489795919e-05, 'epoch': 9.14}
{'loss': 0.3065, 'grad_norm': 0.7270749807357788, 'learning_rate': 3.0484693877551023e-05, 'epoch': 9.19}
{'loss': 0.3057, 'grad_norm': 0.9143624901771545, 'learning_rate': 3.0357142857142857e-05, 'epoch': 9.23}
{'loss': 0.2908, 'grad_norm': 1.1849206686019897, 'learning_rate': 3.0229591836734695e-05, 'epoch': 9.28}
{'loss': 0.2814, 'grad_norm': 0.7129260301589966, 'learning_rate': 3.0102040816326533

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.4004717171192169, 'eval_accuracy': 0.8248041257562233, 'eval_runtime': 25.2227, 'eval_samples_per_second': 799.517, 'eval_steps_per_second': 3.132, 'epoch': 10.0}
{'loss': 0.2975, 'grad_norm': 0.9807771444320679, 'learning_rate': 2.8061224489795918e-05, 'epoch': 10.05}
{'loss': 0.2868, 'grad_norm': 1.053775429725647, 'learning_rate': 2.7933673469387756e-05, 'epoch': 10.09}
{'loss': 0.284, 'grad_norm': 0.7571933269500732, 'learning_rate': 2.7806122448979593e-05, 'epoch': 10.14}
{'loss': 0.3133, 'grad_norm': 0.9678820967674255, 'learning_rate': 2.767857142857143e-05, 'epoch': 10.18}
{'loss': 0.3062, 'grad_norm': 1.0475945472717285, 'learning_rate': 2.7551020408163265e-05, 'epoch': 10.23}
{'loss': 0.2932, 'grad_norm': 0.8248042464256287, 'learning_rate': 2.7423469387755103e-05, 'epoch': 10.27}
{'loss': 0.2784, 'grad_norm': 0.9090235829353333, 'learning_rate': 2.729591836734694e-05, 'epoch': 10.32}
{'loss': 0.2872, 'grad_norm': 0.8328433632850647, 'learning_rate': 2.7168367

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.4028257429599762, 'eval_accuracy': 0.817812159079639, 'eval_runtime': 25.2621, 'eval_samples_per_second': 798.272, 'eval_steps_per_second': 3.127, 'epoch': 11.0}
{'loss': 0.2907, 'grad_norm': 0.5590888261795044, 'learning_rate': 2.5255102040816326e-05, 'epoch': 11.04}
{'loss': 0.2894, 'grad_norm': 0.81894451379776, 'learning_rate': 2.5127551020408164e-05, 'epoch': 11.09}
{'loss': 0.2925, 'grad_norm': 0.6611219644546509, 'learning_rate': 2.5e-05, 'epoch': 11.13}
{'loss': 0.3004, 'grad_norm': 0.6888067126274109, 'learning_rate': 2.487244897959184e-05, 'epoch': 11.18}
{'loss': 0.3059, 'grad_norm': 0.6519743800163269, 'learning_rate': 2.4744897959183673e-05, 'epoch': 11.22}
{'loss': 0.2914, 'grad_norm': 0.6286521553993225, 'learning_rate': 2.461734693877551e-05, 'epoch': 11.27}
{'loss': 0.2896, 'grad_norm': 0.6743403673171997, 'learning_rate': 2.448979591836735e-05, 'epoch': 11.31}
{'loss': 0.2969, 'grad_norm': 0.7790269255638123, 'learning_rate': 2.4362244897959186e-05, 'e

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.3995839059352875, 'eval_accuracy': 0.8248537141723693, 'eval_runtime': 25.1366, 'eval_samples_per_second': 802.258, 'eval_steps_per_second': 3.143, 'epoch': 12.0}
{'loss': 0.2791, 'grad_norm': 0.9568220973014832, 'learning_rate': 2.2448979591836737e-05, 'epoch': 12.04}
{'loss': 0.2953, 'grad_norm': 1.023318886756897, 'learning_rate': 2.2321428571428575e-05, 'epoch': 12.08}
{'loss': 0.2973, 'grad_norm': 0.7331791520118713, 'learning_rate': 2.219387755102041e-05, 'epoch': 12.13}
{'loss': 0.278, 'grad_norm': 0.6360389590263367, 'learning_rate': 2.2066326530612247e-05, 'epoch': 12.17}
{'loss': 0.2918, 'grad_norm': 0.6958456039428711, 'learning_rate': 2.193877551020408e-05, 'epoch': 12.22}
{'loss': 0.3074, 'grad_norm': 0.670377254486084, 'learning_rate': 2.181122448979592e-05, 'epoch': 12.26}
{'loss': 0.2849, 'grad_norm': 0.8154102563858032, 'learning_rate': 2.1683673469387756e-05, 'epoch': 12.31}
{'loss': 0.2877, 'grad_norm': 0.657863199710846, 'learning_rate': 2.1556122448

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.3990321159362793, 'eval_accuracy': 0.8236635921848656, 'eval_runtime': 25.0958, 'eval_samples_per_second': 803.562, 'eval_steps_per_second': 3.148, 'epoch': 13.0}
{'loss': 0.2845, 'grad_norm': 0.834611713886261, 'learning_rate': 1.9642857142857145e-05, 'epoch': 13.03}
{'loss': 0.2967, 'grad_norm': 0.7282406091690063, 'learning_rate': 1.9515306122448983e-05, 'epoch': 13.08}
{'loss': 0.2714, 'grad_norm': 0.7536174058914185, 'learning_rate': 1.9387755102040817e-05, 'epoch': 13.12}
{'loss': 0.2951, 'grad_norm': 0.893649160861969, 'learning_rate': 1.9260204081632655e-05, 'epoch': 13.17}
{'loss': 0.2763, 'grad_norm': 0.8349677324295044, 'learning_rate': 1.913265306122449e-05, 'epoch': 13.21}
{'loss': 0.2933, 'grad_norm': 0.696458101272583, 'learning_rate': 1.9005102040816326e-05, 'epoch': 13.26}
{'loss': 0.2924, 'grad_norm': 0.7754443287849426, 'learning_rate': 1.8877551020408164e-05, 'epoch': 13.3}
{'loss': 0.273, 'grad_norm': 0.7297311425209045, 'learning_rate': 1.875000000

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.38524892926216125, 'eval_accuracy': 0.8324407418427056, 'eval_runtime': 25.1305, 'eval_samples_per_second': 802.451, 'eval_steps_per_second': 3.144, 'epoch': 14.0}
{'loss': 0.2731, 'grad_norm': 0.6630458235740662, 'learning_rate': 1.683673469387755e-05, 'epoch': 14.03}
{'loss': 0.2677, 'grad_norm': 0.6397427320480347, 'learning_rate': 1.670918367346939e-05, 'epoch': 14.07}
{'loss': 0.2965, 'grad_norm': 0.7462073564529419, 'learning_rate': 1.6581632653061225e-05, 'epoch': 14.12}
{'loss': 0.3007, 'grad_norm': 0.8417897820472717, 'learning_rate': 1.6454081632653062e-05, 'epoch': 14.16}
{'loss': 0.276, 'grad_norm': 0.7011064887046814, 'learning_rate': 1.6326530612244897e-05, 'epoch': 14.21}
{'loss': 0.2662, 'grad_norm': 0.9347846508026123, 'learning_rate': 1.6198979591836734e-05, 'epoch': 14.25}
{'loss': 0.2757, 'grad_norm': 0.7841396927833557, 'learning_rate': 1.6071428571428572e-05, 'epoch': 14.3}
{'loss': 0.2759, 'grad_norm': 0.8096339106559753, 'learning_rate': 1.594387

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.3939274251461029, 'eval_accuracy': 0.8283249033025886, 'eval_runtime': 25.2645, 'eval_samples_per_second': 798.195, 'eval_steps_per_second': 3.127, 'epoch': 15.0}
{'loss': 0.2789, 'grad_norm': 0.7240311503410339, 'learning_rate': 1.4030612244897959e-05, 'epoch': 15.02}
{'loss': 0.2742, 'grad_norm': 1.0795177221298218, 'learning_rate': 1.3903061224489797e-05, 'epoch': 15.07}
{'loss': 0.2828, 'grad_norm': 1.1326555013656616, 'learning_rate': 1.3775510204081633e-05, 'epoch': 15.11}
{'loss': 0.2862, 'grad_norm': 0.76488196849823, 'learning_rate': 1.364795918367347e-05, 'epoch': 15.16}
{'loss': 0.2758, 'grad_norm': 0.8261799812316895, 'learning_rate': 1.3520408163265308e-05, 'epoch': 15.2}
{'loss': 0.2804, 'grad_norm': 0.6380460858345032, 'learning_rate': 1.3392857142857144e-05, 'epoch': 15.25}
{'loss': 0.2817, 'grad_norm': 0.594569981098175, 'learning_rate': 1.3265306122448982e-05, 'epoch': 15.29}
{'loss': 0.2762, 'grad_norm': 0.7707346677780151, 'learning_rate': 1.31377551

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.3908120095729828, 'eval_accuracy': 0.8306059704453039, 'eval_runtime': 25.3512, 'eval_samples_per_second': 795.466, 'eval_steps_per_second': 3.116, 'epoch': 16.0}
{'loss': 0.2744, 'grad_norm': 0.7249284386634827, 'learning_rate': 1.1224489795918369e-05, 'epoch': 16.02}
{'loss': 0.3042, 'grad_norm': 1.027177333831787, 'learning_rate': 1.1096938775510205e-05, 'epoch': 16.06}
{'loss': 0.2711, 'grad_norm': 0.5578724145889282, 'learning_rate': 1.096938775510204e-05, 'epoch': 16.11}
{'loss': 0.2724, 'grad_norm': 0.7209962606430054, 'learning_rate': 1.0841836734693878e-05, 'epoch': 16.15}
{'loss': 0.2703, 'grad_norm': 0.7440674901008606, 'learning_rate': 1.0714285714285714e-05, 'epoch': 16.2}
{'loss': 0.286, 'grad_norm': 0.7638757228851318, 'learning_rate': 1.0586734693877552e-05, 'epoch': 16.24}
{'loss': 0.2832, 'grad_norm': 0.7951112389564514, 'learning_rate': 1.045918367346939e-05, 'epoch': 16.29}
{'loss': 0.2786, 'grad_norm': 0.6571719646453857, 'learning_rate': 1.03316326

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.3907434344291687, 'eval_accuracy': 0.8301100862838441, 'eval_runtime': 25.2059, 'eval_samples_per_second': 800.052, 'eval_steps_per_second': 3.134, 'epoch': 17.0}
{'loss': 0.2789, 'grad_norm': 0.7972033619880676, 'learning_rate': 8.418367346938775e-06, 'epoch': 17.01}
{'loss': 0.2857, 'grad_norm': 0.6180434226989746, 'learning_rate': 8.290816326530612e-06, 'epoch': 17.06}
{'loss': 0.2632, 'grad_norm': 0.7769445180892944, 'learning_rate': 8.163265306122448e-06, 'epoch': 17.1}
{'loss': 0.2875, 'grad_norm': 0.7252843379974365, 'learning_rate': 8.035714285714286e-06, 'epoch': 17.15}
{'loss': 0.264, 'grad_norm': 0.6928356885910034, 'learning_rate': 7.908163265306124e-06, 'epoch': 17.19}
{'loss': 0.2748, 'grad_norm': 0.8823881149291992, 'learning_rate': 7.78061224489796e-06, 'epoch': 17.24}
{'loss': 0.2732, 'grad_norm': 0.6545863151550293, 'learning_rate': 7.653061224489797e-06, 'epoch': 17.29}
{'loss': 0.2827, 'grad_norm': 0.7143202424049377, 'learning_rate': 7.5255102040816

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.39922094345092773, 'eval_accuracy': 0.8283249033025886, 'eval_runtime': 24.6248, 'eval_samples_per_second': 818.931, 'eval_steps_per_second': 3.208, 'epoch': 18.0}
{'loss': 0.2776, 'grad_norm': 0.7784190773963928, 'learning_rate': 5.612244897959184e-06, 'epoch': 18.01}
{'loss': 0.2789, 'grad_norm': 0.6871457695960999, 'learning_rate': 5.48469387755102e-06, 'epoch': 18.05}
{'loss': 0.2708, 'grad_norm': 1.122105360031128, 'learning_rate': 5.357142857142857e-06, 'epoch': 18.1}
{'loss': 0.267, 'grad_norm': 0.6641886830329895, 'learning_rate': 5.229591836734695e-06, 'epoch': 18.14}
{'loss': 0.2917, 'grad_norm': 0.7204041481018066, 'learning_rate': 5.102040816326531e-06, 'epoch': 18.19}
{'loss': 0.2441, 'grad_norm': 0.888738214969635, 'learning_rate': 4.9744897959183674e-06, 'epoch': 18.24}
{'loss': 0.2654, 'grad_norm': 0.6214226484298706, 'learning_rate': 4.846938775510204e-06, 'epoch': 18.28}
{'loss': 0.2696, 'grad_norm': 0.7031832933425903, 'learning_rate': 4.7193877551020

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.38471725583076477, 'eval_accuracy': 0.8325399186749975, 'eval_runtime': 25.1758, 'eval_samples_per_second': 801.008, 'eval_steps_per_second': 3.138, 'epoch': 19.0}
{'loss': 0.2811, 'grad_norm': 0.7393852472305298, 'learning_rate': 2.806122448979592e-06, 'epoch': 19.0}
{'loss': 0.2738, 'grad_norm': 1.002983570098877, 'learning_rate': 2.6785714285714285e-06, 'epoch': 19.05}
{'loss': 0.2818, 'grad_norm': 0.7903978824615479, 'learning_rate': 2.5510204081632653e-06, 'epoch': 19.1}
{'loss': 0.272, 'grad_norm': 0.5809009671211243, 'learning_rate': 2.423469387755102e-06, 'epoch': 19.14}
{'loss': 0.2502, 'grad_norm': 0.721860945224762, 'learning_rate': 2.295918367346939e-06, 'epoch': 19.19}
{'loss': 0.2747, 'grad_norm': 0.9174560308456421, 'learning_rate': 2.1683673469387757e-06, 'epoch': 19.23}
{'loss': 0.2612, 'grad_norm': 0.7184321284294128, 'learning_rate': 2.040816326530612e-06, 'epoch': 19.28}
{'loss': 0.2779, 'grad_norm': 0.7707904577255249, 'learning_rate': 1.91326530612

  0%|          | 0/79 [00:00<?, ?it/s]

{'eval_loss': 0.39049065113067627, 'eval_accuracy': 0.8331349796687494, 'eval_runtime': 24.9873, 'eval_samples_per_second': 807.051, 'eval_steps_per_second': 3.162, 'epoch': 20.0}
{'train_runtime': 2293.9303, 'train_samples_per_second': 492.334, 'train_steps_per_second': 1.927, 'train_loss': 0.2993068355240973, 'epoch': 20.0}


In [10]:
import torch

tens_1 = graph_representation_generator_graph_prompter_hf.model.to(
    device="cpu"
).state_dict()

tens_2 = torch.load("./data/gnn/backup/model_128.pth")
for key in tens_1.keys():
    if isinstance(tens_1[key], torch.Tensor):
        if (tens_1[key] == tens_2[key]).all():
            print(key, "is same")
        else:
            print(key, "is not same")

target_lin.weight is not same
target_lin.bias is not same
source_emb.weight is not same
target_emb.weight is not same
gnn.conv1.source__edge__target.lin_l.weight is not same
gnn.conv1.source__edge__target.lin_l.bias is not same
gnn.conv1.source__edge__target.lin_r.weight is not same
gnn.conv1.target__rev_edge__source.lin_l.weight is not same
gnn.conv1.target__rev_edge__source.lin_l.bias is not same
gnn.conv1.target__rev_edge__source.lin_r.weight is not same
gnn.conv2.source__edge__target.lin_l.weight is not same
gnn.conv2.source__edge__target.lin_l.bias is not same
gnn.conv2.source__edge__target.lin_r.weight is not same
gnn.conv2.target__rev_edge__source.lin_l.weight is not same
gnn.conv2.target__rev_edge__source.lin_l.bias is not same
gnn.conv2.target__rev_edge__source.lin_r.weight is not same
