# Training and Evaluation of GNNs and LLMs
In this notebook, we train the models on the [MovieLens Dataset](https://movielens.org/) after the Pytorch Geometrics Tutorial on [Link Prediction](https://colab.research.google.com/drive/1xpzn1Nvai1ygd_P5Yambc_oe4VBPK_ZT?usp=sharing#scrollTo=vit8xKCiXAue).

First we import all of our dependencies.

The **GNNTrainer** manages and trains a GNN model. Its most important interfaces include
**the constructor**, which defines the GNN architecture and loads the pre-trained GNN model if it is already on the hard disk,
**the training method**, which initializes the training on the GNN model and
**the get_embedding methods**, which represent the inference interface to the GNN model and return the corresponding embeddings in the dimension defined in the constructor for given user movie node pairs.

**The MovieLensLoader** loads and manages the data sets. The most important tasks include **saving and (re)loading and transforming** the data sets.

**PromptEncoderOnlyClassifier** and **VanillaEncoderOnlyClassifier** each manage a **prompt (model) LLM** and a **vanilla (model) LLM**. An EncoderOnlyClassifier (ClassifierBase) provides interfaces for training and testing an LLM model.
PromptEncoder and VanillaEncoder differ from their DataCollectors. DataCollectors change the behavior of the models during training and testing and allow data points to be created at runtime. With the help of these collators, we **create non-existent edges on the fly**.

In [1]:
from gnn import GNNTrainer
from movie_lens_loader import MovieLensLoader
from llm import PromptBertClassifier, VanillaBertClassifier, AddingEmbeddingsBertClassifierBase

from transformers import AutoConfig

We define in advance which **Knowledge Graph Embedding Dimension (KGE_DIMENSION)** the GNN encoder has. We want to determine from which output dimension the GNN encoder can produce embeddings that lead to a significant increase in performance *without exceeding the context length of the LLMs*. In the original tutorial, the KGE_DIMENSION was $64$.

In [2]:
MODEL_NAME = "google/bert_uncased_L-2_H-128_A-2"
MODEL_HIDDEN_SIZE = AutoConfig.from_pretrained(MODEL_NAME).hidden_size
SMALL_KGE_DIMENSION = 4
LARGE_KGE_DIMENSION = MODEL_HIDDEN_SIZE
KGE_DIMENSIONS = [SMALL_KGE_DIMENSION, LARGE_KGE_DIMENSION] # Output Dimension of the GNN Encoder.
EPOCHS = 20
BATCH_SIZE = 256

First we load the MovieLensLoader, which downloads the Movie Lens dataset (https://files.grouplens.org/datasets/movielens/ml-latest-small.zip) and prepares it to be used on GNN and LLM. We also pass the embedding dimensions that we will assume we are training with. First time takes approx. 30 sec.

In [3]:

movie_lens_loader = MovieLensLoader(kge_dimensions = KGE_DIMENSIONS)

In [4]:
movie_lens_loader.llm_df.head()

Unnamed: 0,mappedUserId,mappedMovieId,title,genres,prompt,split,user_embedding_4,movie_embedding_4,user_embedding_128,movie_embedding_128
0,0,0,Toy Story (1995),"['Adventure', 'Animation', 'Children', 'Comedy...",user: 0[SEP]title: Toy Story (1995)[SEP]genres...,val,"[-1.1001030206680298, 1.8915419578552246, 1.12...","[-0.5872145891189575, 0.7743701934814453, 1.61...","[0.08414871990680695, -0.8586543798446655, -0....","[0.16821661591529846, -0.092796690762043, -0.3..."
1,0,2,Grumpier Old Men (1995),"['Comedy', 'Romance']",user: 0[SEP]title: Grumpier Old Men (1995)[SEP...,train,"[-1.1077896356582642, 2.1780264377593994, 1.23...","[0.09496331214904785, 0.2714781165122986, 1.88...","[0.05774740129709244, -0.9950929284095764, -0....","[0.490773469209671, 0.3398897647857666, -0.477..."
2,0,5,Heat (1995),"['Action', 'Crime', 'Thriller']",user: 0[SEP]title: Heat (1995)[SEP]genres: ['A...,train,"[-1.2368007898330688, 1.965051293373108, 1.000...","[-0.31836411356925964, -0.3369845151901245, 2....","[0.12886668741703033, -0.9324935674667358, -0....","[0.04016011953353882, 0.29606175422668457, -0...."
3,0,43,Seven (a.k.a. Se7en) (1995),"['Mystery', 'Thriller']",user: 0[SEP]title: Seven (a.k.a. Se7en) (1995)...,train,"[-1.1827130317687988, 2.0680408477783203, 1.21...","[-1.0212271213531494, 0.4136617183685303, 2.13...","[0.007488217204809189, -0.8364493250846863, -0...","[0.08211947977542877, 0.01935999095439911, -0...."
4,0,46,"Usual Suspects, The (1995)","['Crime', 'Mystery', 'Thriller']","user: 0[SEP]title: Usual Suspects, The (1995)[...",test,"[-1.0296814441680908, 2.030970811843872, 1.094...","[-0.7692033052444458, 0.5881525874137878, 1.53...","[0.07757449150085449, -0.9171958565711975, -0....","[0.1306394338607788, 0.0786743313074112, -0.00..."


Next, we initialize the GNN trainers (possible on Cuda), one for each KGE_DIMENSION.
A GNN trainer manages a model and each model consists of an **encoder and classifier** part.

**The encoder** is a parameterized *Grap Convolutional Network (GCN)* with a *2-layer GNN computation graph* and a single *ReLU* activation function in between.

**The classifier** applies the dot-product between source and destination kges to derive edge-level predictions.

In [5]:
gnn_trainer =    GNNTrainer(movie_lens_loader.data, kge_dimension = SMALL_KGE_DIMENSION)
gnn_trainer_large = GNNTrainer(movie_lens_loader.data, hidden_channels=MODEL_HIDDEN_SIZE, kge_dimension=MODEL_HIDDEN_SIZE)

loading pretrained model
Device: 'cuda'
loading pretrained model
Device: 'cuda'


We then train and validate the model on the link prediction task.

If the model is already trained, we can skip this part.
Training the models can take up to 5 minutes.

In [6]:

print(gnn_trainer.kge_dimension)
#gnn_trainer.train_model(movie_lens_loader.gnn_train_data, EPOCHS)
gnn_trainer.validate_model(movie_lens_loader.gnn_test_data)
print(gnn_trainer_large.kge_dimension)
#gnn_trainer_large.train_model(movie_lens_loader.gnn_train_data, EPOCHS)
gnn_trainer_large.validate_model(movie_lens_loader.gnn_test_data)


4


100%|██████████| 79/79 [00:04<00:00, 17.09it/s]



Validation AUC: 0.9342
128


100%|██████████| 79/79 [00:04<00:00, 17.10it/s]


Validation AUC: 0.9280





Next we produce the KGEs for every edge in the dataset. These embeddings can then be used for the LLM on the link-prediction task.

In [7]:
gnn_trainer.get_embeddings(movie_lens_loader)
gnn_trainer_large.get_embeddings(movie_lens_loader)
movie_lens_loader.llm_df.head()

Unnamed: 0,mappedUserId,mappedMovieId,title,genres,prompt,split,user_embedding_4,movie_embedding_4,user_embedding_128,movie_embedding_128
0,0,0,Toy Story (1995),"['Adventure', 'Animation', 'Children', 'Comedy...",user: 0[SEP]title: Toy Story (1995)[SEP]genres...,val,"[-1.1001030206680298, 1.8915419578552246, 1.12...","[-0.5872145891189575, 0.7743701934814453, 1.61...","[0.08414871990680695, -0.8586543798446655, -0....","[0.16821661591529846, -0.092796690762043, -0.3..."
1,0,2,Grumpier Old Men (1995),"['Comedy', 'Romance']",user: 0[SEP]title: Grumpier Old Men (1995)[SEP...,train,"[-1.1077896356582642, 2.1780264377593994, 1.23...","[0.09496331214904785, 0.2714781165122986, 1.88...","[0.05774740129709244, -0.9950929284095764, -0....","[0.490773469209671, 0.3398897647857666, -0.477..."
2,0,5,Heat (1995),"['Action', 'Crime', 'Thriller']",user: 0[SEP]title: Heat (1995)[SEP]genres: ['A...,train,"[-1.2368007898330688, 1.965051293373108, 1.000...","[-0.31836411356925964, -0.3369845151901245, 2....","[0.12886668741703033, -0.9324935674667358, -0....","[0.04016011953353882, 0.29606175422668457, -0...."
3,0,43,Seven (a.k.a. Se7en) (1995),"['Mystery', 'Thriller']",user: 0[SEP]title: Seven (a.k.a. Se7en) (1995)...,train,"[-1.1827130317687988, 2.0680408477783203, 1.21...","[-1.0212271213531494, 0.4136617183685303, 2.13...","[0.007488217204809189, -0.8364493250846863, -0...","[0.08211947977542877, 0.01935999095439911, -0...."
4,0,46,"Usual Suspects, The (1995)","['Crime', 'Mystery', 'Thriller']","user: 0[SEP]title: Usual Suspects, The (1995)[...",test,"[-1.0296814441680908, 2.030970811843872, 1.094...","[-0.7692033052444458, 0.5881525874137878, 1.53...","[0.07757449150085449, -0.9171958565711975, -0....","[0.1306394338607788, 0.0786743313074112, -0.00..."


Next we initialize the vanilla encoder only classifier. This classifier does only use the NLP part of the prompt (no KGE) for predicting if the given link exists.

In [8]:
vanilla_bert_classifier = VanillaBertClassifier(movie_lens_loader.llm_df, batch_size=BATCH_SIZE, model_name=MODEL_NAME)

Next we generate a vanilla llm dataset and tokenize it for training.

In [9]:
dataset_vanilla = movie_lens_loader.generate_vanilla_dataset(vanilla_bert_classifier.tokenize_function)

Next we train the model on the produced dataset. This can be skipped, if already trained ones.

In [10]:
#vanilla_bert_classifier.train_model_on_data(dataset_vanilla, epochs=EPOCHS)

Next we initialize the prompt encoder only classifier. This classifier uses the vanilla prompt and the KGEs for its link prediction.

In [11]:
prompt_bert_classifier = PromptBertClassifier(movie_lens_loader, gnn_trainer.get_embedding, kge_dimension=SMALL_KGE_DIMENSION, batch_size=BATCH_SIZE, model_name=MODEL_NAME)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


We also generate a prompt dataset, this time the prompts also include the KGEs.

In [12]:
dataset_prompt = movie_lens_loader.generate_prompt_embedding_dataset(prompt_bert_classifier.tokenize_function, kge_dimension = prompt_bert_classifier.kge_dimension)

We also train the model. This can be skipped if already done ones.

In [13]:
prompt_bert_classifier.train_model_on_data(dataset_prompt, epochs = EPOCHS)

  0%|          | 0/4420 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'loss': 0.679, 'grad_norm': 2.0133321285247803, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.05}
{'loss': 0.6778, 'grad_norm': 2.2803542613983154, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.09}
{'loss': 0.673, 'grad_norm': 1.5807524919509888, 'learning_rate': 3e-06, 'epoch': 0.14}
{'loss': 0.6673, 'grad_norm': 1.5122690200805664, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.18}
{'loss': 0.6604, 'grad_norm': 1.2263641357421875, 'learning_rate': 5e-06, 'epoch': 0.23}
{'loss': 0.654, 'grad_norm': 1.9036740064620972, 'learning_rate': 6e-06, 'epoch': 0.27}
{'loss': 0.6471, 'grad_norm': 1.2255926132202148, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.32}
{'loss': 0.6484, 'grad_norm': 0.3636559844017029, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.36}
{'loss': 0.6453, 'grad_norm': 0.5044724941253662, 'learning_rate': 9e-06, 'epoch': 0.41}
{'loss': 0.6391, 'grad_norm': 0.8040212988853455, 'learning_rate': 1e-05, 'epoch': 0.45}
{'loss': 0.6327, 'grad_norm': 0

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.6177912354469299, 'eval_accuracy': 0.6661999766654999, 'eval_runtime': 166.6224, 'eval_samples_per_second': 102.879, 'eval_steps_per_second': 0.402, 'epoch': 1.0}
{'loss': 0.625, 'grad_norm': 0.40551435947418213, 'learning_rate': 2.3000000000000003e-05, 'epoch': 1.04}
{'loss': 0.6142, 'grad_norm': 0.41211628913879395, 'learning_rate': 2.4e-05, 'epoch': 1.09}
{'loss': 0.6158, 'grad_norm': 0.9771807193756104, 'learning_rate': 2.5e-05, 'epoch': 1.13}
{'loss': 0.6055, 'grad_norm': 0.33162227272987366, 'learning_rate': 2.6000000000000002e-05, 'epoch': 1.18}
{'loss': 0.6205, 'grad_norm': 0.8342613577842712, 'learning_rate': 2.7000000000000002e-05, 'epoch': 1.22}
{'loss': 0.622, 'grad_norm': 0.8178257942199707, 'learning_rate': 2.8000000000000003e-05, 'epoch': 1.27}
{'loss': 0.6343, 'grad_norm': 0.9120437502861023, 'learning_rate': 2.9e-05, 'epoch': 1.31}
{'loss': 0.6142, 'grad_norm': 0.8860794901847839, 'learning_rate': 3e-05, 'epoch': 1.36}
{'loss': 0.6226, 'grad_norm': 0.56

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.5217620134353638, 'eval_accuracy': 0.7382452455956131, 'eval_runtime': 168.6952, 'eval_samples_per_second': 101.615, 'eval_steps_per_second': 0.397, 'epoch': 2.0}
{'loss': 0.5541, 'grad_norm': 1.1645301580429077, 'learning_rate': 4.5e-05, 'epoch': 2.04}
{'loss': 0.5285, 'grad_norm': 2.0737507343292236, 'learning_rate': 4.600000000000001e-05, 'epoch': 2.08}
{'loss': 0.5417, 'grad_norm': 1.7233505249023438, 'learning_rate': 4.7e-05, 'epoch': 2.13}
{'loss': 0.515, 'grad_norm': 5.963149547576904, 'learning_rate': 4.8e-05, 'epoch': 2.17}
{'loss': 0.5224, 'grad_norm': 2.123399019241333, 'learning_rate': 4.9e-05, 'epoch': 2.22}
{'loss': 0.5018, 'grad_norm': 1.1121021509170532, 'learning_rate': 5e-05, 'epoch': 2.26}
{'loss': 0.5205, 'grad_norm': 1.1940871477127075, 'learning_rate': 4.987244897959184e-05, 'epoch': 2.31}
{'loss': 0.497, 'grad_norm': 1.5084012746810913, 'learning_rate': 4.974489795918368e-05, 'epoch': 2.35}
{'loss': 0.5032, 'grad_norm': 1.2528804540634155, 'learni

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.419197142124176, 'eval_accuracy': 0.8065569945163925, 'eval_runtime': 167.6274, 'eval_samples_per_second': 102.263, 'eval_steps_per_second': 0.4, 'epoch': 3.0}
{'loss': 0.4483, 'grad_norm': 2.6346092224121094, 'learning_rate': 4.783163265306123e-05, 'epoch': 3.03}
{'loss': 0.4699, 'grad_norm': 2.4391942024230957, 'learning_rate': 4.7704081632653066e-05, 'epoch': 3.08}
{'loss': 0.4622, 'grad_norm': 1.942014217376709, 'learning_rate': 4.7576530612244904e-05, 'epoch': 3.12}
{'loss': 0.4768, 'grad_norm': 1.635702133178711, 'learning_rate': 4.744897959183674e-05, 'epoch': 3.17}
{'loss': 0.4475, 'grad_norm': 1.3766586780548096, 'learning_rate': 4.732142857142857e-05, 'epoch': 3.21}
{'loss': 0.4543, 'grad_norm': 1.2722172737121582, 'learning_rate': 4.719387755102041e-05, 'epoch': 3.26}
{'loss': 0.4439, 'grad_norm': 1.9239388704299927, 'learning_rate': 4.706632653061225e-05, 'epoch': 3.3}
{'loss': 0.4415, 'grad_norm': 2.896433115005493, 'learning_rate': 4.6938775510204086e-05, 

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.3925262689590454, 'eval_accuracy': 0.8168825107922063, 'eval_runtime': 168.2904, 'eval_samples_per_second': 101.86, 'eval_steps_per_second': 0.398, 'epoch': 4.0}
{'loss': 0.4493, 'grad_norm': 1.1924556493759155, 'learning_rate': 4.502551020408164e-05, 'epoch': 4.03}
{'loss': 0.4475, 'grad_norm': 1.7662426233291626, 'learning_rate': 4.4897959183673474e-05, 'epoch': 4.07}
{'loss': 0.4268, 'grad_norm': 1.5061163902282715, 'learning_rate': 4.477040816326531e-05, 'epoch': 4.12}
{'loss': 0.4318, 'grad_norm': 1.804823398590088, 'learning_rate': 4.464285714285715e-05, 'epoch': 4.16}
{'loss': 0.4332, 'grad_norm': 2.086886405944824, 'learning_rate': 4.451530612244898e-05, 'epoch': 4.21}
{'loss': 0.4314, 'grad_norm': 2.6553871631622314, 'learning_rate': 4.438775510204082e-05, 'epoch': 4.25}
{'loss': 0.4231, 'grad_norm': 1.379791259765625, 'learning_rate': 4.4260204081632656e-05, 'epoch': 4.3}
{'loss': 0.4266, 'grad_norm': 1.9438687562942505, 'learning_rate': 4.4132653061224493e-05

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.37018483877182007, 'eval_accuracy': 0.8317582545793957, 'eval_runtime': 167.3209, 'eval_samples_per_second': 102.45, 'eval_steps_per_second': 0.4, 'epoch': 5.0}
{'loss': 0.4099, 'grad_norm': 2.0873093605041504, 'learning_rate': 4.2219387755102045e-05, 'epoch': 5.02}
{'loss': 0.4082, 'grad_norm': 4.340262413024902, 'learning_rate': 4.209183673469388e-05, 'epoch': 5.07}
{'loss': 0.4023, 'grad_norm': 2.3191778659820557, 'learning_rate': 4.196428571428572e-05, 'epoch': 5.11}
{'loss': 0.4196, 'grad_norm': 2.7521753311157227, 'learning_rate': 4.183673469387756e-05, 'epoch': 5.16}
{'loss': 0.409, 'grad_norm': 1.5038940906524658, 'learning_rate': 4.170918367346939e-05, 'epoch': 5.2}
{'loss': 0.4159, 'grad_norm': 1.8419475555419922, 'learning_rate': 4.1581632653061226e-05, 'epoch': 5.25}
{'loss': 0.4081, 'grad_norm': 2.533839225769043, 'learning_rate': 4.1454081632653064e-05, 'epoch': 5.29}
{'loss': 0.4164, 'grad_norm': 1.6267328262329102, 'learning_rate': 4.13265306122449e-05, 

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.3563588261604309, 'eval_accuracy': 0.8370668533426672, 'eval_runtime': 168.6328, 'eval_samples_per_second': 101.653, 'eval_steps_per_second': 0.397, 'epoch': 6.0}
{'loss': 0.3889, 'grad_norm': 1.6959741115570068, 'learning_rate': 3.9413265306122446e-05, 'epoch': 6.02}
{'loss': 0.4175, 'grad_norm': 3.336632013320923, 'learning_rate': 3.928571428571429e-05, 'epoch': 6.06}
{'loss': 0.3972, 'grad_norm': 1.762070655822754, 'learning_rate': 3.915816326530613e-05, 'epoch': 6.11}
{'loss': 0.3838, 'grad_norm': 1.4784995317459106, 'learning_rate': 3.9030612244897965e-05, 'epoch': 6.15}
{'loss': 0.383, 'grad_norm': 1.5206083059310913, 'learning_rate': 3.8903061224489796e-05, 'epoch': 6.2}
{'loss': 0.3851, 'grad_norm': 1.5879034996032715, 'learning_rate': 3.8775510204081634e-05, 'epoch': 6.24}
{'loss': 0.3905, 'grad_norm': 1.6178771257400513, 'learning_rate': 3.864795918367347e-05, 'epoch': 6.29}
{'loss': 0.3833, 'grad_norm': 3.3765218257904053, 'learning_rate': 3.852040816326531e-

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.34895142912864685, 'eval_accuracy': 0.8432504958581263, 'eval_runtime': 166.8057, 'eval_samples_per_second': 102.766, 'eval_steps_per_second': 0.402, 'epoch': 7.0}
{'loss': 0.366, 'grad_norm': 1.9183104038238525, 'learning_rate': 3.6607142857142853e-05, 'epoch': 7.01}
{'loss': 0.3961, 'grad_norm': 1.4231370687484741, 'learning_rate': 3.64795918367347e-05, 'epoch': 7.06}
{'loss': 0.3851, 'grad_norm': 1.85813307762146, 'learning_rate': 3.6352040816326536e-05, 'epoch': 7.1}
{'loss': 0.384, 'grad_norm': 3.181644916534424, 'learning_rate': 3.622448979591837e-05, 'epoch': 7.15}
{'loss': 0.4002, 'grad_norm': 1.6260340213775635, 'learning_rate': 3.609693877551021e-05, 'epoch': 7.19}
{'loss': 0.3711, 'grad_norm': 1.5887869596481323, 'learning_rate': 3.596938775510204e-05, 'epoch': 7.24}
{'loss': 0.3816, 'grad_norm': 4.565511226654053, 'learning_rate': 3.584183673469388e-05, 'epoch': 7.29}
{'loss': 0.3827, 'grad_norm': 2.1093719005584717, 'learning_rate': 3.571428571428572e-05, '

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.33202528953552246, 'eval_accuracy': 0.8497258196243146, 'eval_runtime': 166.7377, 'eval_samples_per_second': 102.808, 'eval_steps_per_second': 0.402, 'epoch': 8.0}
{'loss': 0.3895, 'grad_norm': 1.75002920627594, 'learning_rate': 3.380102040816326e-05, 'epoch': 8.01}
{'loss': 0.3566, 'grad_norm': 3.6180458068847656, 'learning_rate': 3.36734693877551e-05, 'epoch': 8.05}
{'loss': 0.3943, 'grad_norm': 2.6973016262054443, 'learning_rate': 3.354591836734694e-05, 'epoch': 8.1}
{'loss': 0.3452, 'grad_norm': 2.408785104751587, 'learning_rate': 3.341836734693878e-05, 'epoch': 8.14}
{'loss': 0.388, 'grad_norm': 4.599977970123291, 'learning_rate': 3.329081632653062e-05, 'epoch': 8.19}
{'loss': 0.3629, 'grad_norm': 1.7526192665100098, 'learning_rate': 3.316326530612245e-05, 'epoch': 8.24}
{'loss': 0.365, 'grad_norm': 3.8998208045959473, 'learning_rate': 3.303571428571429e-05, 'epoch': 8.28}
{'loss': 0.3747, 'grad_norm': 2.53245210647583, 'learning_rate': 3.2908163265306125e-05, 'epo

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.32349950075149536, 'eval_accuracy': 0.8561428071403571, 'eval_runtime': 167.6532, 'eval_samples_per_second': 102.247, 'eval_steps_per_second': 0.4, 'epoch': 9.0}
{'loss': 0.3554, 'grad_norm': 2.5443665981292725, 'learning_rate': 3.0994897959183676e-05, 'epoch': 9.0}
{'loss': 0.3689, 'grad_norm': 1.851567029953003, 'learning_rate': 3.086734693877551e-05, 'epoch': 9.05}
{'loss': 0.3611, 'grad_norm': 2.3390023708343506, 'learning_rate': 3.073979591836735e-05, 'epoch': 9.1}
{'loss': 0.358, 'grad_norm': 2.5125954151153564, 'learning_rate': 3.061224489795919e-05, 'epoch': 9.14}
{'loss': 0.3631, 'grad_norm': 1.6115801334381104, 'learning_rate': 3.0484693877551023e-05, 'epoch': 9.19}
{'loss': 0.3758, 'grad_norm': 2.3114757537841797, 'learning_rate': 3.0357142857142857e-05, 'epoch': 9.23}
{'loss': 0.3675, 'grad_norm': 1.5353615283966064, 'learning_rate': 3.0229591836734695e-05, 'epoch': 9.28}
{'loss': 0.3697, 'grad_norm': 3.259267807006836, 'learning_rate': 3.0102040816326533e-0

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.31438156962394714, 'eval_accuracy': 0.8595846458989617, 'eval_runtime': 169.1472, 'eval_samples_per_second': 101.344, 'eval_steps_per_second': 0.396, 'epoch': 10.0}
{'loss': 0.3569, 'grad_norm': 1.371767282485962, 'learning_rate': 2.8061224489795918e-05, 'epoch': 10.05}
{'loss': 0.3793, 'grad_norm': 2.1237711906433105, 'learning_rate': 2.7933673469387756e-05, 'epoch': 10.09}
{'loss': 0.3456, 'grad_norm': 2.1659793853759766, 'learning_rate': 2.7806122448979593e-05, 'epoch': 10.14}
{'loss': 0.3583, 'grad_norm': 2.1240532398223877, 'learning_rate': 2.767857142857143e-05, 'epoch': 10.18}
{'loss': 0.3535, 'grad_norm': 4.539743423461914, 'learning_rate': 2.7551020408163265e-05, 'epoch': 10.23}
{'loss': 0.3531, 'grad_norm': 1.9370315074920654, 'learning_rate': 2.7423469387755103e-05, 'epoch': 10.27}
{'loss': 0.3499, 'grad_norm': 2.4600889682769775, 'learning_rate': 2.729591836734694e-05, 'epoch': 10.32}
{'loss': 0.3666, 'grad_norm': 3.2179789543151855, 'learning_rate': 2.71683

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.31501761078834534, 'eval_accuracy': 0.860926379652316, 'eval_runtime': 168.2349, 'eval_samples_per_second': 101.893, 'eval_steps_per_second': 0.398, 'epoch': 11.0}
{'loss': 0.3552, 'grad_norm': 2.161956548690796, 'learning_rate': 2.5255102040816326e-05, 'epoch': 11.04}
{'loss': 0.3386, 'grad_norm': 2.5855789184570312, 'learning_rate': 2.5127551020408164e-05, 'epoch': 11.09}
{'loss': 0.3542, 'grad_norm': 1.6112260818481445, 'learning_rate': 2.5e-05, 'epoch': 11.13}
{'loss': 0.3683, 'grad_norm': 2.7125349044799805, 'learning_rate': 2.487244897959184e-05, 'epoch': 11.18}
{'loss': 0.3504, 'grad_norm': 1.5861746072769165, 'learning_rate': 2.4744897959183673e-05, 'epoch': 11.22}
{'loss': 0.3483, 'grad_norm': 2.1374118328094482, 'learning_rate': 2.461734693877551e-05, 'epoch': 11.27}
{'loss': 0.3383, 'grad_norm': 1.7424129247665405, 'learning_rate': 2.448979591836735e-05, 'epoch': 11.31}
{'loss': 0.3657, 'grad_norm': 1.5313876867294312, 'learning_rate': 2.4362244897959186e-05,

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.31382784247398376, 'eval_accuracy': 0.8614514059036286, 'eval_runtime': 167.5427, 'eval_samples_per_second': 102.314, 'eval_steps_per_second': 0.4, 'epoch': 12.0}
{'loss': 0.3507, 'grad_norm': 3.4585635662078857, 'learning_rate': 2.2448979591836737e-05, 'epoch': 12.04}
{'loss': 0.3504, 'grad_norm': 1.3815233707427979, 'learning_rate': 2.2321428571428575e-05, 'epoch': 12.08}
{'loss': 0.3442, 'grad_norm': 2.639545440673828, 'learning_rate': 2.219387755102041e-05, 'epoch': 12.13}
{'loss': 0.3644, 'grad_norm': 4.172611236572266, 'learning_rate': 2.2066326530612247e-05, 'epoch': 12.17}
{'loss': 0.3344, 'grad_norm': 2.56211519241333, 'learning_rate': 2.193877551020408e-05, 'epoch': 12.22}
{'loss': 0.3467, 'grad_norm': 1.9352631568908691, 'learning_rate': 2.181122448979592e-05, 'epoch': 12.26}
{'loss': 0.3442, 'grad_norm': 2.5657474994659424, 'learning_rate': 2.1683673469387756e-05, 'epoch': 12.31}
{'loss': 0.3547, 'grad_norm': 2.4472920894622803, 'learning_rate': 2.1556122448

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.3130281865596771, 'eval_accuracy': 0.8616847509042119, 'eval_runtime': 167.6683, 'eval_samples_per_second': 102.238, 'eval_steps_per_second': 0.4, 'epoch': 13.0}
{'loss': 0.3263, 'grad_norm': 2.4004828929901123, 'learning_rate': 1.9642857142857145e-05, 'epoch': 13.03}
{'loss': 0.3466, 'grad_norm': 1.8615409135818481, 'learning_rate': 1.9515306122448983e-05, 'epoch': 13.08}
{'loss': 0.3295, 'grad_norm': 2.857135772705078, 'learning_rate': 1.9387755102040817e-05, 'epoch': 13.12}
{'loss': 0.3404, 'grad_norm': 2.4170479774475098, 'learning_rate': 1.9260204081632655e-05, 'epoch': 13.17}
{'loss': 0.3444, 'grad_norm': 1.6658514738082886, 'learning_rate': 1.913265306122449e-05, 'epoch': 13.21}
{'loss': 0.3471, 'grad_norm': 1.6664605140686035, 'learning_rate': 1.9005102040816326e-05, 'epoch': 13.26}
{'loss': 0.3446, 'grad_norm': 1.8387547731399536, 'learning_rate': 1.8877551020408164e-05, 'epoch': 13.3}
{'loss': 0.3355, 'grad_norm': 1.854002833366394, 'learning_rate': 1.87500000

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.3073578178882599, 'eval_accuracy': 0.8659432971648583, 'eval_runtime': 166.9382, 'eval_samples_per_second': 102.685, 'eval_steps_per_second': 0.401, 'epoch': 14.0}
{'loss': 0.3441, 'grad_norm': 4.481644153594971, 'learning_rate': 1.683673469387755e-05, 'epoch': 14.03}
{'loss': 0.319, 'grad_norm': 2.5402743816375732, 'learning_rate': 1.670918367346939e-05, 'epoch': 14.07}
{'loss': 0.3073, 'grad_norm': 2.4481611251831055, 'learning_rate': 1.6581632653061225e-05, 'epoch': 14.12}
{'loss': 0.3349, 'grad_norm': 2.031440019607544, 'learning_rate': 1.6454081632653062e-05, 'epoch': 14.16}
{'loss': 0.3443, 'grad_norm': 1.6451114416122437, 'learning_rate': 1.6326530612244897e-05, 'epoch': 14.21}
{'loss': 0.362, 'grad_norm': 2.9352810382843018, 'learning_rate': 1.6198979591836734e-05, 'epoch': 14.25}
{'loss': 0.3492, 'grad_norm': 1.5065643787384033, 'learning_rate': 1.6071428571428572e-05, 'epoch': 14.3}
{'loss': 0.3501, 'grad_norm': 2.8169615268707275, 'learning_rate': 1.594387755

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.3080599308013916, 'eval_accuracy': 0.8591762921479408, 'eval_runtime': 166.3749, 'eval_samples_per_second': 103.032, 'eval_steps_per_second': 0.403, 'epoch': 15.0}
{'loss': 0.3271, 'grad_norm': 1.6559295654296875, 'learning_rate': 1.4030612244897959e-05, 'epoch': 15.02}
{'loss': 0.3312, 'grad_norm': 1.5949742794036865, 'learning_rate': 1.3903061224489797e-05, 'epoch': 15.07}
{'loss': 0.343, 'grad_norm': 3.2012698650360107, 'learning_rate': 1.3775510204081633e-05, 'epoch': 15.11}
{'loss': 0.3486, 'grad_norm': 3.549734115600586, 'learning_rate': 1.364795918367347e-05, 'epoch': 15.16}
{'loss': 0.3397, 'grad_norm': 3.4084696769714355, 'learning_rate': 1.3520408163265308e-05, 'epoch': 15.2}
{'loss': 0.3302, 'grad_norm': 4.328952789306641, 'learning_rate': 1.3392857142857144e-05, 'epoch': 15.25}
{'loss': 0.3361, 'grad_norm': 1.7581273317337036, 'learning_rate': 1.3265306122448982e-05, 'epoch': 15.29}
{'loss': 0.3462, 'grad_norm': 2.7893874645233154, 'learning_rate': 1.3137755

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.29574859142303467, 'eval_accuracy': 0.8702601796756505, 'eval_runtime': 170.6292, 'eval_samples_per_second': 100.463, 'eval_steps_per_second': 0.393, 'epoch': 16.0}
{'loss': 0.3344, 'grad_norm': 2.137279748916626, 'learning_rate': 1.1224489795918369e-05, 'epoch': 16.02}
{'loss': 0.3236, 'grad_norm': 1.4718824625015259, 'learning_rate': 1.1096938775510205e-05, 'epoch': 16.06}
{'loss': 0.3252, 'grad_norm': 1.6820440292358398, 'learning_rate': 1.096938775510204e-05, 'epoch': 16.11}
{'loss': 0.3387, 'grad_norm': 2.5989990234375, 'learning_rate': 1.0841836734693878e-05, 'epoch': 16.15}
{'loss': 0.3318, 'grad_norm': 2.3536739349365234, 'learning_rate': 1.0714285714285714e-05, 'epoch': 16.2}
{'loss': 0.3307, 'grad_norm': 2.654719352722168, 'learning_rate': 1.0586734693877552e-05, 'epoch': 16.24}
{'loss': 0.3243, 'grad_norm': 1.679227352142334, 'learning_rate': 1.045918367346939e-05, 'epoch': 16.29}
{'loss': 0.3573, 'grad_norm': 1.5329948663711548, 'learning_rate': 1.0331632653

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.29735779762268066, 'eval_accuracy': 0.8696768171741921, 'eval_runtime': 167.6554, 'eval_samples_per_second': 102.245, 'eval_steps_per_second': 0.4, 'epoch': 17.0}
{'loss': 0.3252, 'grad_norm': 2.0683369636535645, 'learning_rate': 8.418367346938775e-06, 'epoch': 17.01}
{'loss': 0.3362, 'grad_norm': 3.434779644012451, 'learning_rate': 8.290816326530612e-06, 'epoch': 17.06}
{'loss': 0.3298, 'grad_norm': 1.6070014238357544, 'learning_rate': 8.163265306122448e-06, 'epoch': 17.1}
{'loss': 0.3381, 'grad_norm': 3.573881149291992, 'learning_rate': 8.035714285714286e-06, 'epoch': 17.15}
{'loss': 0.3225, 'grad_norm': 1.8341189622879028, 'learning_rate': 7.908163265306124e-06, 'epoch': 17.19}
{'loss': 0.3309, 'grad_norm': 1.9570201635360718, 'learning_rate': 7.78061224489796e-06, 'epoch': 17.24}
{'loss': 0.3157, 'grad_norm': 2.583948850631714, 'learning_rate': 7.653061224489797e-06, 'epoch': 17.29}
{'loss': 0.3317, 'grad_norm': 4.233783721923828, 'learning_rate': 7.525510204081633e

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.29272645711898804, 'eval_accuracy': 0.8714852409287132, 'eval_runtime': 170.2087, 'eval_samples_per_second': 100.712, 'eval_steps_per_second': 0.394, 'epoch': 18.0}
{'loss': 0.326, 'grad_norm': 1.4779797792434692, 'learning_rate': 5.612244897959184e-06, 'epoch': 18.01}
{'loss': 0.3435, 'grad_norm': 1.8217583894729614, 'learning_rate': 5.48469387755102e-06, 'epoch': 18.05}
{'loss': 0.338, 'grad_norm': 1.8670899868011475, 'learning_rate': 5.357142857142857e-06, 'epoch': 18.1}
{'loss': 0.3447, 'grad_norm': 2.4358131885528564, 'learning_rate': 5.229591836734695e-06, 'epoch': 18.14}
{'loss': 0.3281, 'grad_norm': 2.4266974925994873, 'learning_rate': 5.102040816326531e-06, 'epoch': 18.19}
{'loss': 0.3206, 'grad_norm': 3.318112373352051, 'learning_rate': 4.9744897959183674e-06, 'epoch': 18.24}
{'loss': 0.3292, 'grad_norm': 6.297029495239258, 'learning_rate': 4.846938775510204e-06, 'epoch': 18.28}
{'loss': 0.3276, 'grad_norm': 2.546241044998169, 'learning_rate': 4.71938775510204

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.2934187054634094, 'eval_accuracy': 0.8725352934313383, 'eval_runtime': 168.023, 'eval_samples_per_second': 102.022, 'eval_steps_per_second': 0.399, 'epoch': 19.0}
{'loss': 0.3211, 'grad_norm': 1.892685055732727, 'learning_rate': 2.806122448979592e-06, 'epoch': 19.0}
{'loss': 0.3317, 'grad_norm': 2.6029622554779053, 'learning_rate': 2.6785714285714285e-06, 'epoch': 19.05}
{'loss': 0.3508, 'grad_norm': 2.1942896842956543, 'learning_rate': 2.5510204081632653e-06, 'epoch': 19.1}
{'loss': 0.351, 'grad_norm': 2.260469913482666, 'learning_rate': 2.423469387755102e-06, 'epoch': 19.14}
{'loss': 0.3403, 'grad_norm': 2.1681981086730957, 'learning_rate': 2.295918367346939e-06, 'epoch': 19.19}
{'loss': 0.3134, 'grad_norm': 2.7360994815826416, 'learning_rate': 2.1683673469387757e-06, 'epoch': 19.23}
{'loss': 0.3426, 'grad_norm': 3.0878658294677734, 'learning_rate': 2.040816326530612e-06, 'epoch': 19.28}
{'loss': 0.3226, 'grad_norm': 1.561689853668213, 'learning_rate': 1.9132653061224

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.28821152448654175, 'eval_accuracy': 0.8773188659432971, 'eval_runtime': 169.1877, 'eval_samples_per_second': 101.319, 'eval_steps_per_second': 0.396, 'epoch': 20.0}
{'train_runtime': 14044.0232, 'train_samples_per_second': 80.417, 'train_steps_per_second': 0.315, 'train_loss': 0.39407227238918324, 'epoch': 20.0}


In [14]:
adding_embedding_bert_only_classifier = AddingEmbeddingsBertClassifierBase(movie_lens_loader, gnn_trainer_large.get_embedding, kge_dimension=MODEL_HIDDEN_SIZE, batch_size=BATCH_SIZE, model_name=MODEL_NAME)
dataset_adding_embedding = movie_lens_loader.generate_adding_embedding_dataset(adding_embedding_bert_only_classifier.tokenizer.sep_token, adding_embedding_bert_only_classifier.tokenizer.pad_token, adding_embedding_bert_only_classifier.tokenize_function, kge_dimension = MODEL_HIDDEN_SIZE)


Some weights of InsertEmbeddingBertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/56469 [00:00<?, ? examples/s]

Map:   0%|          | 0/17142 [00:00<?, ? examples/s]

Map:   0%|          | 0/17142 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/56469 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/17142 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/17142 [00:00<?, ? examples/s]

In [15]:
adding_embedding_bert_only_classifier.train_model_on_data(dataset_adding_embedding, epochs = EPOCHS)

  0%|          | 0/4420 [00:00<?, ?it/s]

{'loss': 0.7091, 'grad_norm': 2.509500026702881, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.05}
{'loss': 0.7048, 'grad_norm': 2.799487590789795, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.09}
{'loss': 0.7035, 'grad_norm': 2.143673896789551, 'learning_rate': 3e-06, 'epoch': 0.14}
{'loss': 0.6954, 'grad_norm': 2.0925886631011963, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.18}
{'loss': 0.6853, 'grad_norm': 1.866171956062317, 'learning_rate': 5e-06, 'epoch': 0.23}
{'loss': 0.6781, 'grad_norm': 2.5431997776031494, 'learning_rate': 6e-06, 'epoch': 0.27}
{'loss': 0.6691, 'grad_norm': 1.8883814811706543, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.32}
{'loss': 0.6606, 'grad_norm': 0.9473239779472351, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.36}
{'loss': 0.6513, 'grad_norm': 0.37495705485343933, 'learning_rate': 9e-06, 'epoch': 0.41}
{'loss': 0.6463, 'grad_norm': 1.2949178218841553, 'learning_rate': 1e-05, 'epoch': 0.45}
{'loss': 0.6336, 'grad_norm': 0

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.4853430688381195, 'eval_accuracy': 0.7964648232411621, 'eval_runtime': 148.5075, 'eval_samples_per_second': 115.429, 'eval_steps_per_second': 0.451, 'epoch': 1.0}
{'loss': 0.517, 'grad_norm': 0.9431381821632385, 'learning_rate': 2.3000000000000003e-05, 'epoch': 1.04}
{'loss': 0.4844, 'grad_norm': 1.1250807046890259, 'learning_rate': 2.4e-05, 'epoch': 1.09}
{'loss': 0.4854, 'grad_norm': 1.1480079889297485, 'learning_rate': 2.5e-05, 'epoch': 1.13}
{'loss': 0.4693, 'grad_norm': 1.0225285291671753, 'learning_rate': 2.6000000000000002e-05, 'epoch': 1.18}
{'loss': 0.4675, 'grad_norm': 1.0478017330169678, 'learning_rate': 2.7000000000000002e-05, 'epoch': 1.22}
{'loss': 0.458, 'grad_norm': 1.0481812953948975, 'learning_rate': 2.8000000000000003e-05, 'epoch': 1.27}
{'loss': 0.4643, 'grad_norm': 1.0488988161087036, 'learning_rate': 2.9e-05, 'epoch': 1.31}
{'loss': 0.4418, 'grad_norm': 0.9810708165168762, 'learning_rate': 3e-05, 'epoch': 1.36}
{'loss': 0.4526, 'grad_norm': 1.03173

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.3956083059310913, 'eval_accuracy': 0.8163574845408937, 'eval_runtime': 150.7453, 'eval_samples_per_second': 113.715, 'eval_steps_per_second': 0.444, 'epoch': 2.0}
{'loss': 0.3869, 'grad_norm': 0.9085557460784912, 'learning_rate': 4.5e-05, 'epoch': 2.04}
{'loss': 0.3968, 'grad_norm': 1.0455820560455322, 'learning_rate': 4.600000000000001e-05, 'epoch': 2.08}
{'loss': 0.4208, 'grad_norm': 0.9333491921424866, 'learning_rate': 4.7e-05, 'epoch': 2.13}
{'loss': 0.3986, 'grad_norm': 1.1259772777557373, 'learning_rate': 4.8e-05, 'epoch': 2.17}
{'loss': 0.4114, 'grad_norm': 0.8827069401741028, 'learning_rate': 4.9e-05, 'epoch': 2.22}
{'loss': 0.4059, 'grad_norm': 1.017343282699585, 'learning_rate': 5e-05, 'epoch': 2.26}
{'loss': 0.3952, 'grad_norm': 1.0563033819198608, 'learning_rate': 4.987244897959184e-05, 'epoch': 2.31}
{'loss': 0.3979, 'grad_norm': 0.9619131088256836, 'learning_rate': 4.974489795918368e-05, 'epoch': 2.35}
{'loss': 0.3888, 'grad_norm': 1.0475866794586182, 'lea

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.34688109159469604, 'eval_accuracy': 0.8445338933613348, 'eval_runtime': 148.9407, 'eval_samples_per_second': 115.093, 'eval_steps_per_second': 0.45, 'epoch': 3.0}
{'loss': 0.3553, 'grad_norm': 0.9349470138549805, 'learning_rate': 4.783163265306123e-05, 'epoch': 3.03}
{'loss': 0.3645, 'grad_norm': 0.9631088972091675, 'learning_rate': 4.7704081632653066e-05, 'epoch': 3.08}
{'loss': 0.3763, 'grad_norm': 0.9411998391151428, 'learning_rate': 4.7576530612244904e-05, 'epoch': 3.12}
{'loss': 0.3873, 'grad_norm': 1.0161653757095337, 'learning_rate': 4.744897959183674e-05, 'epoch': 3.17}
{'loss': 0.3697, 'grad_norm': 0.9823721647262573, 'learning_rate': 4.732142857142857e-05, 'epoch': 3.21}
{'loss': 0.3716, 'grad_norm': 0.8645493984222412, 'learning_rate': 4.719387755102041e-05, 'epoch': 3.26}
{'loss': 0.373, 'grad_norm': 1.1809028387069702, 'learning_rate': 4.706632653061225e-05, 'epoch': 3.3}
{'loss': 0.3727, 'grad_norm': 0.9025824666023254, 'learning_rate': 4.6938775510204086e

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.3236522674560547, 'eval_accuracy': 0.8577178858942948, 'eval_runtime': 149.0229, 'eval_samples_per_second': 115.029, 'eval_steps_per_second': 0.45, 'epoch': 4.0}
{'loss': 0.379, 'grad_norm': 1.0309009552001953, 'learning_rate': 4.502551020408164e-05, 'epoch': 4.03}
{'loss': 0.3623, 'grad_norm': 0.9340695738792419, 'learning_rate': 4.4897959183673474e-05, 'epoch': 4.07}
{'loss': 0.341, 'grad_norm': 0.8402586579322815, 'learning_rate': 4.477040816326531e-05, 'epoch': 4.12}
{'loss': 0.3363, 'grad_norm': 1.3143240213394165, 'learning_rate': 4.464285714285715e-05, 'epoch': 4.16}
{'loss': 0.3448, 'grad_norm': 1.1387217044830322, 'learning_rate': 4.451530612244898e-05, 'epoch': 4.21}
{'loss': 0.3639, 'grad_norm': 1.3493255376815796, 'learning_rate': 4.438775510204082e-05, 'epoch': 4.25}
{'loss': 0.3401, 'grad_norm': 0.9757453799247742, 'learning_rate': 4.4260204081632656e-05, 'epoch': 4.3}
{'loss': 0.3457, 'grad_norm': 0.9458795189857483, 'learning_rate': 4.4132653061224493e-0

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.3170113265514374, 'eval_accuracy': 0.8601680084004201, 'eval_runtime': 149.7218, 'eval_samples_per_second': 114.492, 'eval_steps_per_second': 0.447, 'epoch': 5.0}
{'loss': 0.337, 'grad_norm': 0.9424701929092407, 'learning_rate': 4.2219387755102045e-05, 'epoch': 5.02}
{'loss': 0.3417, 'grad_norm': 1.1752114295959473, 'learning_rate': 4.209183673469388e-05, 'epoch': 5.07}
{'loss': 0.3176, 'grad_norm': 1.514762282371521, 'learning_rate': 4.196428571428572e-05, 'epoch': 5.11}
{'loss': 0.346, 'grad_norm': 1.2150661945343018, 'learning_rate': 4.183673469387756e-05, 'epoch': 5.16}
{'loss': 0.3486, 'grad_norm': 1.1392451524734497, 'learning_rate': 4.170918367346939e-05, 'epoch': 5.2}
{'loss': 0.339, 'grad_norm': 0.8045358061790466, 'learning_rate': 4.1581632653061226e-05, 'epoch': 5.25}
{'loss': 0.3367, 'grad_norm': 1.031054973602295, 'learning_rate': 4.1454081632653064e-05, 'epoch': 5.29}
{'loss': 0.3439, 'grad_norm': 1.3351845741271973, 'learning_rate': 4.13265306122449e-05, 

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.3035866916179657, 'eval_accuracy': 0.8657682884144208, 'eval_runtime': 151.083, 'eval_samples_per_second': 113.461, 'eval_steps_per_second': 0.443, 'epoch': 6.0}
{'loss': 0.3302, 'grad_norm': 0.7767389416694641, 'learning_rate': 3.9413265306122446e-05, 'epoch': 6.02}
{'loss': 0.3226, 'grad_norm': 1.0604954957962036, 'learning_rate': 3.928571428571429e-05, 'epoch': 6.06}
{'loss': 0.3138, 'grad_norm': 0.9361637234687805, 'learning_rate': 3.915816326530613e-05, 'epoch': 6.11}
{'loss': 0.2979, 'grad_norm': 0.8221874833106995, 'learning_rate': 3.9030612244897965e-05, 'epoch': 6.15}
{'loss': 0.3196, 'grad_norm': 1.008080005645752, 'learning_rate': 3.8903061224489796e-05, 'epoch': 6.2}
{'loss': 0.3207, 'grad_norm': 0.9232077598571777, 'learning_rate': 3.8775510204081634e-05, 'epoch': 6.24}
{'loss': 0.326, 'grad_norm': 1.0624943971633911, 'learning_rate': 3.864795918367347e-05, 'epoch': 6.29}
{'loss': 0.3243, 'grad_norm': 1.1224843263626099, 'learning_rate': 3.852040816326531e-

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.3006895184516907, 'eval_accuracy': 0.8706685334266714, 'eval_runtime': 149.6139, 'eval_samples_per_second': 114.575, 'eval_steps_per_second': 0.448, 'epoch': 7.0}
{'loss': 0.2978, 'grad_norm': 1.117543339729309, 'learning_rate': 3.6607142857142853e-05, 'epoch': 7.01}
{'loss': 0.3284, 'grad_norm': 1.08034348487854, 'learning_rate': 3.64795918367347e-05, 'epoch': 7.06}
{'loss': 0.3144, 'grad_norm': 1.2560765743255615, 'learning_rate': 3.6352040816326536e-05, 'epoch': 7.1}
{'loss': 0.3177, 'grad_norm': 0.7268418073654175, 'learning_rate': 3.622448979591837e-05, 'epoch': 7.15}
{'loss': 0.3142, 'grad_norm': 1.5668540000915527, 'learning_rate': 3.609693877551021e-05, 'epoch': 7.19}
{'loss': 0.3114, 'grad_norm': 0.8042719960212708, 'learning_rate': 3.596938775510204e-05, 'epoch': 7.24}
{'loss': 0.3165, 'grad_norm': 0.8799934983253479, 'learning_rate': 3.584183673469388e-05, 'epoch': 7.29}
{'loss': 0.3198, 'grad_norm': 0.9216762781143188, 'learning_rate': 3.571428571428572e-05,

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.2804384231567383, 'eval_accuracy': 0.8787189359467973, 'eval_runtime': 149.4534, 'eval_samples_per_second': 114.698, 'eval_steps_per_second': 0.448, 'epoch': 8.0}
{'loss': 0.3227, 'grad_norm': 0.9453676342964172, 'learning_rate': 3.380102040816326e-05, 'epoch': 8.01}
{'loss': 0.299, 'grad_norm': 0.8656010031700134, 'learning_rate': 3.36734693877551e-05, 'epoch': 8.05}
{'loss': 0.3205, 'grad_norm': 1.021600604057312, 'learning_rate': 3.354591836734694e-05, 'epoch': 8.1}
{'loss': 0.2928, 'grad_norm': 0.9979525208473206, 'learning_rate': 3.341836734693878e-05, 'epoch': 8.14}
{'loss': 0.3078, 'grad_norm': 1.4387305974960327, 'learning_rate': 3.329081632653062e-05, 'epoch': 8.19}
{'loss': 0.3053, 'grad_norm': 0.9978251457214355, 'learning_rate': 3.316326530612245e-05, 'epoch': 8.24}
{'loss': 0.2996, 'grad_norm': 1.1429896354675293, 'learning_rate': 3.303571428571429e-05, 'epoch': 8.28}
{'loss': 0.3053, 'grad_norm': 0.9999520778656006, 'learning_rate': 3.2908163265306125e-05,

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.2776373624801636, 'eval_accuracy': 0.8807023684517559, 'eval_runtime': 148.9537, 'eval_samples_per_second': 115.083, 'eval_steps_per_second': 0.45, 'epoch': 9.0}
{'loss': 0.2928, 'grad_norm': 0.868794322013855, 'learning_rate': 3.0994897959183676e-05, 'epoch': 9.0}
{'loss': 0.3029, 'grad_norm': 0.9612505435943604, 'learning_rate': 3.086734693877551e-05, 'epoch': 9.05}
{'loss': 0.3028, 'grad_norm': 0.7910257577896118, 'learning_rate': 3.073979591836735e-05, 'epoch': 9.1}
{'loss': 0.3188, 'grad_norm': 0.8692062497138977, 'learning_rate': 3.061224489795919e-05, 'epoch': 9.14}
{'loss': 0.2888, 'grad_norm': 1.0114737749099731, 'learning_rate': 3.0484693877551023e-05, 'epoch': 9.19}
{'loss': 0.3022, 'grad_norm': 0.9708380699157715, 'learning_rate': 3.0357142857142857e-05, 'epoch': 9.23}
{'loss': 0.292, 'grad_norm': 1.088259220123291, 'learning_rate': 3.0229591836734695e-05, 'epoch': 9.28}
{'loss': 0.311, 'grad_norm': 0.9922879338264465, 'learning_rate': 3.0102040816326533e-05

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.26560232043266296, 'eval_accuracy': 0.88723602846809, 'eval_runtime': 151.1049, 'eval_samples_per_second': 113.444, 'eval_steps_per_second': 0.443, 'epoch': 10.0}
{'loss': 0.2905, 'grad_norm': 1.1739002466201782, 'learning_rate': 2.8061224489795918e-05, 'epoch': 10.05}
{'loss': 0.2983, 'grad_norm': 1.164442539215088, 'learning_rate': 2.7933673469387756e-05, 'epoch': 10.09}
{'loss': 0.2918, 'grad_norm': 1.5632154941558838, 'learning_rate': 2.7806122448979593e-05, 'epoch': 10.14}
{'loss': 0.2922, 'grad_norm': 1.3233723640441895, 'learning_rate': 2.767857142857143e-05, 'epoch': 10.18}
{'loss': 0.2973, 'grad_norm': 1.1120280027389526, 'learning_rate': 2.7551020408163265e-05, 'epoch': 10.23}
{'loss': 0.2693, 'grad_norm': 1.7096308469772339, 'learning_rate': 2.7423469387755103e-05, 'epoch': 10.27}
{'loss': 0.2916, 'grad_norm': 0.7552413940429688, 'learning_rate': 2.729591836734694e-05, 'epoch': 10.32}
{'loss': 0.2915, 'grad_norm': 0.8566994667053223, 'learning_rate': 2.716836

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.26583555340766907, 'eval_accuracy': 0.8857192859642982, 'eval_runtime': 149.5195, 'eval_samples_per_second': 114.647, 'eval_steps_per_second': 0.448, 'epoch': 11.0}
{'loss': 0.2975, 'grad_norm': 0.9776728749275208, 'learning_rate': 2.5255102040816326e-05, 'epoch': 11.04}
{'loss': 0.2945, 'grad_norm': 0.8840038776397705, 'learning_rate': 2.5127551020408164e-05, 'epoch': 11.09}
{'loss': 0.2903, 'grad_norm': 1.086629867553711, 'learning_rate': 2.5e-05, 'epoch': 11.13}
{'loss': 0.2913, 'grad_norm': 0.8852521181106567, 'learning_rate': 2.487244897959184e-05, 'epoch': 11.18}
{'loss': 0.2802, 'grad_norm': 0.9304704070091248, 'learning_rate': 2.4744897959183673e-05, 'epoch': 11.22}
{'loss': 0.2924, 'grad_norm': 0.9730851650238037, 'learning_rate': 2.461734693877551e-05, 'epoch': 11.27}
{'loss': 0.2792, 'grad_norm': 0.8322763442993164, 'learning_rate': 2.448979591836735e-05, 'epoch': 11.31}
{'loss': 0.2852, 'grad_norm': 1.1767518520355225, 'learning_rate': 2.4362244897959186e-05

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.26594629883766174, 'eval_accuracy': 0.8867110022167775, 'eval_runtime': 149.4395, 'eval_samples_per_second': 114.709, 'eval_steps_per_second': 0.448, 'epoch': 12.0}
{'loss': 0.2954, 'grad_norm': 1.0461323261260986, 'learning_rate': 2.2448979591836737e-05, 'epoch': 12.04}
{'loss': 0.3002, 'grad_norm': 0.8327432870864868, 'learning_rate': 2.2321428571428575e-05, 'epoch': 12.08}
{'loss': 0.2972, 'grad_norm': 1.2142733335494995, 'learning_rate': 2.219387755102041e-05, 'epoch': 12.13}
{'loss': 0.2985, 'grad_norm': 0.9825642704963684, 'learning_rate': 2.2066326530612247e-05, 'epoch': 12.17}
{'loss': 0.2731, 'grad_norm': 0.7729843258857727, 'learning_rate': 2.193877551020408e-05, 'epoch': 12.22}
{'loss': 0.2867, 'grad_norm': 1.0139495134353638, 'learning_rate': 2.181122448979592e-05, 'epoch': 12.26}
{'loss': 0.2953, 'grad_norm': 1.1485016345977783, 'learning_rate': 2.1683673469387756e-05, 'epoch': 12.31}
{'loss': 0.2911, 'grad_norm': 0.9452078342437744, 'learning_rate': 2.1556

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.26677829027175903, 'eval_accuracy': 0.8859526309648815, 'eval_runtime': 149.3501, 'eval_samples_per_second': 114.777, 'eval_steps_per_second': 0.449, 'epoch': 13.0}
{'loss': 0.2778, 'grad_norm': 1.0942413806915283, 'learning_rate': 1.9642857142857145e-05, 'epoch': 13.03}
{'loss': 0.2804, 'grad_norm': 0.8793929815292358, 'learning_rate': 1.9515306122448983e-05, 'epoch': 13.08}
{'loss': 0.276, 'grad_norm': 0.9027805924415588, 'learning_rate': 1.9387755102040817e-05, 'epoch': 13.12}
{'loss': 0.2975, 'grad_norm': 0.9403008222579956, 'learning_rate': 1.9260204081632655e-05, 'epoch': 13.17}
{'loss': 0.28, 'grad_norm': 1.0986301898956299, 'learning_rate': 1.913265306122449e-05, 'epoch': 13.21}
{'loss': 0.2802, 'grad_norm': 0.911440372467041, 'learning_rate': 1.9005102040816326e-05, 'epoch': 13.26}
{'loss': 0.2958, 'grad_norm': 0.9988669157028198, 'learning_rate': 1.8877551020408164e-05, 'epoch': 13.3}
{'loss': 0.2827, 'grad_norm': 1.0921072959899902, 'learning_rate': 1.8750000

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.260738730430603, 'eval_accuracy': 0.8914362384785905, 'eval_runtime': 149.5723, 'eval_samples_per_second': 114.607, 'eval_steps_per_second': 0.448, 'epoch': 14.0}
{'loss': 0.283, 'grad_norm': 1.175284504890442, 'learning_rate': 1.683673469387755e-05, 'epoch': 14.03}
{'loss': 0.2906, 'grad_norm': 0.8211909532546997, 'learning_rate': 1.670918367346939e-05, 'epoch': 14.07}
{'loss': 0.2632, 'grad_norm': 0.7064717411994934, 'learning_rate': 1.6581632653061225e-05, 'epoch': 14.12}
{'loss': 0.2721, 'grad_norm': 1.3405791521072388, 'learning_rate': 1.6454081632653062e-05, 'epoch': 14.16}
{'loss': 0.2769, 'grad_norm': 0.9880358576774597, 'learning_rate': 1.6326530612244897e-05, 'epoch': 14.21}
{'loss': 0.2964, 'grad_norm': 1.0329054594039917, 'learning_rate': 1.6198979591836734e-05, 'epoch': 14.25}
{'loss': 0.2931, 'grad_norm': 1.0176931619644165, 'learning_rate': 1.6071428571428572e-05, 'epoch': 14.3}
{'loss': 0.2901, 'grad_norm': 1.2898308038711548, 'learning_rate': 1.59438775

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.2619209289550781, 'eval_accuracy': 0.8861859759654649, 'eval_runtime': 152.028, 'eval_samples_per_second': 112.756, 'eval_steps_per_second': 0.441, 'epoch': 15.0}
{'loss': 0.275, 'grad_norm': 0.88801509141922, 'learning_rate': 1.4030612244897959e-05, 'epoch': 15.02}
{'loss': 0.2891, 'grad_norm': 1.2506626844406128, 'learning_rate': 1.3903061224489797e-05, 'epoch': 15.07}
{'loss': 0.2702, 'grad_norm': 0.9626736044883728, 'learning_rate': 1.3775510204081633e-05, 'epoch': 15.11}
{'loss': 0.2971, 'grad_norm': 1.1206727027893066, 'learning_rate': 1.364795918367347e-05, 'epoch': 15.16}
{'loss': 0.2914, 'grad_norm': 1.7500827312469482, 'learning_rate': 1.3520408163265308e-05, 'epoch': 15.2}
{'loss': 0.2836, 'grad_norm': 0.9941381216049194, 'learning_rate': 1.3392857142857144e-05, 'epoch': 15.25}
{'loss': 0.271, 'grad_norm': 1.041023850440979, 'learning_rate': 1.3265306122448982e-05, 'epoch': 15.29}
{'loss': 0.2719, 'grad_norm': 1.3764194250106812, 'learning_rate': 1.3137755102

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.25324130058288574, 'eval_accuracy': 0.891261229728153, 'eval_runtime': 153.3863, 'eval_samples_per_second': 111.757, 'eval_steps_per_second': 0.437, 'epoch': 16.0}
{'loss': 0.2637, 'grad_norm': 1.03707754611969, 'learning_rate': 1.1224489795918369e-05, 'epoch': 16.02}
{'loss': 0.2669, 'grad_norm': 0.8107073307037354, 'learning_rate': 1.1096938775510205e-05, 'epoch': 16.06}
{'loss': 0.2743, 'grad_norm': 0.877745509147644, 'learning_rate': 1.096938775510204e-05, 'epoch': 16.11}
{'loss': 0.2788, 'grad_norm': 1.2814749479293823, 'learning_rate': 1.0841836734693878e-05, 'epoch': 16.15}
{'loss': 0.2842, 'grad_norm': 1.2220592498779297, 'learning_rate': 1.0714285714285714e-05, 'epoch': 16.2}
{'loss': 0.2836, 'grad_norm': 0.9780862927436829, 'learning_rate': 1.0586734693877552e-05, 'epoch': 16.24}
{'loss': 0.2953, 'grad_norm': 1.0277347564697266, 'learning_rate': 1.045918367346939e-05, 'epoch': 16.29}
{'loss': 0.2921, 'grad_norm': 0.9057775735855103, 'learning_rate': 1.03316326

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.25959229469299316, 'eval_accuracy': 0.887994399719986, 'eval_runtime': 152.8109, 'eval_samples_per_second': 112.178, 'eval_steps_per_second': 0.438, 'epoch': 17.0}
{'loss': 0.275, 'grad_norm': 1.00331711769104, 'learning_rate': 8.418367346938775e-06, 'epoch': 17.01}
{'loss': 0.2782, 'grad_norm': 0.9383623600006104, 'learning_rate': 8.290816326530612e-06, 'epoch': 17.06}
{'loss': 0.275, 'grad_norm': 1.1063990592956543, 'learning_rate': 8.163265306122448e-06, 'epoch': 17.1}
{'loss': 0.2793, 'grad_norm': 0.8430660963058472, 'learning_rate': 8.035714285714286e-06, 'epoch': 17.15}
{'loss': 0.2764, 'grad_norm': 0.8457693457603455, 'learning_rate': 7.908163265306124e-06, 'epoch': 17.19}
{'loss': 0.2898, 'grad_norm': 1.022377371788025, 'learning_rate': 7.78061224489796e-06, 'epoch': 17.24}
{'loss': 0.2545, 'grad_norm': 1.096496343612671, 'learning_rate': 7.653061224489797e-06, 'epoch': 17.29}
{'loss': 0.291, 'grad_norm': 1.1034842729568481, 'learning_rate': 7.525510204081633e-0

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.2536170184612274, 'eval_accuracy': 0.8919029284797573, 'eval_runtime': 152.5274, 'eval_samples_per_second': 112.386, 'eval_steps_per_second': 0.439, 'epoch': 18.0}
{'loss': 0.2626, 'grad_norm': 1.0358753204345703, 'learning_rate': 5.612244897959184e-06, 'epoch': 18.01}
{'loss': 0.2939, 'grad_norm': 0.8462299108505249, 'learning_rate': 5.48469387755102e-06, 'epoch': 18.05}
{'loss': 0.2691, 'grad_norm': 0.8137834668159485, 'learning_rate': 5.357142857142857e-06, 'epoch': 18.1}
{'loss': 0.269, 'grad_norm': 0.9100592136383057, 'learning_rate': 5.229591836734695e-06, 'epoch': 18.14}
{'loss': 0.2873, 'grad_norm': 1.3644741773605347, 'learning_rate': 5.102040816326531e-06, 'epoch': 18.19}
{'loss': 0.2681, 'grad_norm': 1.3821933269500732, 'learning_rate': 4.9744897959183674e-06, 'epoch': 18.24}
{'loss': 0.268, 'grad_norm': 1.9112229347229004, 'learning_rate': 4.846938775510204e-06, 'epoch': 18.28}
{'loss': 0.2714, 'grad_norm': 0.9123752117156982, 'learning_rate': 4.719387755102

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.24508707225322723, 'eval_accuracy': 0.8957531209893828, 'eval_runtime': 152.5009, 'eval_samples_per_second': 112.406, 'eval_steps_per_second': 0.439, 'epoch': 19.0}
{'loss': 0.2616, 'grad_norm': 1.0053839683532715, 'learning_rate': 2.806122448979592e-06, 'epoch': 19.0}
{'loss': 0.2901, 'grad_norm': 0.9918547868728638, 'learning_rate': 2.6785714285714285e-06, 'epoch': 19.05}
{'loss': 0.2841, 'grad_norm': 0.9265093207359314, 'learning_rate': 2.5510204081632653e-06, 'epoch': 19.1}
{'loss': 0.281, 'grad_norm': 0.8319127559661865, 'learning_rate': 2.423469387755102e-06, 'epoch': 19.14}
{'loss': 0.2923, 'grad_norm': 1.1985414028167725, 'learning_rate': 2.295918367346939e-06, 'epoch': 19.19}
{'loss': 0.2723, 'grad_norm': 0.8562483787536621, 'learning_rate': 2.1683673469387757e-06, 'epoch': 19.23}
{'loss': 0.2864, 'grad_norm': 1.2020838260650635, 'learning_rate': 2.040816326530612e-06, 'epoch': 19.28}
{'loss': 0.269, 'grad_norm': 0.8529525995254517, 'learning_rate': 1.913265306

  0%|          | 0/67 [00:00<?, ?it/s]

{'eval_loss': 0.24707971513271332, 'eval_accuracy': 0.8940030334850075, 'eval_runtime': 152.7639, 'eval_samples_per_second': 112.212, 'eval_steps_per_second': 0.439, 'epoch': 20.0}
{'train_runtime': 13225.0559, 'train_samples_per_second': 85.397, 'train_steps_per_second': 0.334, 'train_loss': 0.3273738372379838, 'epoch': 20.0}
