# PrimeKG Subgraph Construction (Multi-Modal)


First of all, we need to import necessary libraries as follows:

In [1]:
# Import necessary libraries
import os
os.environ["OPENAI_API_KEY"] = "xxx"
os.environ["NVCF_RUN_KEY"] = "xxx"

import numpy as np
import pandas as pd
import networkx as nx
import pickle
from tqdm import tqdm
import torch 
from torch_geometric.utils import from_networkx
import sys
sys.path.append('../../..')
from aiagents4pharma.talk2knowledgegraphs.datasets.starkqa_primekg import StarkQAPrimeKG
from aiagents4pharma.talk2knowledgegraphs.datasets.biobridge_primekg import BioBridgePrimeKG
from aiagents4pharma.talk2knowledgegraphs.utils.embeddings.ollama import EmbeddingWithOllama
from aiagents4pharma.talk2knowledgegraphs.utils import kg_utils

# # Set the logging level for httpx to WARNING to suppress INFO messages
import logging
logging.getLogger("httpx").setLevel(logging.WARNING)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define biobridge primekg data by providing a local directory where the data is stored
biobridge_data = BioBridgePrimeKG(primekg_dir="../../../../data/primekg/",
                                  local_dir="../../../../data/biobridge_primekg/")

# Invoke a method to load the data
biobridge_data.load_data()

# Get the node information of the BioBridge PrimeKG
biobridge_node_info = biobridge_data.get_node_info_dict()
biobridge_node_info.keys()

Loading PrimeKG dataset...
Loading nodes of PrimeKG dataset ...
../../../../data/primekg/primekg_nodes.tsv.gz already exists. Loading the data from the local directory.
Loading edges of PrimeKG dataset ...
../../../../data/primekg/primekg_edges.tsv.gz already exists. Loading the data from the local directory.
Loading data config file of BioBridgePrimeKG...
File data_config.json already exists in ../../../../data/biobridge_primekg/.
Building node embeddings...
Building full triplets...
Building train-test split...
Building negative triplets...


dict_keys(['gene/protein', 'molecular_function', 'cellular_component', 'biological_process', 'drug', 'disease'])

### Training

In [3]:
# Training configuration
hidden_dim=768 # the hidden dimension of the transformation model
n_layer=6 # the number of transformer layers
batch_size=1024 # the training batch size
learning_rate=1.6e-3 # the learning ratesss
n_epoch=10 # the number of training epochs
weight_decay=1e-4 # the weight decay
eval_steps=10 # the number of steps to evaluate the |model
save_dir="./checkpoints/model-1" # the directory to save the model
dataloader_num_workers=4 # the number of workers for data loading
use_wandb=False, # whether to use wandb

In [4]:
os.makedirs(save_dir, exist_ok=True)
save_dir

'./checkpoints/model-1'

In [5]:
class TrainDataset(torch.utils.data.Dataset):
    """Use the train set for contrastive learning with InfoNCE loss.
    """
    def __init__(self, triplet, node):
        self.triplet = triplet
        self.node = node

    def __getitem__(self, index):
        return self.triplet.iloc[index]

    def __len__(self):
        return len(self.triplet)

