In [1]:
from llm_manager import (
    SequenceClassifierOutputOverRanges,
    EmbeddingBasedClassifier,
    ClassifierBase,
    ID2LABEL,
    LABEL2ID,
    MODEL_NAME,
)
from dataset_manager import ROOT, MovieLensManager, INPUT_EMBEDS_REPLACE_KGE_DIMENSION
from graph_representation_generator import GraphRepresentationGenerator
from typing import Optional, Union, Tuple
import os

import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import BertForSequenceClassification

# Customize Input Embeds Replace Classifier
In our first experiments we noticed that replacing the placeholders of the input embeds with the KGEs breakes up the gradient propagation all the way back. To make it even, we will freeze this entire step so that only the attention headers are trained. 

In [2]:
class GraphPrompterHFFrozenBertForSequenceClassification(
    BertForSequenceClassification
):
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        graph_embeddings: Optional[torch.Tensor] = None,
        token_type_ranges: Optional[torch.Tensor] = None,
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputOverRanges]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
        if inputs_embeds is None:
            # CHANGES: we freeze the gradients when producing input embeddings and detach them before passing them on just to be save
            with torch.no_grad():
                inputs_embeds = self.bert.embeddings(input_ids).detach()
            assert isinstance(inputs_embeds, torch.Tensor)
        if graph_embeddings is not None and len(graph_embeddings) > 0:
            if attention_mask is not None:
                mask = (
                    (
                        (attention_mask.to(self.device).sum(dim=1) - 1)
                        .unsqueeze(1)
                        .repeat((1, 2))
                        - torch.tensor([3, 1], device=self.device)
                    )
                    .unsqueeze(2)
                    .repeat((1, 1, self.config.hidden_size))
                )  # basically a mask finding the last positions between the sep tokens (reshaped so they can be used in scatter)
                inputs_embeds = inputs_embeds.to(
                    self.device
                ).scatter(
                    1, mask.to(self.device), graph_embeddings.to(self.device)
                )  # replace the input embeds at the place holder positions with the KGEs.
        outputs = self.bert(
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )  # feed forward the input embeds to the attention model

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (
                    labels.dtype == torch.long or labels.dtype == torch.int
                ):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutputOverRanges(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            token_type_ranges=token_type_ranges,
        )

Next we define the new Classifier the same way the original was defined, with the exception of the model we are loading, which is the new frozen model.

In [3]:
class GraphPrompterHFFrozenClassifier(EmbeddingBasedClassifier):
    def __init__(
        self,
        kge_manager,
        get_embedding_cb,
        root_path,
        model_name=MODEL_NAME,
        model_max_length=256,
        false_ratio=1.0,
        force_recompute=False,
    ) -> None:
        training_path = f"{root_path}/training"
        model_path = f"{training_path}/best"

        if os.path.exists(model_path) and not force_recompute:
            model = (
                GraphPrompterHFFrozenBertForSequenceClassification.from_pretrained(
                    model_path,
                    num_labels=2,
                    id2label=ID2LABEL,
                    label2id=LABEL2ID,
                )
            )
        else:
            model = (
                GraphPrompterHFFrozenBertForSequenceClassification.from_pretrained(
                    model_name, num_labels=2, id2label=ID2LABEL, label2id=LABEL2ID
                )
            )
        assert isinstance(model, BertForSequenceClassification)
        super().__init__(
            kge_manager,
            get_embedding_cb,
            root_path,
            model,
            model_name,
            model_max_length,
            false_ratio,
            force_recompute,
        )

    def plot_training_loss_and_accuracy(self):
        model_type = "Input Embeds Replace"
        self._plot_training_loss_and_accuracy(model_type)

Now comes the whole training procedure (see training_models.ipynb)

In [4]:
kg_manager = MovieLensManager()
EPOCHS = 20
BATCH_SIZE = 256
graph_representation_generator_graph_prompter_hf = GraphRepresentationGenerator(
    kg_manager.data,
    kg_manager.gnn_train_data,
    kg_manager.gnn_val_data,
    kg_manager.gnn_test_data,
    hidden_channels=INPUT_EMBEDS_REPLACE_KGE_DIMENSION,
    kge_dimension=INPUT_EMBEDS_REPLACE_KGE_DIMENSION,
)
graph_prompter_hf_embeddings = (
    graph_representation_generator_graph_prompter_hf.get_saved_embeddings(
        "graph_prompter_hf"
    )
)
save = False
if graph_prompter_hf_embeddings is None:
    graph_prompter_hf_embeddings = (
        graph_representation_generator_graph_prompter_hf.generate_embeddings(
            kg_manager.llm_df
        )
    )
    save = True
kg_manager.append_graph_prompter_hf_graph_embeddings(
    graph_prompter_hf_embeddings, save=save
)

loading pretrained model
Device: 'cuda'


In [5]:
INPUT_EMBEDS_REPLACE_FROZEN_ROOT = (
    f"{ROOT}/llm/graph_prompter_hf_frozen"  # make sure the dir exists
)
if not os.path.exists(INPUT_EMBEDS_REPLACE_FROZEN_ROOT):
    os.makedirs(INPUT_EMBEDS_REPLACE_FROZEN_ROOT)
