In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from tqdm import tqdm
import torch

import sys
sys.path.append('../')

from redkg.dataloader import TrainDataset, get_info
from redkg.kge import KGEModel
from redkg.config import Config
from redkg.utils import AttributeDict
from redkg.train import train_kge_model


from ogb.linkproppred import LinkPropPredDataset, Evaluator

  from .autonotebook import tqdm as notebook_tqdm


## BioKG
<img src="biokg_logo2.png" width="400">

The **biokg** dataset is a Knowledge Graph (KG), that was created using data from a large number of biomedical data repositories. It contains 5 types of entities: diseases (10,687 nodes), proteins (17,499), drugs (10,533 nodes), side effects (9,969 nodes), and protein functions (45,085 nodes). There are 51 types of directed relations connecting two types of entities, including 38 kinds of drug-drug interactions, 8 kinds of protein-protein interaction, as well as drug-protein, drug-side effect, function-function relations. All relations are modeled as directed edges, among which the relations connecting the same entity types (e.g., protein-protein, drug-drug, function-function) are always symmetric, i.e., the edges are bi-directional.

This dataset is relevant to both biomedical and fundamental ML research. On the biomedical side, the dataset allows us to get better insights into human biology and generate predictions that can guide downstream biomedical research. On the fundamental ML side, the dataset presents challenges in handling a noisy, incomplete KG with possible contradictory observations. This is because the ogbl-biokg dataset involves heterogeneous interactions that span from the molecular scale (e.g., protein-protein interactions within a cell) to whole populations (e.g., reports of unwanted side effects experienced by patients in a particular country). Further, triplets in the KG come from sources with a variety of confidence levels, including experimental readouts, human-curated annotations, and automatically extracted metadata.

In [3]:
dataset_name = 'ogbl-biokg'

dataset = LinkPropPredDataset(name = dataset_name, root = '../data')
split_edge = dataset.get_edge_split()
train_triples, valid_triples, test_triples = split_edge["train"], split_edge["valid"], split_edge["test"]

info = get_info(dataset, train_triples)

In [5]:
evaluator = Evaluator(name = dataset_name)

kge_model = KGEModel(
        model_name="TransE",
        nentity=info['nentity'],
        nrelation=info['nrelation'],
        hidden_dim=128,
        gamma=12.0,
        double_entity_embedding=True,
        double_relation_embedding=True,
        evaluator=evaluator
    )

## Training model

In [26]:
train_pars = AttributeDict()
train_pars.cuda = False
train_pars.uni_weight = True
train_pars.negative_adversarial_sampling = True
train_pars.regularization = 0.0
train_pars.adversarial_temperature = 1.0
train_pars.train_batch_size = 128
train_pars.negative_sample_size = 128
train_pars.learning_rate = 0.001
train_pars.cpu_num = 10
train_pars.negative_mode = "full"
train_pars.neg_size_eval_train = 500
train_pars.test_log_steps = 1000
train_pars.test_batch_size = 128
train_pars.nentity = info['nentity']  
train_pars.nrelation = info['nrelation']
train_pars.do_test = True

In [46]:
training_logs, test_logs = train_kge_model(kge_model, train_pars, test_params, info, train_triples, valid_triples, test_triples)

Training...


## Results visualization

In [64]:
test_logs

[{'hits@1_list': 0.0033673858270049095,
  'hits@3_list': 0.009205824695527554,
  'hits@10_list': 0.029947325587272644,
  'mrr_list': 0.0193993728607893},
 {'hits@1_list': 0.003305993042886257,
  'hits@3_list': 0.009122944436967373,
  'hits@10_list': 0.029704824090003967,
  'mrr_list': 0.01936676912009716},
 {'hits@1_list': 0.00337966438382864,
  'hits@3_list': 0.009211963973939419,
  'hits@10_list': 0.029950395226478577,
  'mrr_list': 0.019408494234085083},
 {'hits@1_list': 0.0033213412389159203,
  'hits@3_list': 0.009098388254642487,
  'hits@10_list': 0.0294254869222641,
  'mrr_list': 0.019350754097104073},
 {'hits@1_list': 0.0032752968836575747,
  'hits@3_list': 0.009138292632997036,
  'hits@10_list': 0.02936716563999653,
  'mrr_list': 0.01935487426817417},
 {'hits@1_list': 0.003315201960504055,
  'hits@3_list': 0.009098388254642487,
  'hits@10_list': 0.029385583475232124,
  'mrr_list': 0.01936601661145687},
 {'hits@1_list': 0.0033029234036803246,
  'hits@3_list': 0.00890500098466873

In [63]:
training_logs

[{'positive_sample_loss': 5.781314849853516,
  'negative_sample_loss': 0.14729920029640198,
  'loss': 2.9643070697784424},
 {'positive_sample_loss': 5.593881607055664,
  'negative_sample_loss': 0.04241343215107918,
  'loss': 2.8181474208831787},
 {'positive_sample_loss': 5.6944122314453125,
  'negative_sample_loss': 0.059163790196180344,
  'loss': 2.8767879009246826},
 {'positive_sample_loss': 5.279773235321045,
  'negative_sample_loss': 0.09559981524944305,
  'loss': 2.6876864433288574},
 {'positive_sample_loss': 5.365677356719971,
  'negative_sample_loss': 0.18596382439136505,
  'loss': 2.77582049369812},
 {'positive_sample_loss': 5.616865634918213,
  'negative_sample_loss': 0.24613013863563538,
  'loss': 2.931497812271118},
 {'positive_sample_loss': 5.645710468292236,
  'negative_sample_loss': 0.15487168729305267,
  'loss': 2.9002909660339355},
 {'positive_sample_loss': 5.437168598175049,
  'negative_sample_loss': 0.08542022109031677,
  'loss': 2.761294364929199},
 {'positive_sample