In [6]:
class ValDataset(torch.utils.data.Dataset):
    """Use the test set to evaluate the retrieval performance of the model. The evaluation is performed in this way:
    1. get the raw embeddings of all nodes in `self.node_all`.
    2. for each relation type, transform head node embedding to tail node embedding using the transformation model.
    3. match the transformed embedding with the raw embedding of the target node.

    Args:
        node_test (pd.DataFrame): the test node dataframe
        triplet (pd.DataFrame): the **all** triplet dataframe
        node_all (pd.DataFrame): the **all** node dataframe. Need to encode them all when evaluating.
        target_node_type_index (int): the target node type to consider for evaluation and prediction.
            Defaults to None and use all node types.
        target_relation (int): the `display relation` type to consider for evaluation. 
        frequent_threshold (int, optional): the tail node that appears less than this threshold will be removed in the evaluation.
            Defaults to None and use all nodes.
    """
    def __init__(self,
                 node_test,
                 triplet_all,
                 node_all,
                 target_node_type_index=None,
                 target_relation=None,
                 frequent_threshold=None,
                 ):
        self.target_relation = target_relation
        self.frequent_threshold = frequent_threshold
        self.target_node_type_index = target_node_type_index
        if target_relation is not None:

            # filter the triplet and node dataframe by the relation
            # only maintain triplets with the relation in the relation list
            # only maintain the nodes that appear in the triplets
            # only maintain the test nodes that appear in the triplets
            triplet_all = triplet_all[triplet_all['display_relation'].isin([target_relation])].reset_index(drop=True).copy()
            all_node_index  = pd.concat([triplet_all["head_index"], triplet_all["tail_index"]]).unique()
            node_all = node_all[node_all["node_index"].isin(all_node_index)].reset_index(drop=True).copy()
            node_test = node_test[node_test["node_index"].isin(all_node_index)].reset_index(drop=True).copy()

        if target_node_type_index is not None:
            # filter the triplet that has head_type equal to the target node type
            triplet_all = triplet_all[triplet_all["head_type"] == target_node_type_index].reset_index(drop=True).copy()
            # only choose the test nodes that are the target node type
            node_test = node_test[node_test["node_type"] == target_node_type_index].reset_index(drop=True).copy()
            # only choose node_all that are the head node in node_test and the tail node in triplet_all
            all_node_index = pd.concat([node_test["node_index"], triplet_all["tail_index"]]).unique()
            node_all = node_all[node_all["node_index"].isin(all_node_index)].reset_index(drop=True).copy()

        # filter out the target node in the triplet that is not frequent enough
        if self.frequent_threshold is not None:
            val_counts = triplet_all["tail_index"].value_counts()
            frequent_node_index = val_counts[val_counts >= self.frequent_threshold].index
            triplet_all = triplet_all[triplet_all["tail_index"].isin(frequent_node_index)].reset_index(drop=True).copy()
            all_node_index = pd.concat([node_test["node_index"], triplet_all["tail_index"]]).unique()
            node_all = node_all[node_all["node_index"].isin(all_node_index)].reset_index(drop=True).copy()

        # filter out the test node that does have a tail node in the triplet
        node_test_new = node_test[node_test["node_index"].isin(triplet_all["head_index"])].reset_index(drop=True).copy()
        if len(node_test_new) != len(node_test):
            print(f"Warning: {len(node_test) - len(node_test_new)} test nodes are removed because they do not have a tail node in the triplet.")
            # find the difference between the two dataframes
            diff_index = node_test["node_index"][~node_test["node_index"].isin(node_test_new["node_index"])]
            # filter out node all
            node_all = node_all[~node_all["node_index"].isin(diff_index)].reset_index(drop=True).copy()
            node_test = node_test_new

        # save the filtered dataframes
        self.triplet = triplet_all
        self.node = node_test
        self.node_all = node_all   
        self.tail_node_types = self.triplet["tail_type"].unique()

    def __getitem__(self, index):
        # get the positive tail_index and all candidate tail_index from the same type
        row = self.node.iloc[index]
        triplet = self.triplet[self.triplet["head_index"] == row["node_index"]]
        outputs = {
            "head_index": row["node_index"],
            "head_type": row["node_type"],
            "tail_index": triplet["tail_index"].tolist(),
            "tail_type": triplet["tail_type"].tolist(),
            "display_relation": triplet["display_relation"].tolist(),
            "relation": triplet["relation"].tolist(),
        }
        return outputs
    
    def __len__(self):
        return len(self.node)
    
    def get_all_node(self):
        return self.node_all
    
    def get_all_triplet(self):
        return self.triplet

In [7]:
biobridge_data.get_train_test_split()["test"]

Unnamed: 0,head_index,head_name,head_source,head_id,head_type,tail_index,tail_name,tail_source,tail_id,tail_type,display_relation,relation
0,8,MT1A,NCBI,4489,1,1785,TP53,NCBI,7157,1,3,protein_protein
1,12,CD7,NCBI,924,1,7681,SFXN5,NCBI,94097,1,3,protein_protein
2,16,SNRPD2,NCBI,6633,1,3235,PRPF4,NCBI,9128,1,3,protein_protein
3,19,VAV3,NCBI,10451,1,3005,ZRANB1,NCBI,54764,1,3,protein_protein
4,16,SNRPD2,NCBI,6633,1,216,NCSTN,NCBI,23385,1,3,protein_protein
...,...,...,...,...,...,...,...,...,...,...,...,...
393675,125342,myosin V complex,GO,31475,7,9639,DYNLL2,NCBI,140735,1,2,cellcomp_protein
393676,55608,extracellular membrane-bounded organelle,GO,65010,7,57129,PHOSPHO1,NCBI,162466,1,2,cellcomp_protein
393677,124243,axonemal outer doublet,GO,97545,7,59351,CFAP100,NCBI,348807,1,2,cellcomp_protein
393678,124243,axonemal outer doublet,GO,97545,7,59352,CFAP73,NCBI,387885,1,2,cellcomp_protein


