In [32]:
%load_ext autoreload
%autoreload 2
%matplotlib inline


from collections import defaultdict
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader

import sys
sys.path.append('../')

from redkg.dataloader import read_kg, BidirectionalOneShotIterator  #TrainDataset, 
from redkg.kge import KGEModel
from redkg.config import Config
from redkg.utils import AttributeDict

from ogb.linkproppred import LinkPropPredDataset, Evaluator

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [40]:
from torch.utils.data import Dataset
class TrainDataset(Dataset):
    def __init__(self, triples, nentity, nrelation, negative_sample_size, mode, count, true_head, true_tail, entity_dict):
        self.len = len(triples['head'])
        self.triples = triples
        self.nentity = nentity
        self.nrelation = nrelation
        self.negative_sample_size = negative_sample_size
        self.mode = mode
        self.count = count
        self.true_head = true_head
        self.true_tail = true_tail
        self.entity_dict = entity_dict
        
    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        head, relation, tail = self.triples['head'][idx], self.triples['relation'][idx], self.triples['tail'][idx]
        head_type, tail_type = self.triples['head_type'][idx], self.triples['tail_type'][idx]
        positive_sample = [head + self.entity_dict[head_type][0], relation, tail + self.entity_dict[tail_type][0]]

        subsampling_weight = self.count[(head, relation, head_type)] + self.count[(tail, -relation-1, tail_type)]
        subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight]))

        if self.mode == 'head-batch':
            negative_sample = torch.randint(self.entity_dict[head_type][0], self.entity_dict[head_type][1], (self.negative_sample_size,))
        elif self.mode == 'tail-batch':
            negative_sample = torch.randint(self.entity_dict[tail_type][0], self.entity_dict[tail_type][1], (self.negative_sample_size,))
        else:
            raise
        positive_sample = torch.LongTensor(positive_sample)
            
        return positive_sample, negative_sample, subsampling_weight, self.mode
    
    @staticmethod
    def collate_fn(data):
        positive_sample = torch.stack([_[0] for _ in data], dim=0)
        negative_sample = torch.stack([_[1] for _ in data], dim=0)
        subsample_weight = torch.cat([_[2] for _ in data], dim=0)
        mode = data[0][3]
        return positive_sample, negative_sample, subsample_weight, mode
    

<img src="biokg_logo.png" width="400">
The ogbl-biokg dataset is a Knowledge Graph (KG), which we created using data from a large number of biomedical data repositories. It contains 5 types of entities: diseases (10,687 nodes), proteins (17,499), drugs (10,533 nodes), side effects (9,969 nodes), and protein functions (45,085 nodes). There are 51 types of directed relations connecting two types of entities, including 38 kinds of drug-drug interactions, 8 kinds of protein-protein interaction, as well as drug-protein, drug-side effect, function-function relations. All relations are modeled as directed edges, among which the relations connecting the same entity types (e.g., protein-protein, drug-drug, function-function) are always symmetric, i.e., the edges are bi-directional.

This dataset is relevant to both biomedical and fundamental ML research. On the biomedical side, the dataset allows us to get better insights into human biology and generate predictions that can guide downstream biomedical research. On the fundamental ML side, the dataset presents challenges in handling a noisy, incomplete KG with possible contradictory observations. This is because the ogbl-biokg dataset involves heterogeneous interactions that span from the molecular scale (e.g., protein-protein interactions within a cell) to whole populations (e.g., reports of unwanted side effects experienced by patients in a particular country). Further, triplets in the KG come from sources with a variety of confidence levels, including experimental readouts, human-curated annotations, and automatically extracted metadata.

In [10]:
dataset_name = 'ogbl-biokg'

dataset = LinkPropPredDataset(name = dataset_name, root = '../data')

In [16]:
split_edge = dataset.get_edge_split()
train_triples, valid_triples, test_triples = split_edge["train"], split_edge["valid"], split_edge["test"]

In [18]:
nrelation = int(max(train_triples['relation']))+1
entity_dict = dict()
cur_idx = 0
for key in dataset[0]['num_nodes_dict']:
    entity_dict[key] = (cur_idx, cur_idx + dataset[0]['num_nodes_dict'][key])
    cur_idx += dataset[0]['num_nodes_dict'][key]
nentity = sum(dataset[0]['num_nodes_dict'].values())

evaluator = Evaluator(name = dataset_name)

In [19]:
train_count, train_true_head, train_true_tail = defaultdict(lambda: 4), defaultdict(list), defaultdict(list)
for i in tqdm(range(len(train_triples['head']))):
    head, relation, tail = train_triples['head'][i], train_triples['relation'][i], train_triples['tail'][i]
    head_type, tail_type = train_triples['head_type'][i], train_triples['tail_type'][i]
    train_count[(head, relation, head_type)] += 1
    train_count[(tail, -relation-1, tail_type)] += 1
    train_true_head[(relation, tail)].append(head)
    train_true_tail[(head, relation)].append(tail)

