In [1]:
from graph_representation_generator import GraphRepresentationGenerator
from dataset_manager import (
    MovieLensManager,
    PROMPT_KGE_DIMENSION,
    INPUT_EMBEDS_REPLACE_KGE_DIMENSION,
    ROOT,
)
from llm_manager import (
    VanillaBertClassifier,
    GraphPrompterHFClassifier,
)

In [2]:
kg_manager = MovieLensManager()

In [3]:
graph_representation_generator_prompt = GraphRepresentationGenerator(
    kg_manager.data,
    kg_manager.gnn_train_data,
    kg_manager.gnn_val_data,
    kg_manager.gnn_test_data,
    kge_dimension=PROMPT_KGE_DIMENSION,
)
graph_representation_generator_graph_prompter_hf = GraphRepresentationGenerator(
    kg_manager.data,
    kg_manager.gnn_train_data,
    kg_manager.gnn_val_data,
    kg_manager.gnn_test_data,
    hidden_channels=INPUT_EMBEDS_REPLACE_KGE_DIMENSION,
    kge_dimension=INPUT_EMBEDS_REPLACE_KGE_DIMENSION,
)

loading pretrained model
Device: 'cpu'
loading pretrained model
Device: 'cpu'


In [4]:
prompt_embeddings = graph_representation_generator_prompt.get_saved_embeddings("prompt")
graph_prompter_hf_embeddings = (
    graph_representation_generator_graph_prompter_hf.get_saved_embeddings(
        "graph_prompter_hf"
    )
)
save_prompt = False
save_graph_prompter_hf = False
if prompt_embeddings is None:
    prompt_embeddings = graph_representation_generator_prompt.generate_embeddings(
        kg_manager.llm_df
    )
    save_prompt = True
if graph_prompter_hf_embeddings is None:
    graph_prompter_hf_embeddings = (
        graph_representation_generator_graph_prompter_hf.generate_embeddings(
            kg_manager.llm_df
        )
    )
    save_graph_prompter_hf = True

kg_manager.append_prompt_graph_embeddings(prompt_embeddings, save=save_prompt)
kg_manager.append_graph_prompter_hf_graph_embeddings(
    graph_prompter_hf_embeddings, save=save_graph_prompter_hf
)


In [5]:
VANILLA_ROOT = f"{ROOT}/llm/vanilla"
PROMPT_ROOT = f"{ROOT}/llm/prompt"
INPUT_EMBEDS_REPLACE_ROOT = f"{ROOT}/llm/graph_prompter_hf"

In [6]:
vanilla_bert_classifier = VanillaBertClassifier(
    kg_manager.llm_df,
    kg_manager.source_df,
    kg_manager.target_df,
    root_path=VANILLA_ROOT,
    false_ratio=-1,
)
graph_prompter_hf_bert_classifier = GraphPrompterHFClassifier(
    kg_manager,
    graph_representation_generator_graph_prompter_hf.get_embedding,
    root_path=INPUT_EMBEDS_REPLACE_ROOT,
    false_ratio=-1,
)


Some weights of BertForSequenceClassificationRanges were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassificationRanges were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GraphPrompterHFBertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
dataset_vanilla = kg_manager.generate_vanilla_dataset(
    vanilla_bert_classifier.tokenize_function
)
dataset_graph_prompter_hf = (
    kg_manager.generate_graph_prompter_hf_embedding_dataset(
        graph_prompter_hf_bert_classifier.tokenizer.sep_token,
        graph_prompter_hf_bert_classifier.tokenizer.pad_token,
        graph_prompter_hf_bert_classifier.tokenize_function,
    )
)

In [9]:
BATCH_SIZE = 64
vanilla_bert_classifier.forward_dataset_and_save_outputs(
    dataset_vanilla,
    kg_manager.get_vanilla_tokens_as_df,
    force_recompute=True,
    batch_size=BATCH_SIZE,
    load_fields=["hidden_states", "logits", "attentions"],
    is_test=True,
    splits=["val"],
)
graph_prompter_hf_bert_classifier.forward_dataset_and_save_outputs(
    dataset_graph_prompter_hf,
    kg_manager.get_vanilla_tokens_as_df,
    force_recompute=True,
    load_fields=["hidden_states", "logits", "attentions"],
    batch_size=BATCH_SIZE,
    is_test=True,
    splits=["val"],
)

Forward val where ('cls',) are masked
Forward val where ('user_id',) are masked
Forward val where ('movie_id',) are masked
Forward val where ('title',) are masked
Forward val where ('genres',) are masked
Forward val where ('seps',) are masked
Forward val where ('cls', 'user_id') are masked
Forward val where ('cls', 'movie_id') are masked
Forward val where ('cls', 'title') are masked
Forward val where ('cls', 'genres') are masked
Forward val where ('cls', 'seps') are masked
Forward val where ('user_id', 'movie_id') are masked
Forward val where ('user_id', 'title') are masked
Forward val where ('user_id', 'genres') are masked
Forward val where ('user_id', 'seps') are masked
Forward val where ('movie_id', 'title') are masked
Forward val where ('movie_id', 'genres') are masked
Forward val where ('movie_id', 'seps') are masked
Forward val where ('title', 'genres') are masked
Forward val where ('title', 'seps') are masked
Forward val where ('genres', 'seps') are masked
Forward val where ('cl