In [8]:
biobridge_data.get_primekg_triplets_negative()

Unnamed: 0,head_index,head_name,head_source,head_id,head_type,tail_index,tail_name,tail_source,tail_id,tail_type,display_relation,relation,negative_tail_index
0,0,PHYHIP,NCBI,9796,1,8889,KIF15,NCBI,56992,1,3,protein_protein,"[1533, 13199, 3392, 58453, 2320, 5335, 5931, 6..."
1,1,GPANK1,NCBI,7918,1,2798,PNMA1,NCBI,9240,1,3,protein_protein,"[3703, 3058, 12245, 77327, 1523, 11417, 8180, ..."
2,2,ZRSR2,NCBI,8233,1,5646,TTC33,NCBI,23548,1,3,protein_protein,"[57364, 4827, 3618, 4619, 13537, 2283, 13604, ..."
3,3,NRF1,NCBI,4899,1,11592,MAN1B1,NCBI,11253,1,3,protein_protein,"[9075, 14006, 57630, 58767, 59599, 566, 9093, ..."
4,4,PI4KA,NCBI,5297,1,2122,RGS20,NCBI,8601,1,3,protein_protein,"[1104, 9454, 11225, 6657, 13626, 5516, 12844, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
3510925,124473,longitudinal sarcoplasmic reticulum,GO,14801,7,58744,DHRS7C,NCBI,201140,1,2,cellcomp_protein,"[274, 58856, 34400, 9268, 6714, 1526, 58741, 1..."
3510926,55747,myofilament,GO,36379,7,57367,MYBPHL,NCBI,343263,1,2,cellcomp_protein,"[58257, 56530, 521, 823, 2377, 1444, 5686, 117..."
3510927,126945,lateral wall of outer hair cell,GO,120249,7,22033,SLC26A5,NCBI,375611,1,2,cellcomp_protein,"[59304, 57686, 10627, 12313, 6187, 2347, 1572,..."
3510928,125456,Swi5-Swi2 complex,GO,34974,7,57415,SWI5,NCBI,375757,1,2,cellcomp_protein,"[7770, 4549, 59268, 11118, 10631, 2777, 6701, ..."


In [9]:
biobridge_data.get_primekg_triplets()

Unnamed: 0,head_index,head_name,head_source,head_id,head_type,tail_index,tail_name,tail_source,tail_id,tail_type,display_relation,relation
0,0,PHYHIP,NCBI,9796,1,8889,KIF15,NCBI,56992,1,3,protein_protein
1,1,GPANK1,NCBI,7918,1,2798,PNMA1,NCBI,9240,1,3,protein_protein
2,2,ZRSR2,NCBI,8233,1,5646,TTC33,NCBI,23548,1,3,protein_protein
3,3,NRF1,NCBI,4899,1,11592,MAN1B1,NCBI,11253,1,3,protein_protein
4,4,PI4KA,NCBI,5297,1,2122,RGS20,NCBI,8601,1,3,protein_protein
...,...,...,...,...,...,...,...,...,...,...,...,...
3904605,52855,B cell receptor transport into membrane raft,GO,32597,0,34572,CD24,NCBI,100133941,1,2,bioprocess_protein
3904606,113352,chemokine receptor transport out of membrane raft,GO,32600,0,34572,CD24,NCBI,100133941,1,2,bioprocess_protein
3904607,42264,negative regulation of cytoskeleton organization,GO,51494,0,57675,IQCJ-SCHIP1,NCBI,100505385,1,2,bioprocess_protein
3904608,109904,mesendoderm migration,GO,90133,0,58770,APELA,NCBI,100506013,1,2,bioprocess_protein