graph_prompter_hf_frozen_bert_classifier = GraphPrompterHFFrozenClassifier(
    kg_manager,
    graph_representation_generator_graph_prompter_hf.get_embedding,
    root_path=INPUT_EMBEDS_REPLACE_FROZEN_ROOT,
)
dataset_graph_prompter_hf = (
    kg_manager.generate_graph_prompter_hf_embedding_dataset(
        graph_prompter_hf_frozen_bert_classifier.tokenizer.sep_token,
        graph_prompter_hf_frozen_bert_classifier.tokenizer.pad_token,
        graph_prompter_hf_frozen_bert_classifier.tokenize_function,
        force_recompute=False,
    )
)
dataset_graph_prompter_hf

Some weights of GraphPrompterHFFrozenBertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DatasetDict({
    train: Dataset({
        features: ['source_id', 'target_id', 'id_x', 'id_y', 'prompt_feature_title', 'prompt_feature_genres', 'labels', 'split', 'prompt', 'gnn_feature_(no genres listed)', 'gnn_feature_Action', 'gnn_feature_Adventure', 'gnn_feature_Animation', 'gnn_feature_Children', 'gnn_feature_Comedy', 'gnn_feature_Crime', 'gnn_feature_Documentary', 'gnn_feature_Drama', 'gnn_feature_Fantasy', 'gnn_feature_Film-Noir', 'gnn_feature_Horror', 'gnn_feature_IMAX', 'gnn_feature_Musical', 'gnn_feature_Mystery', 'gnn_feature_Romance', 'gnn_feature_Sci-Fi', 'gnn_feature_Thriller', 'gnn_feature_War', 'gnn_feature_Western', 'graph_prompter_hf_source_embedding', 'graph_prompter_hf_target_embedding', 'graph_embeddings', 'input_ids', 'attention_mask', 'token_type_ranges'],
        num_rows: 56469
    })
    val: Dataset({
        features: ['source_id', 'target_id', 'id_x', 'id_y', 'prompt_feature_title', 'prompt_feature_genres', 'labels', 'split', 'prompt', 'gnn_feature_(no gen

In [6]:
graph_prompter_hf_frozen_bert_classifier.train_model_on_data(
    dataset_graph_prompter_hf, epochs=EPOCHS, batch_size=BATCH_SIZE
)
# now we don't have to generate new non-existing edges, because they were generated in the beginning anyways.
graph_prompter_hf_frozen_bert_classifier = GraphPrompterHFFrozenClassifier(
    kg_manager,
    graph_representation_generator_graph_prompter_hf.get_embedding,
    root_path=INPUT_EMBEDS_REPLACE_FROZEN_ROOT,
    false_ratio=-1.0,  # init with false_ratio of -1 so no new false edges are produced on the fly
)
graph_prompter_hf_frozen_df = (
    graph_prompter_hf_frozen_bert_classifier.forward_dataset_and_save_outputs(
        dataset_graph_prompter_hf,
        kg_manager.get_vanilla_tokens_as_df,
        epochs=1,
        batch_size=BATCH_SIZE,
        force_recompute=True,
    )
)

  0%|          | 0/4420 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'loss': 0.7055, 'grad_norm': 0.6920956969261169, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.05}
{'loss': 0.7042, 'grad_norm': 1.1834245920181274, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.09}
{'loss': 0.7011, 'grad_norm': 0.6951578855514526, 'learning_rate': 3e-06, 'epoch': 0.14}
{'loss': 0.7024, 'grad_norm': 0.6460822820663452, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.18}
{'loss': 0.7003, 'grad_norm': 0.737121045589447, 'learning_rate': 5e-06, 'epoch': 0.23}
{'loss': 0.6973, 'grad_norm': 1.1496777534484863, 'learning_rate': 6e-06, 'epoch': 0.27}
{'loss': 0.6954, 'grad_norm': 0.5198498368263245, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.32}
{'loss': 0.695, 'grad_norm': 0.4533461630344391, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.36}
{'loss': 0.6926, 'grad_norm': 0.4349064230918884, 'learning_rate': 9e-06, 'epoch': 0.41}
{'loss': 0.6885, 'grad_norm': 0.5290090441703796, 'learning_rate': 1e-05, 'epoch': 0.45}
{'loss': 0.6837, 'grad_norm': 

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.5785201787948608, 'eval_accuracy': 0.7287947730719869, 'eval_runtime': 31.0727, 'eval_samples_per_second': 1103.347, 'eval_steps_per_second': 4.312, 'epoch': 1.0}
{'loss': 0.6016, 'grad_norm': 0.6119468212127686, 'learning_rate': 2.3000000000000003e-05, 'epoch': 1.04}
{'loss': 0.6047, 'grad_norm': 0.5500653386116028, 'learning_rate': 2.4e-05, 'epoch': 1.09}
{'loss': 0.5903, 'grad_norm': 0.7359963059425354, 'learning_rate': 2.5e-05, 'epoch': 1.13}
{'loss': 0.5866, 'grad_norm': 0.9616073369979858, 'learning_rate': 2.6000000000000002e-05, 'epoch': 1.18}
{'loss': 0.5692, 'grad_norm': 0.8594094514846802, 'learning_rate': 2.7000000000000002e-05, 'epoch': 1.22}
{'loss': 0.5854, 'grad_norm': 0.8118565082550049, 'learning_rate': 2.8000000000000003e-05, 'epoch': 1.27}
{'loss': 0.5623, 'grad_norm': 0.7552378177642822, 'learning_rate': 2.9e-05, 'epoch': 1.31}
{'loss': 0.567, 'grad_norm': 0.5573002696037292, 'learning_rate': 3e-05, 'epoch': 1.36}
{'loss': 0.5685, 'grad_norm': 0.5983

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.44817566871643066, 'eval_accuracy': 0.7996441488741104, 'eval_runtime': 30.6204, 'eval_samples_per_second': 1119.646, 'eval_steps_per_second': 4.376, 'epoch': 2.0}
{'loss': 0.4969, 'grad_norm': 1.1562436819076538, 'learning_rate': 4.5e-05, 'epoch': 2.04}
{'loss': 0.4671, 'grad_norm': 0.8428548574447632, 'learning_rate': 4.600000000000001e-05, 'epoch': 2.08}
{'loss': 0.4675, 'grad_norm': 0.7837862968444824, 'learning_rate': 4.7e-05, 'epoch': 2.13}
{'loss': 0.4484, 'grad_norm': 2.057143449783325, 'learning_rate': 4.8e-05, 'epoch': 2.17}
{'loss': 0.447, 'grad_norm': 0.8299083113670349, 'learning_rate': 4.9e-05, 'epoch': 2.22}
{'loss': 0.4455, 'grad_norm': 1.5028512477874756, 'learning_rate': 5e-05, 'epoch': 2.26}
{'loss': 0.4182, 'grad_norm': 1.727279543876648, 'learning_rate': 4.987244897959184e-05, 'epoch': 2.31}
{'loss': 0.4397, 'grad_norm': 0.6622062921524048, 'learning_rate': 4.974489795918368e-05, 'epoch': 2.35}
{'loss': 0.4254, 'grad_norm': 0.679668128490448, 'learn

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.344199538230896, 'eval_accuracy': 0.8500466690001167, 'eval_runtime': 31.0972, 'eval_samples_per_second': 1102.479, 'eval_steps_per_second': 4.309, 'epoch': 3.0}
{'loss': 0.3803, 'grad_norm': 0.8910186886787415, 'learning_rate': 4.783163265306123e-05, 'epoch': 3.03}
{'loss': 0.3743, 'grad_norm': 0.8793795704841614, 'learning_rate': 4.7704081632653066e-05, 'epoch': 3.08}
{'loss': 0.3847, 'grad_norm': 0.8121848702430725, 'learning_rate': 4.7576530612244904e-05, 'epoch': 3.12}
{'loss': 0.3776, 'grad_norm': 0.7177539467811584, 'learning_rate': 4.744897959183674e-05, 'epoch': 3.17}
{'loss': 0.369, 'grad_norm': 1.1468653678894043, 'learning_rate': 4.732142857142857e-05, 'epoch': 3.21}
{'loss': 0.3769, 'grad_norm': 0.8686121106147766, 'learning_rate': 4.719387755102041e-05, 'epoch': 3.26}
{'loss': 0.3783, 'grad_norm': 0.9090591669082642, 'learning_rate': 4.706632653061225e-05, 'epoch': 3.3}
{'loss': 0.3545, 'grad_norm': 0.861648678779602, 'learning_rate': 4.6938775510204086e-0

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.31359803676605225, 'eval_accuracy': 0.8663808190409521, 'eval_runtime': 31.0598, 'eval_samples_per_second': 1103.806, 'eval_steps_per_second': 4.314, 'epoch': 4.0}
{'loss': 0.3451, 'grad_norm': 1.0744616985321045, 'learning_rate': 4.502551020408164e-05, 'epoch': 4.03}
{'loss': 0.3487, 'grad_norm': 1.2852296829223633, 'learning_rate': 4.4897959183673474e-05, 'epoch': 4.07}
{'loss': 0.382, 'grad_norm': 1.281740665435791, 'learning_rate': 4.477040816326531e-05, 'epoch': 4.12}
{'loss': 0.3606, 'grad_norm': 0.8177499175071716, 'learning_rate': 4.464285714285715e-05, 'epoch': 4.16}
{'loss': 0.3532, 'grad_norm': 0.773481547832489, 'learning_rate': 4.451530612244898e-05, 'epoch': 4.21}
{'loss': 0.3415, 'grad_norm': 0.6825517416000366, 'learning_rate': 4.438775510204082e-05, 'epoch': 4.25}
{'loss': 0.3449, 'grad_norm': 0.9061127305030823, 'learning_rate': 4.4260204081632656e-05, 'epoch': 4.3}
{'loss': 0.3442, 'grad_norm': 0.6647849082946777, 'learning_rate': 4.4132653061224493e-

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.29713767766952515, 'eval_accuracy': 0.8734978415587447, 'eval_runtime': 30.545, 'eval_samples_per_second': 1122.408, 'eval_steps_per_second': 4.387, 'epoch': 5.0}
{'loss': 0.3475, 'grad_norm': 0.8401666283607483, 'learning_rate': 4.2219387755102045e-05, 'epoch': 5.02}
{'loss': 0.3402, 'grad_norm': 0.6340931057929993, 'learning_rate': 4.209183673469388e-05, 'epoch': 5.07}
{'loss': 0.3463, 'grad_norm': 0.9996523857116699, 'learning_rate': 4.196428571428572e-05, 'epoch': 5.11}
{'loss': 0.3364, 'grad_norm': 0.9060143828392029, 'learning_rate': 4.183673469387756e-05, 'epoch': 5.16}
{'loss': 0.3234, 'grad_norm': 0.677783191204071, 'learning_rate': 4.170918367346939e-05, 'epoch': 5.2}
{'loss': 0.3215, 'grad_norm': 0.8394181132316589, 'learning_rate': 4.1581632653061226e-05, 'epoch': 5.25}
{'loss': 0.3377, 'grad_norm': 0.8322839736938477, 'learning_rate': 4.1454081632653064e-05, 'epoch': 5.29}
{'loss': 0.3432, 'grad_norm': 0.7714694142341614, 'learning_rate': 4.13265306122449e-

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.2820236384868622, 'eval_accuracy': 0.8826566328316415, 'eval_runtime': 30.4241, 'eval_samples_per_second': 1126.869, 'eval_steps_per_second': 4.404, 'epoch': 6.0}
{'loss': 0.3188, 'grad_norm': 1.1288292407989502, 'learning_rate': 3.9413265306122446e-05, 'epoch': 6.02}
{'loss': 0.3305, 'grad_norm': 0.8152766227722168, 'learning_rate': 3.928571428571429e-05, 'epoch': 6.06}
{'loss': 0.3301, 'grad_norm': 0.844337522983551, 'learning_rate': 3.915816326530613e-05, 'epoch': 6.11}
{'loss': 0.327, 'grad_norm': 1.0029500722885132, 'learning_rate': 3.9030612244897965e-05, 'epoch': 6.15}
{'loss': 0.3234, 'grad_norm': 0.7166342735290527, 'learning_rate': 3.8903061224489796e-05, 'epoch': 6.2}
{'loss': 0.3117, 'grad_norm': 0.8200662732124329, 'learning_rate': 3.8775510204081634e-05, 'epoch': 6.24}
{'loss': 0.3262, 'grad_norm': 0.9191544055938721, 'learning_rate': 3.864795918367347e-05, 'epoch': 6.29}
{'loss': 0.3219, 'grad_norm': 1.213193416595459, 'learning_rate': 3.852040816326531e-

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.26922616362571716, 'eval_accuracy': 0.8882277447205693, 'eval_runtime': 30.4332, 'eval_samples_per_second': 1126.535, 'eval_steps_per_second': 4.403, 'epoch': 7.0}
{'loss': 0.3214, 'grad_norm': 0.8906382322311401, 'learning_rate': 3.6607142857142853e-05, 'epoch': 7.01}
{'loss': 0.3162, 'grad_norm': 0.7677427530288696, 'learning_rate': 3.64795918367347e-05, 'epoch': 7.06}
{'loss': 0.319, 'grad_norm': 0.7352374196052551, 'learning_rate': 3.6352040816326536e-05, 'epoch': 7.1}
{'loss': 0.2861, 'grad_norm': 1.0871706008911133, 'learning_rate': 3.622448979591837e-05, 'epoch': 7.15}
{'loss': 0.3105, 'grad_norm': 0.8124300241470337, 'learning_rate': 3.609693877551021e-05, 'epoch': 7.19}
{'loss': 0.3386, 'grad_norm': 0.7092180252075195, 'learning_rate': 3.596938775510204e-05, 'epoch': 7.24}
{'loss': 0.3229, 'grad_norm': 0.9064545631408691, 'learning_rate': 3.584183673469388e-05, 'epoch': 7.29}
{'loss': 0.3159, 'grad_norm': 0.8729361891746521, 'learning_rate': 3.571428571428572e-

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.26424461603164673, 'eval_accuracy': 0.8914945747287364, 'eval_runtime': 30.7094, 'eval_samples_per_second': 1116.399, 'eval_steps_per_second': 4.363, 'epoch': 8.0}
{'loss': 0.3196, 'grad_norm': 0.9444246292114258, 'learning_rate': 3.380102040816326e-05, 'epoch': 8.01}
{'loss': 0.2984, 'grad_norm': 1.0085216760635376, 'learning_rate': 3.36734693877551e-05, 'epoch': 8.05}
{'loss': 0.3217, 'grad_norm': 0.7169920206069946, 'learning_rate': 3.354591836734694e-05, 'epoch': 8.1}
{'loss': 0.3056, 'grad_norm': 0.8081768751144409, 'learning_rate': 3.341836734693878e-05, 'epoch': 8.14}
{'loss': 0.2998, 'grad_norm': 0.9646367430686951, 'learning_rate': 3.329081632653062e-05, 'epoch': 8.19}
{'loss': 0.3116, 'grad_norm': 1.0990928411483765, 'learning_rate': 3.316326530612245e-05, 'epoch': 8.24}
{'loss': 0.3018, 'grad_norm': 1.1029318571090698, 'learning_rate': 3.303571428571429e-05, 'epoch': 8.28}
{'loss': 0.3234, 'grad_norm': 0.7210940718650818, 'learning_rate': 3.2908163265306125e-

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.26646938920021057, 'eval_accuracy': 0.8905903628514759, 'eval_runtime': 31.529, 'eval_samples_per_second': 1087.379, 'eval_steps_per_second': 4.25, 'epoch': 9.0}
{'loss': 0.305, 'grad_norm': 0.7841144800186157, 'learning_rate': 3.0994897959183676e-05, 'epoch': 9.0}
{'loss': 0.2916, 'grad_norm': 1.1914547681808472, 'learning_rate': 3.086734693877551e-05, 'epoch': 9.05}
{'loss': 0.2826, 'grad_norm': 0.8577204346656799, 'learning_rate': 3.073979591836735e-05, 'epoch': 9.1}
{'loss': 0.291, 'grad_norm': 0.8156721591949463, 'learning_rate': 3.061224489795919e-05, 'epoch': 9.14}
{'loss': 0.3158, 'grad_norm': 1.7646825313568115, 'learning_rate': 3.0484693877551023e-05, 'epoch': 9.19}
{'loss': 0.3149, 'grad_norm': 0.7422825694084167, 'learning_rate': 3.0357142857142857e-05, 'epoch': 9.23}
{'loss': 0.3187, 'grad_norm': 0.8189684152603149, 'learning_rate': 3.0229591836734695e-05, 'epoch': 9.28}
{'loss': 0.2878, 'grad_norm': 0.7954614162445068, 'learning_rate': 3.0102040816326533e-

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.26763173937797546, 'eval_accuracy': 0.8889277797223194, 'eval_runtime': 30.9527, 'eval_samples_per_second': 1107.627, 'eval_steps_per_second': 4.329, 'epoch': 10.0}
{'loss': 0.3039, 'grad_norm': 0.8635313510894775, 'learning_rate': 2.8061224489795918e-05, 'epoch': 10.05}
{'loss': 0.3021, 'grad_norm': 0.7064277529716492, 'learning_rate': 2.7933673469387756e-05, 'epoch': 10.09}
{'loss': 0.2911, 'grad_norm': 0.7334499359130859, 'learning_rate': 2.7806122448979593e-05, 'epoch': 10.14}
{'loss': 0.2925, 'grad_norm': 0.7827435731887817, 'learning_rate': 2.767857142857143e-05, 'epoch': 10.18}
{'loss': 0.2797, 'grad_norm': 0.6792486906051636, 'learning_rate': 2.7551020408163265e-05, 'epoch': 10.23}
{'loss': 0.3092, 'grad_norm': 0.7230169177055359, 'learning_rate': 2.7423469387755103e-05, 'epoch': 10.27}
{'loss': 0.2895, 'grad_norm': 0.8957627415657043, 'learning_rate': 2.729591836734694e-05, 'epoch': 10.32}
{'loss': 0.2718, 'grad_norm': 0.7441991567611694, 'learning_rate': 2.716

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.2562359571456909, 'eval_accuracy': 0.8963656516159141, 'eval_runtime': 30.7099, 'eval_samples_per_second': 1116.383, 'eval_steps_per_second': 4.363, 'epoch': 11.0}
{'loss': 0.3025, 'grad_norm': 0.7654920816421509, 'learning_rate': 2.5255102040816326e-05, 'epoch': 11.04}
{'loss': 0.2835, 'grad_norm': 0.7954716682434082, 'learning_rate': 2.5127551020408164e-05, 'epoch': 11.09}
{'loss': 0.3014, 'grad_norm': 0.9197812080383301, 'learning_rate': 2.5e-05, 'epoch': 11.13}
{'loss': 0.3028, 'grad_norm': 1.0141725540161133, 'learning_rate': 2.487244897959184e-05, 'epoch': 11.18}
{'loss': 0.2969, 'grad_norm': 0.895582914352417, 'learning_rate': 2.4744897959183673e-05, 'epoch': 11.22}
{'loss': 0.2905, 'grad_norm': 0.7877312898635864, 'learning_rate': 2.461734693877551e-05, 'epoch': 11.27}
{'loss': 0.3079, 'grad_norm': 0.9836503267288208, 'learning_rate': 2.448979591836735e-05, 'epoch': 11.31}
{'loss': 0.3058, 'grad_norm': 0.7494577169418335, 'learning_rate': 2.4362244897959186e-05,

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.2553376257419586, 'eval_accuracy': 0.8956364484890911, 'eval_runtime': 30.8722, 'eval_samples_per_second': 1110.513, 'eval_steps_per_second': 4.34, 'epoch': 12.0}
{'loss': 0.2815, 'grad_norm': 0.9327031373977661, 'learning_rate': 2.2448979591836737e-05, 'epoch': 12.04}
{'loss': 0.2965, 'grad_norm': 0.9915962815284729, 'learning_rate': 2.2321428571428575e-05, 'epoch': 12.08}
{'loss': 0.2792, 'grad_norm': 0.8836171627044678, 'learning_rate': 2.219387755102041e-05, 'epoch': 12.13}
{'loss': 0.3033, 'grad_norm': 0.8192165493965149, 'learning_rate': 2.2066326530612247e-05, 'epoch': 12.17}
{'loss': 0.2849, 'grad_norm': 0.6669965982437134, 'learning_rate': 2.193877551020408e-05, 'epoch': 12.22}
{'loss': 0.3054, 'grad_norm': 0.7564768195152283, 'learning_rate': 2.181122448979592e-05, 'epoch': 12.26}
{'loss': 0.3034, 'grad_norm': 1.02680504322052, 'learning_rate': 2.1683673469387756e-05, 'epoch': 12.31}
{'loss': 0.2938, 'grad_norm': 0.8154210448265076, 'learning_rate': 2.15561224

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.2545433044433594, 'eval_accuracy': 0.8954906078637265, 'eval_runtime': 30.5774, 'eval_samples_per_second': 1121.22, 'eval_steps_per_second': 4.382, 'epoch': 13.0}
{'loss': 0.2893, 'grad_norm': 0.8150110840797424, 'learning_rate': 1.9642857142857145e-05, 'epoch': 13.03}
{'loss': 0.2867, 'grad_norm': 1.1148560047149658, 'learning_rate': 1.9515306122448983e-05, 'epoch': 13.08}
{'loss': 0.3053, 'grad_norm': 0.881665825843811, 'learning_rate': 1.9387755102040817e-05, 'epoch': 13.12}
{'loss': 0.2783, 'grad_norm': 0.9508885741233826, 'learning_rate': 1.9260204081632655e-05, 'epoch': 13.17}
{'loss': 0.2929, 'grad_norm': 0.9716324210166931, 'learning_rate': 1.913265306122449e-05, 'epoch': 13.21}
{'loss': 0.2942, 'grad_norm': 0.9872744679450989, 'learning_rate': 1.9005102040816326e-05, 'epoch': 13.26}
{'loss': 0.2981, 'grad_norm': 0.8986158967018127, 'learning_rate': 1.8877551020408164e-05, 'epoch': 13.3}
{'loss': 0.3011, 'grad_norm': 1.0113205909729004, 'learning_rate': 1.875000

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.25441882014274597, 'eval_accuracy': 0.8949072453622681, 'eval_runtime': 31.0462, 'eval_samples_per_second': 1104.288, 'eval_steps_per_second': 4.316, 'epoch': 14.0}
{'loss': 0.2956, 'grad_norm': 1.0056380033493042, 'learning_rate': 1.683673469387755e-05, 'epoch': 14.03}
{'loss': 0.2878, 'grad_norm': 0.7249377369880676, 'learning_rate': 1.670918367346939e-05, 'epoch': 14.07}
{'loss': 0.2842, 'grad_norm': 0.7160003185272217, 'learning_rate': 1.6581632653061225e-05, 'epoch': 14.12}
{'loss': 0.3078, 'grad_norm': 1.0268734693527222, 'learning_rate': 1.6454081632653062e-05, 'epoch': 14.16}
{'loss': 0.2857, 'grad_norm': 0.7211300134658813, 'learning_rate': 1.6326530612244897e-05, 'epoch': 14.21}
{'loss': 0.2891, 'grad_norm': 0.8397911787033081, 'learning_rate': 1.6198979591836734e-05, 'epoch': 14.25}
{'loss': 0.3031, 'grad_norm': 0.8738505244255066, 'learning_rate': 1.6071428571428572e-05, 'epoch': 14.3}
{'loss': 0.2681, 'grad_norm': 0.8145748972892761, 'learning_rate': 1.5943

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.2504860460758209, 'eval_accuracy': 0.8959864659899661, 'eval_runtime': 29.7193, 'eval_samples_per_second': 1153.592, 'eval_steps_per_second': 4.509, 'epoch': 15.0}
{'loss': 0.2802, 'grad_norm': 0.8941984176635742, 'learning_rate': 1.4030612244897959e-05, 'epoch': 15.02}
{'loss': 0.2929, 'grad_norm': 1.029395341873169, 'learning_rate': 1.3903061224489797e-05, 'epoch': 15.07}
{'loss': 0.2926, 'grad_norm': 0.8466260433197021, 'learning_rate': 1.3775510204081633e-05, 'epoch': 15.11}
{'loss': 0.2897, 'grad_norm': 0.9022239446640015, 'learning_rate': 1.364795918367347e-05, 'epoch': 15.16}
{'loss': 0.2801, 'grad_norm': 0.8380320072174072, 'learning_rate': 1.3520408163265308e-05, 'epoch': 15.2}
{'loss': 0.2906, 'grad_norm': 0.6658509969711304, 'learning_rate': 1.3392857142857144e-05, 'epoch': 15.25}
{'loss': 0.3018, 'grad_norm': 0.8966707587242126, 'learning_rate': 1.3265306122448982e-05, 'epoch': 15.29}
{'loss': 0.2914, 'grad_norm': 0.7762507796287537, 'learning_rate': 1.31377

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.25454995036125183, 'eval_accuracy': 0.8948489091121222, 'eval_runtime': 30.8703, 'eval_samples_per_second': 1110.582, 'eval_steps_per_second': 4.341, 'epoch': 16.0}
{'loss': 0.3161, 'grad_norm': 0.9813271164894104, 'learning_rate': 1.1224489795918369e-05, 'epoch': 16.02}
{'loss': 0.2779, 'grad_norm': 0.7870693802833557, 'learning_rate': 1.1096938775510205e-05, 'epoch': 16.06}
{'loss': 0.2768, 'grad_norm': 0.7121210694313049, 'learning_rate': 1.096938775510204e-05, 'epoch': 16.11}
{'loss': 0.2796, 'grad_norm': 0.9340022802352905, 'learning_rate': 1.0841836734693878e-05, 'epoch': 16.15}
{'loss': 0.2777, 'grad_norm': 0.938133180141449, 'learning_rate': 1.0714285714285714e-05, 'epoch': 16.2}
{'loss': 0.2828, 'grad_norm': 0.7995957136154175, 'learning_rate': 1.0586734693877552e-05, 'epoch': 16.24}
{'loss': 0.2827, 'grad_norm': 0.7310016751289368, 'learning_rate': 1.045918367346939e-05, 'epoch': 16.29}
{'loss': 0.2911, 'grad_norm': 0.8229255676269531, 'learning_rate': 1.03316

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.25311940908432007, 'eval_accuracy': 0.8955197759887994, 'eval_runtime': 30.9692, 'eval_samples_per_second': 1107.035, 'eval_steps_per_second': 4.327, 'epoch': 17.0}
{'loss': 0.281, 'grad_norm': 0.7476648092269897, 'learning_rate': 8.418367346938775e-06, 'epoch': 17.01}
{'loss': 0.2831, 'grad_norm': 0.6786288022994995, 'learning_rate': 8.290816326530612e-06, 'epoch': 17.06}
{'loss': 0.2938, 'grad_norm': 0.9614419937133789, 'learning_rate': 8.163265306122448e-06, 'epoch': 17.1}
{'loss': 0.288, 'grad_norm': 0.9583922624588013, 'learning_rate': 8.035714285714286e-06, 'epoch': 17.15}
{'loss': 0.2902, 'grad_norm': 0.7512609958648682, 'learning_rate': 7.908163265306124e-06, 'epoch': 17.19}
{'loss': 0.28, 'grad_norm': 0.9792643189430237, 'learning_rate': 7.78061224489796e-06, 'epoch': 17.24}
{'loss': 0.2943, 'grad_norm': 0.938090443611145, 'learning_rate': 7.653061224489797e-06, 'epoch': 17.29}
{'loss': 0.278, 'grad_norm': 0.7395123243331909, 'learning_rate': 7.525510204081633e

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.25360018014907837, 'eval_accuracy': 0.8948780772371951, 'eval_runtime': 30.4679, 'eval_samples_per_second': 1125.25, 'eval_steps_per_second': 4.398, 'epoch': 18.0}
{'loss': 0.2727, 'grad_norm': 0.9142851829528809, 'learning_rate': 5.612244897959184e-06, 'epoch': 18.01}
{'loss': 0.2872, 'grad_norm': 0.9257677793502808, 'learning_rate': 5.48469387755102e-06, 'epoch': 18.05}
{'loss': 0.289, 'grad_norm': 0.9115723967552185, 'learning_rate': 5.357142857142857e-06, 'epoch': 18.1}
{'loss': 0.3064, 'grad_norm': 0.8956109285354614, 'learning_rate': 5.229591836734695e-06, 'epoch': 18.14}
{'loss': 0.2813, 'grad_norm': 0.7900171279907227, 'learning_rate': 5.102040816326531e-06, 'epoch': 18.19}
{'loss': 0.2962, 'grad_norm': 0.8971437215805054, 'learning_rate': 4.9744897959183674e-06, 'epoch': 18.24}
{'loss': 0.283, 'grad_norm': 0.6934467554092407, 'learning_rate': 4.846938775510204e-06, 'epoch': 18.28}
{'loss': 0.2898, 'grad_norm': 0.9459550976753235, 'learning_rate': 4.719387755102

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.24998049437999725, 'eval_accuracy': 0.8953447672383619, 'eval_runtime': 30.5257, 'eval_samples_per_second': 1123.121, 'eval_steps_per_second': 4.39, 'epoch': 19.0}
{'loss': 0.2799, 'grad_norm': 0.9792028069496155, 'learning_rate': 2.806122448979592e-06, 'epoch': 19.0}
{'loss': 0.2888, 'grad_norm': 0.9682706594467163, 'learning_rate': 2.6785714285714285e-06, 'epoch': 19.05}
{'loss': 0.2852, 'grad_norm': 0.8230661749839783, 'learning_rate': 2.5510204081632653e-06, 'epoch': 19.1}
{'loss': 0.2655, 'grad_norm': 0.9372957348823547, 'learning_rate': 2.423469387755102e-06, 'epoch': 19.14}
{'loss': 0.2802, 'grad_norm': 0.8907191753387451, 'learning_rate': 2.295918367346939e-06, 'epoch': 19.19}
{'loss': 0.2954, 'grad_norm': 1.0428447723388672, 'learning_rate': 2.1683673469387757e-06, 'epoch': 19.23}
{'loss': 0.2824, 'grad_norm': 0.7810648679733276, 'learning_rate': 2.040816326530612e-06, 'epoch': 19.28}
{'loss': 0.2777, 'grad_norm': 0.9300900101661682, 'learning_rate': 1.91326530

  0%|          | 0/134 [00:00<?, ?it/s]

{'eval_loss': 0.25163358449935913, 'eval_accuracy': 0.895315599113289, 'eval_runtime': 30.4632, 'eval_samples_per_second': 1125.424, 'eval_steps_per_second': 4.399, 'epoch': 20.0}
{'train_runtime': 9292.4387, 'train_samples_per_second': 121.538, 'train_steps_per_second': 0.476, 'train_loss': 0.34099182462260735, 'epoch': 20.0}
./data/llm/graph_prompter_hf_frozen/attentions.npy ./data/llm/graph_prompter_hf_frozen/hidden_states.npy ./data/llm/graph_prompter_hf_frozen/tokens.csv
train Forward Epoch 1 from 1




test Forward Epoch 1 from 1
val Forward Epoch 1 from 1


# Post-Training
After Training, we load all attentions and hidden states from before and append the new fields to the dataset.

In [7]:
vanilla_df = ClassifierBase.read_forward_dataset("./data/llm/vanilla")
prompt_df = ClassifierBase.read_forward_dataset("./data/llm/prompt")
graph_prompter_hf_df = ClassifierBase.read_forward_dataset(
    "./data/llm/graph_prompter_hf"
)

In [10]:
kg_manager.llm_df

Unnamed: 0,source_id,target_id,id_x,id_y,prompt_feature_title,prompt_feature_genres,labels,split,prompt,gnn_feature_(no genres listed),...,gnn_feature_IMAX,gnn_feature_Musical,gnn_feature_Mystery,gnn_feature_Romance,gnn_feature_Sci-Fi,gnn_feature_Thriller,gnn_feature_War,gnn_feature_Western,graph_prompter_hf_source_embedding,graph_prompter_hf_target_embedding
0,0,5,0.0,5.0,Heat (1995),"['Action', 'Crime', 'Thriller']",1,train,"0[SEP]5[SEP]Heat (1995)[SEP]['Action', 'Crime'...",,...,,,,,,,,,"[0.6693655252456665, 0.4152011275291443, -0.31...","[0.5493870973587036, -0.5692430734634399, -0.6..."
1,0,89,0.0,89.0,Bottle Rocket (1996),"['Adventure', 'Comedy', 'Crime', 'Romance']",1,train,0[SEP]89[SEP]Bottle Rocket (1996)[SEP]['Advent...,,...,,,,,,,,,"[0.6030164957046509, 0.37829291820526123, -0.3...","[0.08589676022529602, -0.4689454734325409, 0.0..."
2,0,97,0.0,97.0,Braveheart (1995),"['Action', 'Drama', 'War']",1,train,"0[SEP]97[SEP]Braveheart (1995)[SEP]['Action', ...",,...,,,,,,,,,"[0.6387443542480469, 0.3706628084182739, -0.32...","[0.04940161854028702, -0.47019216418266296, -0..."
3,0,184,0.0,184.0,Billy Madison (1995),['Comedy'],1,train,0[SEP]184[SEP]Billy Madison (1995)[SEP]['Comedy'],,...,,,,,,,,,"[0.49271589517593384, 0.4562503397464752, -0.2...","[0.6636490821838379, -0.3091672956943512, 0.01..."
4,0,224,0.0,224.0,Star Wars: Episode IV - A New Hope (1977),"['Action', 'Adventure', 'Sci-Fi']",1,train,0[SEP]224[SEP]Star Wars: Episode IV - A New Ho...,,...,,,,,,,,,"[0.6498588919639587, 0.43191879987716675, -0.3...","[-0.08320911973714828, -0.060460180044174194, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
125032,79,8846,,,Lovesick (2014),"['Comedy', 'Romance']",0,val,"79[SEP]8846[SEP]Lovesick (2014)[SEP]['Comedy',...",0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,"[-0.5735506415367126, -0.044633716344833374, -...","[0.47228026390075684, -0.911340594291687, 0.00..."
125033,300,2595,,,Dersu Uzala (1975),"['Adventure', 'Drama']",0,val,300[SEP]2595[SEP]Dersu Uzala (1975)[SEP]['Adve...,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"[0.006543345749378204, -0.3807082176208496, 0....","[0.36955565214157104, -0.21707147359848022, -0..."
125034,321,59,,,Lawnmower Man 2: Beyond Cyberspace (1996),"['Action', 'Sci-Fi', 'Thriller']",0,val,321[SEP]59[SEP]Lawnmower Man 2: Beyond Cybersp...,0.0,...,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,"[-0.3852957785129547, 0.02625414729118347, -0....","[0.38845473527908325, -0.27163609862327576, -0..."
125035,388,2639,,,All the Vermeers in New York (1990),"['Comedy', 'Drama', 'Romance']",0,val,388[SEP]2639[SEP]All the Vermeers in New York ...,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,"[0.26095038652420044, 0.08665543794631958, -0....","[1.6701771020889282, -0.6786400079727173, 0.87..."


In [8]:
dataset = kg_manager.generate_huggingface_dataset(
    [vanilla_df, prompt_df, graph_prompter_hf_df, graph_prompter_hf_frozen_df],
    ["vanilla", "prompt", "graph_prompter_hf", "graph_prompter_hf_frozen_df"],
)

AssertionError: 

In [None]:
dataset

NameError: name 'dataset' is not defined

In [None]:
dataset.save_to_disk("./data/dataset.hf")

Saving the dataset (0/10 shards):   0%|          | 0/56469 [00:00<?, ? examples/s]

Saving the dataset (0/6 shards):   0%|          | 0/34284 [00:00<?, ? examples/s]

Saving the dataset (0/6 shards):   0%|          | 0/34284 [00:00<?, ? examples/s]

In [None]:
dataset.push_to_hub("AhmadPython/MovieLens_KGE")

Uploading the dataset shards:   0%|          | 0/10 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/6 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/6 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/2.80k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


CommitInfo(commit_url='https://huggingface.co/datasets/AhmadPython/MovieLens_KGE/commit/70c96f660429150e55b6437d59ee39da76891de7', commit_message='Upload dataset', commit_description='', oid='70c96f660429150e55b6437d59ee39da76891de7', pr_url=None, pr_revision=None, pr_num=None)