100%|██████████| 4762678/4762678 [00:16<00:00, 290269.91it/s]


In [21]:
kge_model = "TransE"
hidden_dim = 128
gamma = 12.0

kge_model = KGEModel(
        model_name=kge_model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=hidden_dim,
        gamma=gamma,
        double_entity_embedding=True,
        double_relation_embedding=True,
        evaluator=evaluator
    )

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, kge_model.parameters()), 
    lr=current_learning_rate
)

## Training model

In [44]:
negative_sample_size = 128
batch_size = 128
cpu_num = 10
current_learning_rate = 0.001

train_dataloader_head = DataLoader(
    TrainDataset(train_triples, nentity, nrelation, 
        negative_sample_size, 'head-batch',
        train_count, train_true_head, train_true_tail,
        entity_dict), 
    batch_size=batch_size,
    shuffle=True, 
    num_workers=max(1, cpu_num//2),
    collate_fn=TrainDataset.collate_fn
)

train_dataloader_tail = DataLoader(
    TrainDataset(train_triples, nentity, nrelation, 
        negative_sample_size, 'tail-batch',
        train_count, train_true_head, train_true_tail,
        entity_dict), 
    batch_size=batch_size,
    shuffle=True, 
    num_workers=max(1, cpu_num//2),
    collate_fn=TrainDataset.collate_fn
)
train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)

In [60]:
train_pars = AttributeDict()
train_pars.cuda = False
train_pars.uni_weight = True
train_pars.negative_adversarial_sampling = True
train_pars.regularization = 0.0
train_pars.adversarial_temperature = 1.0

test_params = AttributeDict()
test_params.cuda = False
test_params.nentity = nentity
test_params.nrelation = nrelation
test_params.neg_size_eval_train = 500
test_params.test_log_steps = 1000
test_params.test_batch_size = batch_size
test_params.cpu_num = cpu_num

In [61]:
training_logs = []
test_logs = []

max_steps = 10
#Training Loop
for step in range(max_steps):
    log = kge_model.train_step(kge_model, optimizer, train_iterator, train_pars)
    training_logs.append(log)
    
    metrics = kge_model.test_step(kge_model, valid_triples, test_params, entity_dict)
    test_logs.append(metrics)

In [64]:
test_logs

[{'hits@1_list': 0.0033673858270049095,
  'hits@3_list': 0.009205824695527554,
  'hits@10_list': 0.029947325587272644,
  'mrr_list': 0.0193993728607893},
 {'hits@1_list': 0.003305993042886257,
  'hits@3_list': 0.009122944436967373,
  'hits@10_list': 0.029704824090003967,
  'mrr_list': 0.01936676912009716},
 {'hits@1_list': 0.00337966438382864,
  'hits@3_list': 0.009211963973939419,
  'hits@10_list': 0.029950395226478577,
  'mrr_list': 0.019408494234085083},
 {'hits@1_list': 0.0033213412389159203,
  'hits@3_list': 0.009098388254642487,
  'hits@10_list': 0.0294254869222641,
  'mrr_list': 0.019350754097104073},
 {'hits@1_list': 0.0032752968836575747,
  'hits@3_list': 0.009138292632997036,
  'hits@10_list': 0.02936716563999653,
  'mrr_list': 0.01935487426817417},
 {'hits@1_list': 0.003315201960504055,
  'hits@3_list': 0.009098388254642487,
  'hits@10_list': 0.029385583475232124,
  'mrr_list': 0.01936601661145687},
 {'hits@1_list': 0.0033029234036803246,
  'hits@3_list': 0.00890500098466873

In [63]:
training_logs

[{'positive_sample_loss': 5.781314849853516,
  'negative_sample_loss': 0.14729920029640198,
  'loss': 2.9643070697784424},
 {'positive_sample_loss': 5.593881607055664,
  'negative_sample_loss': 0.04241343215107918,
  'loss': 2.8181474208831787},
 {'positive_sample_loss': 5.6944122314453125,
  'negative_sample_loss': 0.059163790196180344,
  'loss': 2.8767879009246826},
 {'positive_sample_loss': 5.279773235321045,
  'negative_sample_loss': 0.09559981524944305,
  'loss': 2.6876864433288574},
 {'positive_sample_loss': 5.365677356719971,
  'negative_sample_loss': 0.18596382439136505,
  'loss': 2.77582049369812},
 {'positive_sample_loss': 5.616865634918213,
  'negative_sample_loss': 0.24613013863563538,
  'loss': 2.931497812271118},
 {'positive_sample_loss': 5.645710468292236,
  'negative_sample_loss': 0.15487168729305267,
  'loss': 2.9002909660339355},
 {'positive_sample_loss': 5.437168598175049,
  'negative_sample_loss': 0.08542022109031677,
  'loss': 2.761294364929199},
 {'positive_sample