In [10]:
train_split = biobridge_data.get_primekg_triplets_negative()
test_split = biobridge_data.get_train_test_split()["test"]
node_train_split = biobridge_data.get_train_test_split()["node_train"]
node_test_split = biobridge_data.get_train_test_split()["node_test"]

df_all = biobridge_data.get_primekg_triplets()
df_node_all = pd.concat([node_train_split, node_test_split], axis=0).reset_index(drop=True)

# drop duplicate nodes and triples
train_split = train_split.drop_duplicates(subset=["head_index", "tail_index", "display_relation"]).reset_index(drop=True)
# test_split = test_split.drop_duplicates(subset=["head_index", "tail_index", "display_relation"]).reset_index(drop=True)
node_train_split = node_train_split.drop_duplicates(subset=["node_index"]).reset_index(drop=True)
node_test_split = node_test_split.drop_duplicates(subset=["node_index"]).reset_index(drop=True)
df_all = df_all.drop_duplicates(subset=["head_index", "tail_index", "display_relation"]).reset_index(drop=True)
df_node_all = df_node_all.drop_duplicates(subset=["node_index"]).reset_index(drop=True)

split_data = {
    "train": train_split,
    "test": test_split,
    "node_train": node_train_split,
    "node_test": node_test_split,
    "all": df_all,
    "node_all": df_node_all,
}

In [11]:
def build_model_config(data_config):
    # build model config
    model_config = {
        "n_node": len(data_config["node_type"]),
        "n_relation": len(data_config["relation_type"]),
        }
    proj_dim = {}
    for node_type, dim in data_config["emb_dim"].items():
        proj_dim[data_config["node_type"][node_type]] = dim
    model_config["proj_dim"] = proj_dim
    return model_config

In [12]:
train_data = TrainDataset(**{"triplet":split_data["train"], 
                             "node":split_data["node_train"]})

In [13]:
split_data["node_all"]

Unnamed: 0,node_index,node_type
0,0,1
1,1,1
2,2,1
3,3,1
4,4,1
...,...,...
84976,127404,7
84977,127415,7
84978,127421,7
84979,127425,7


In [14]:
val_data = ValDataset(**{"triplet_all":split_data["all"], 
                         "node_test":split_data["node_test"],
                         "node_all":split_data["node_all"],
                         "target_relation": 2, # only consider the evaluation on one relation, 2: `interact with`
                         "target_node_type_index": 1, # the index of the target node type: protein/gene is 1
                         "frequent_threshold": 50, # the threshold of the frequent node
                         })



In [15]:
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [16]:
data_config = biobridge_data.get_data_config()
data_config

{'node_type': {'biological_process': 0,
  'gene/protein': 1,
  'disease': 2,
  'effect/phenotype': 3,
  'anatomy': 4,
  'molecular_function': 5,
  'drug': 6,
  'cellular_component': 7,
  'pathway': 8,
  'exposure': 9},
 'relation_type': {'expression present': 0,
  'synergistic interaction': 1,
  'interacts with': 2,
  'ppi': 3,
  'phenotype present': 4,
  'parent-child': 5,
  'associated with': 6,
  'side effect': 7,
  'contraindication': 8,
  'expression absent': 9,
  'target': 10,
  'indication': 11,
  'enzyme': 12,
  'transporter': 13,
  'off-label use': 14,
  'linked to': 15,
  'phenotype absent': 16,
  'carrier': 17},
 'emb_dim': {'biological_process': 768,
  'cellular_component': 768,
  'disease': 768,
  'drug': 512,
  'molecular_function': 768,
  'gene/protein': 2560}}

In [17]:
# import urllib.request
# urls = ['https://raw.githubusercontent.com/RyanWangZf/BioBridge/refs/heads/main/src/losses.py',
#         'https://raw.githubusercontent.com/RyanWangZf/BioBridge/refs/heads/main/src/model.py',
#         'https://raw.githubusercontent.com/RyanWangZf/BioBridge/refs/heads/main/src/trainer.py',
#         'https://raw.githubusercontent.com/RyanWangZf/BioBridge/refs/heads/main/src/schema.py',
#         'https://raw.githubusercontent.com/RyanWangZf/BioBridge/refs/heads/main/src/collator.py']
# os.makedirs("biobridge", exist_ok=True)
# for url in urls:
#        filename = url.split('/')[-1]
#        urllib.request.urlretrieve(url, "biobridge/"+filename)

# init_file_path = os.path.join("biobridge", "__init__.py")    
# with open(init_file_path, "w", encoding="utf-8") as f:
#       f.write("")

In [18]:
import json
from docs.notebooks.talk2knowledgegraphs.biobridge.model import BindingModel


# build the model
print("### Model Configuration ###")
# build model config
model_config = build_model_config(data_config)
model_config["hidden_dim"] = hidden_dim
model_config["n_layer"] = n_layer
print(json.dumps(model_config, indent=4))
model = BindingModel(**model_config)
model.to(device)

### Model Configuration ###
{
    "n_node": 10,
    "n_relation": 18,
    "proj_dim": {
        "0": 768,
        "7": 768,
        "2": 768,
        "6": 512,
        "5": 768,
        "1": 2560
    },
    "hidden_dim": 768,
    "n_layer": 6
}


BindingModel(
  (paired_loss_fn): InfoNCE()
  (unpaired_loss_fn): InfoNCE()
  (node_type_embed): Sequential(
    (0): Embedding(10, 768)
    (1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (relation_type_embed): Sequential(
    (0): Embedding(18, 768)
    (1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (proj_layer): ModuleDict(
    (0): Linear(in_features=768, out_features=768, bias=False)
    (7): Linear(in_features=768, out_features=768, bias=False)
    (2): Linear(in_features=768, out_features=768, bias=False)
    (6): Linear(in_features=512, out_features=768, bias=False)
    (5): Linear(in_features=768, out_features=768, bias=False)
    (1): Linear(in_features=2560, out_features=768, bias=False)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
 

In [19]:
# save model config to the save directory
with open(os.path.join(save_dir, "model_config.json"), "w") as f:
    json.dump(model_config, f, indent=4)

In [20]:
from transformers import TrainingArguments
from transformers.trainer_utils import speed_metrics
from transformers.debug_utils import DebugOption
from transformers.trainer_utils import (
    EvalPrediction,
)

In [21]:
# build trainer
train_args = TrainingArguments(
    output_dir=save_dir,
    overwrite_output_dir=True,
    num_train_epochs=n_epoch,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=2, # every node corresponds to multiple tail nodes
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    logging_steps=10,
    save_steps=10,
    save_total_limit=5,
    evaluation_strategy="steps",
    eval_steps=eval_steps,
    max_grad_norm=1.0, # gradient clipping
    warmup_ratio=0.1,
    dataloader_num_workers=dataloader_num_workers, # number of processes to use for dataloading
    report_to="wandb" if use_wandb else "none",
    )



In [None]:
pip install transformers[torch]
pip install accelerate==1.4.0
pip install wandb==0.19.8
pip install lightning==2.5.0.post0

SyntaxError: invalid syntax (52660047.py, line 1)

In [23]:
print("### Training Arguments ###")
print(json.dumps(train_args.to_dict(), indent=4))

print("### Number of Trainable Parameters ###")
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

### Training Arguments ###
{
    "output_dir": "./checkpoints/model-1",
    "overwrite_output_dir": true,
    "do_train": false,
    "do_eval": true,
    "do_predict": false,
    "eval_strategy": "steps",
    "prediction_loss_only": false,
    "per_device_train_batch_size": 1024,
    "per_device_eval_batch_size": 2,
    "per_gpu_train_batch_size": null,
    "per_gpu_eval_batch_size": null,
    "gradient_accumulation_steps": 1,
    "eval_accumulation_steps": null,
    "eval_delay": 0,
    "torch_empty_cache_steps": null,
    "learning_rate": 0.0016,
    "weight_decay": 0.0001,
    "adam_beta1": 0.9,
    "adam_beta2": 0.999,
    "adam_epsilon": 1e-08,
    "max_grad_norm": 1.0,
    "num_train_epochs": 10,
    "max_steps": -1,
    "lr_scheduler_type": "linear",
    "lr_scheduler_kwargs": {},
    "warmup_ratio": 0.1,
    "warmup_steps": 0,
    "log_level": "passive",
    "log_on_each_node": true,
    "logging_dir": "./checkpoints/model-1/runs/Mar13_12-59-40_awmulyadi",
    "logging_strategy

In [30]:
from typing import Dict, List, Optional
from collections import defaultdict
from transformers.trainer_utils import (
    EvalPrediction,
)

def compute_metrics(inputs: EvalPrediction) -> Dict:
    """Compute the metrics for the prediction."""
    metrics = defaultdict(list)
    predictions = inputs.predictions
    # print(inputs.predictions)
    # print(predictions)
    # print(predictions['prediction'])
    num_samples = len(predictions["node_index"])
    node_types = predictions["node_type"]
    all_tail_types = list(predictions['prediction'].keys())
    for tail_type in all_tail_types:
        preds, labels = predictions['prediction'][tail_type], predictions['label'][tail_type]
        for i in range(num_samples):
            # compute r@k, k = 5, 10, 20
            # compute ndcg@k, k = 5
            # recall: tp / (tp+fn)
            node_types_i = node_types[i]
            pred, label = preds[i], labels[i]
            label = label[label!=-100]
            if len(label) > 0:
                # only consider the case where the label is not empty
                rec_5 = len(set(pred[0][:5].tolist()).intersection(set(label.tolist()))) / len(label)
                metrics[f"head_{node_types_i}_tail_{tail_type}_rec@5"].append(rec_5)
                rec_10 = len(set(pred[0][:10].tolist()).intersection(set(label.tolist()))) / len(label)
                metrics[f"head_{node_types_i}_tail_{tail_type}_rec@10"].append(rec_10)
                rec_20 = len(set(pred[0][:20].tolist()).intersection(set(label.tolist()))) / len(label)
                metrics[f"head_{node_types_i}_tail_{tail_type}_rec@20"].append(rec_20)

    # compute the sample average
    new_metrics = {}
    for k, v in metrics.items():
        new_metrics[k] = np.mean(v)

    # TODO: average over all tail types if more than one tail type
    if len(all_tail_types) > 1:
        pass

    return new_metrics

In [25]:
# from typing import Dict, List, Optional
# from collections import defaultdict
# from transformers.trainer_utils import (
#     EvalPrediction,
# )


# def compute_metrics(inputs: EvalPrediction) -> Dict:
#     """Compute the metrics for the prediction."""
#     metrics = defaultdict(list)
#     predictions_list = inputs.predictions[0]
#     print("Type of predictions_list:", type(predictions_list))
#     print("First element in predictions_list:", predictions_list)
    
#     num_samples = len(predictions_list)
#     print("Number of samples:", num_samples)
    
#     predictions_dict = {
#         "node_index": [],
#         "node_type": [],
#         "prediction": defaultdict(list),
#         "label": defaultdict(list)
#     }

    
#     for prediction in predictions_list:
#         print("Prediction", prediction)
#         predictions_dict["node_index"].append(prediction["node_index"])
#         predictions_dict["node_type"].append(prediction["node_type"])
#         for key in prediction["prediction"]:
#             predictions_dict["prediction"][key].append(prediction["prediction"][key])
#         for key in prediction["label"]:
#             predictions_dict["label"][key].append(prediction["label"][key])
    
#     node_types = predictions_dict["node_type"]
#     all_tail_types = list(predictions_dict['prediction'].keys())

#     for tail_type in all_tail_types:
#         preds, labels = predictions_dict['prediction'][tail_type], predictions_dict['label'][tail_type]
        
#         for i in range(num_samples):
#             node_types_i = node_types[i]
#             pred, label = preds[i], labels[i]
#             label = label[label != -100]
            
#             if len(label) > 0:
#                 rec_5 = len(set(pred[:5]).intersection(set(label))) / len(label)
#                 metrics[f"head_{node_types_i}_tail_{tail_type}_rec@5"].append(rec_5)
#                 rec_10 = len(set(pred[:10]).intersection(set(label))) / len(label)
#                 metrics[f"head_{node_types_i}_tail_{tail_type}_rec@10"].append(rec_10)
#                 rec_20 = len(set(pred[:20]).intersection(set(label))) / len(label)
#                 metrics[f"head_{node_types_i}_tail_{tail_type}_rec@20"].append(rec_20)

#     new_metrics = {k: np.mean(v) for k, v in metrics.items()}
    
#     return new_metrics

In [26]:
embedding_dict = biobridge_data.get_node_embeddings()

In [31]:
import sys
sys.path.append("..")
from biobridge.trainer import BindingTrainer
from biobridge.collator import TrainCollator, ValCollator

# build trainer
trainer = BindingTrainer(
    args=train_args,
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=TrainCollator(embedding_dict),
    test_data_collator=ValCollator(embedding_dict),
    compute_metrics=compute_metrics,
    )

In [32]:
# WandDB API Key
# abee8b0cb03cd169a5bfc6bb2f5e9dd483980f4e

In [33]:
# train the model
trainer.train()

Step,Training Loss,Validation Loss,Head 1 Tail 5 Rec@5,Head 1 Tail 5 Rec@10,Head 1 Tail 5 Rec@20,Head 1 Tail 7 Rec@5,Head 1 Tail 7 Rec@10,Head 1 Tail 7 Rec@20,Head 1 Tail 0 Rec@5,Head 1 Tail 0 Rec@10,Head 1 Tail 0 Rec@20
10,1.7695,No log,0.406749,0.432175,0.534897,0.007398,0.051285,0.084021,0.008913,0.017347,0.037576
20,1.7591,No log,0.460967,0.512326,0.549922,0.013132,0.052904,0.236511,0.009102,0.029819,0.046011
30,1.7253,No log,0.435522,0.531744,0.552096,0.01003,0.179308,0.409008,0.005681,0.019328,0.058344
40,1.7026,No log,0.43906,0.53208,0.553004,0.054861,0.188183,0.412271,0.005315,0.022202,0.054238
50,1.6772,No log,0.440259,0.487042,0.552985,0.107789,0.209164,0.471282,0.006052,0.018103,0.056348
60,1.6552,No log,0.441084,0.48812,0.552985,0.159105,0.214983,0.475639,0.003717,0.019454,0.051394
70,1.643,No log,0.43628,0.489048,0.556655,0.171552,0.376764,0.417422,0.005339,0.030969,0.051612
80,1.6311,No log,0.43628,0.526593,0.554914,0.291318,0.380528,0.425163,0.004964,0.013112,0.025164
90,1.6117,No log,0.434846,0.526593,0.554827,0.291318,0.38826,0.584831,0.00518,0.015869,0.033112
100,1.6058,No log,0.435891,0.526593,0.555132,0.250192,0.443707,0.584473,0.005034,0.017768,0.036017


INFO:biobridge.trainer:***** Running Node Encoding for Evaluation *****
INFO:biobridge.trainer:  Num examples = 2677
INFO:biobridge.trainer:***** Running Evaluation *****
INFO:biobridge.trainer:  Num examples = 1832
INFO:biobridge.trainer:  Batch size = 2
Prediction: 100%|██████████| 916/916 [00:12<00:00, 70.47it/s] 
INFO:biobridge.trainer:***** Running Node Encoding for Evaluation *****
INFO:biobridge.trainer:  Num examples = 2677
INFO:biobridge.trainer:***** Running Evaluation *****
INFO:biobridge.trainer:  Num examples = 1832
INFO:biobridge.trainer:  Batch size = 2
Prediction: 100%|██████████| 916/916 [00:13<00:00, 66.82it/s] 
INFO:biobridge.trainer:***** Running Node Encoding for Evaluation *****
INFO:biobridge.trainer:  Num examples = 2677
INFO:biobridge.trainer:***** Running Evaluation *****
INFO:biobridge.trainer:  Num examples = 1832
INFO:biobridge.trainer:  Batch size = 2
Prediction: 100%|██████████| 916/916 [00:13<00:00, 68.55it/s] 
INFO:biobridge.trainer:***** Running Node E

KeyboardInterrupt: 

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7827c81ae600>> (for post_run_cell), with arguments args (<ExecutionResult object at 7827a8d12750, execution_count=33 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7827a8d12ed0, raw_cell="# train the model
trainer.train()" store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/home/awmulyadi/Repositories/temp/office2/AIAgents4Pharma/docs/notebooks/talk2knowledgegraphs/tutorial_biobridge_primekg_training.ipynb#X43sZmlsZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

In [None]:
# save the model    
trainer.save_model(save_dir)