In this colab, we'll learn how to train the algorithm RotatE on a Knowledge Graph, and understand all steps from loading the dataset, to setting up the model and making predictions at the end.

First of all, we start by downloading the `ogb` and `tensorboardX` modules.  The `ogb` module, which stands for [Open Graph Benchmark](https://ogb.stanford.edu/docs/home/), contains graph datasets that are widely used in Graph ML research to serve as benchmarks. The second module is `tensorboardX`, which provides logging utilities.

The code of this colab can be found at [Graph_Neural_Network_for_biological_predictions](https://github.com/Zahra-Bakhtiari/Graph_Neural_Network_for_biological_predictions.git) and is based on the following repositories: [snap-stanford](https://github.com/snap-stanford/ogb.git) and [KnowledgeGraphEmbedding](https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding). The original code is more complete and customizable; in here, we simplify it as to make it easier for the reader to train their first GNN on an heterogenous graph. You may achieve similar results by running the CLI commands described in `KnowledgeGraphEmbedding` repository. In this colab, we extract the relevant code to expose the training logic behind the scenes.

In [None]:
import torch
import os

!pip install ogb
!pip install tensorboardX

Collecting ogb
  Downloading ogb-1.3.2-py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 2.4 MB/s 
Collecting outdated>=0.2.0
  Downloading outdated-0.2.1-py3-none-any.whl (7.5 kB)
Collecting littleutils
  Downloading littleutils-0.2.2.tar.gz (6.6 kB)
Building wheels for collected packages: littleutils
  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
  Created wheel for littleutils: filename=littleutils-0.2.2-py3-none-any.whl size=7048 sha256=f38454f4c1c71428f7c2d93e24ca86d7bcc4fb2d7a9d1a3827bf902bb373a04b
  Stored in directory: /root/.cache/pip/wheels/d6/64/cd/32819b511a488e4993f2fab909a95330289c3f4e0f6ef4676d
Successfully built littleutils
Installing collected packages: littleutils, outdated, ogb
Successfully installed littleutils-0.2.2 ogb-1.3.2 outdated-0.2.1
Collecting tensorboardX
  Downloading tensorboardX-2.4.1-py2.py3-none-any.whl (124 kB)
[K     |████████████████████████████████| 124 kB 2.7 MB/s 
Installing collected packages: tensorbo

## Data


After making sure we have all required modules mentioned above, the first step is to load and understand the data.

Let's use the `TrainDataset` and `TestDataset` classes that are defined in [stanford-snap dataloader](https://github.com/snap-stanford/ogb/blob/9d7a3f5e89640938f93b52388cfe259ae8fe3855/examples/linkproppred/biokg/dataloader.py). These are helper classes that will assist us in retrieve positive and negative samples from our original dataset. These classes inherit from `torch.utils.data.Dataset`, which you can read more about in the [official docs](https://pytorch.org/docs/stable/data.html).

In [None]:
from torch.utils.data import Dataset

class TrainDataset(Dataset):
    def __init__(self, triples, nentity, nrelation, negative_sample_size, mode, count, true_head, true_tail, entity_dict):
        self.len = len(triples['head'])
        self.triples = triples
        self.nentity = nentity
        self.nrelation = nrelation
        self.negative_sample_size = negative_sample_size
        self.mode = mode
        self.count = count
        self.true_head = true_head
        self.true_tail = true_tail
        self.entity_dict = entity_dict
        
    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        head, relation, tail = self.triples['head'][idx], self.triples['relation'][idx], self.triples['tail'][idx]
        head_type, tail_type = self.triples['head_type'][idx], self.triples['tail_type'][idx]
        positive_sample = [head + self.entity_dict[head_type][0], relation, tail + self.entity_dict[tail_type][0]]

        subsampling_weight = self.count[(head, relation, head_type)] + self.count[(tail, -relation-1, tail_type)]
        subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight]))

        if self.mode == 'head-batch':
            negative_sample = torch.randint(self.entity_dict[head_type][0], self.entity_dict[head_type][1], (self.negative_sample_size,))
        elif self.mode == 'tail-batch':
            negative_sample = torch.randint(self.entity_dict[tail_type][0], self.entity_dict[tail_type][1], (self.negative_sample_size,))
        else:
            raise
        positive_sample = torch.LongTensor(positive_sample)
            
        return positive_sample, negative_sample, subsampling_weight, self.mode
    
    @staticmethod
    def collate_fn(data):
        positive_sample = torch.stack([_[0] for _ in data], dim=0)
        negative_sample = torch.stack([_[1] for _ in data], dim=0)
        subsample_weight = torch.cat([_[2] for _ in data], dim=0)
        mode = data[0][3]
        return positive_sample, negative_sample, subsample_weight, mode
    

In [None]:
class TestDataset(Dataset):
    def __init__(self, triples, args, mode, random_sampling, entity_dict):
        self.len = len(triples['head'])
        self.triples = triples
        self.nentity = args['nentity']
        self.nrelation = args['nrelation']
        self.mode = mode
        self.random_sampling = random_sampling
        if random_sampling:
            self.neg_size = args['neg_size_eval_train']
        self.entity_dict = entity_dict

    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        head, relation, tail = self.triples['head'][idx], self.triples['relation'][idx], self.triples['tail'][idx]
        head_type, tail_type = self.triples['head_type'][idx], self.triples['tail_type'][idx]
        positive_sample = torch.LongTensor((head + self.entity_dict[head_type][0], relation, tail + self.entity_dict[tail_type][0]))

        if self.mode == 'head-batch':
            if not self.random_sampling:
                negative_sample = torch.cat([torch.LongTensor([head + self.entity_dict[head_type][0]]), 
                        torch.from_numpy(self.triples['head_neg'][idx] + self.entity_dict[head_type][0])])
            else:
                negative_sample = torch.cat([torch.LongTensor([head + self.entity_dict[head_type][0]]), 
                        torch.randint(self.entity_dict[head_type][0], self.entity_dict[head_type][1], size=(self.neg_size,))])
        elif self.mode == 'tail-batch':
            if not self.random_sampling:
                negative_sample = torch.cat([torch.LongTensor([tail + self.entity_dict[tail_type][0]]), 
                        torch.from_numpy(self.triples['tail_neg'][idx] + self.entity_dict[tail_type][0])])
            else:
                negative_sample = torch.cat([torch.LongTensor([tail + self.entity_dict[tail_type][0]]), 
                        torch.randint(self.entity_dict[tail_type][0], self.entity_dict[tail_type][1], size=(self.neg_size,))])

        return positive_sample, negative_sample, self.mode
    
    @staticmethod
    def collate_fn(data):
        positive_sample = torch.stack([_[0] for _ in data], dim=0)
        negative_sample = torch.stack([_[1] for _ in data], dim=0)
        mode = data[0][2]

        return positive_sample, negative_sample, mode
    

We also extract the `BidirectionalOneShotIterator` class, which is useful to iterate both over our training/validation/test heads and tails simultaneously. This is useful for retrieving both positive and negative samples.

In [None]:
class BidirectionalOneShotIterator(object):
    def __init__(self, dataloader_head, dataloader_tail):
        self.iterator_head = self.one_shot_iterator(dataloader_head)
        self.iterator_tail = self.one_shot_iterator(dataloader_tail)
        self.step = 0
        
    def __next__(self):
        self.step += 1
        if self.step % 2 == 0:
            data = next(self.iterator_head)
        else:
            data = next(self.iterator_tail)
        return data
    
    @staticmethod
    def one_shot_iterator(dataloader):
        '''
        Transform a PyTorch Dataloader into python iterator
        '''
        while True:
            for data in dataloader:
                yield data

In this colab, we'll be using the [`ogbl-biokg` dataset](https://ogb.stanford.edu/docs/linkprop/#ogbl-biokg), which is defined [Hu et al. (2005)](https://arxiv.org/pdf/2005.00687.pdf) and contains data from a large number of biomedical data repositories.

To download and explore the dataset, we may use the `ogb` module we previously installed.

In [None]:
from ogb.linkproppred import LinkPropPredDataset

dataset = LinkPropPredDataset(name='ogbl-biokg', root='drive/MyDrive/data')

Downloading http://snap.stanford.edu/ogb/data/linkproppred/biokg.zip


Downloaded 0.90 GB: 100%|██████████| 920/920 [06:31<00:00,  2.35it/s]


Extracting drive/MyDrive/data/biokg.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 3563.55it/s]

Saving...





This is a dataset with only one graph:

In [None]:
graph = dataset[0]

The `ogbl-kg` graph contains 5 types of entities:  diseases (10,687 nodes), proteins (17,499), drugs (10,533 nodes), side effects (9,969 nodes), and protein functions (45,085 nodes). 

We can explore this by checking its `num_nodes_dict`

In [None]:
graph['num_nodes_dict']

{'disease': 10687,
 'drug': 10533,
 'function': 45085,
 'protein': 17499,
 'sideeffect': 9969}

We may also explore what kinds of relationships this knowledge graph has. In total, the description of the dataset mentions it has 51 types of directed relations. We can check that by print its `edge_reltype` dict:

In [None]:
list(graph['edge_reltype'].keys())

[('disease', 'disease-protein', 'protein'),
 ('drug', 'drug-disease', 'disease'),
 ('drug', 'drug-drug_acquired_metabolic_disease', 'drug'),
 ('drug', 'drug-drug_bacterial_infectious_disease', 'drug'),
 ('drug', 'drug-drug_benign_neoplasm', 'drug'),
 ('drug', 'drug-drug_cancer', 'drug'),
 ('drug', 'drug-drug_cardiovascular_system_disease', 'drug'),
 ('drug', 'drug-drug_chromosomal_disease', 'drug'),
 ('drug', 'drug-drug_cognitive_disorder', 'drug'),
 ('drug', 'drug-drug_cryptorchidism', 'drug'),
 ('drug', 'drug-drug_developmental_disorder_of_mental_health', 'drug'),
 ('drug', 'drug-drug_endocrine_system_disease', 'drug'),
 ('drug', 'drug-drug_fungal_infectious_disease', 'drug'),
 ('drug', 'drug-drug_gastrointestinal_system_disease', 'drug'),
 ('drug', 'drug-drug_hematopoietic_system_disease', 'drug'),
 ('drug', 'drug-drug_hematopoietic_system_diseases', 'drug'),
 ('drug', 'drug-drug_hypospadias', 'drug'),
 ('drug', 'drug-drug_immune_system_disease', 'drug'),
 ('drug', 'drug-drug_inheri

## Model


We'll explore the `RotatE` and  `pRotatE` ([Sun et al. (2019)](https://arxiv.org/abs/1902.10197)) models. For more information and details on these models, refer to the linked paper and to our [blog post](https://medium.com/@seshwan2/de5acf0553ac).

We'll use the KGEModel defined in [KnowledgeGraphEmbedding](https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/blob/2e440e0f9c687314d5ff67ead68ce985dc446e3a/codes/model.py).

This class defines a knowledge-graph oriented [`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). We have the losses functions of TransE, RotatE and pRotatE defined here as well. The class is below, please take your time to understand how the code is structured.

In [None]:
#!/usr/bin/python3

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import average_precision_score
from torch.utils.data import DataLoader
from collections import defaultdict

from ogb.linkproppred import Evaluator


class KGEModel(nn.Module):
    def __init__(self, model_name, nentity, nrelation, hidden_dim, gamma, evaluator,
                 double_entity_embedding=False, double_relation_embedding=False):
        super(KGEModel, self).__init__()
        self.model_name = model_name
        self.nentity = nentity
        self.nrelation = nrelation
        self.hidden_dim = hidden_dim
        self.epsilon = 2.0

        self.gamma = nn.Parameter(
            torch.Tensor([gamma]),
            requires_grad=False
        )

        self.embedding_range = nn.Parameter(
            torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]),
            requires_grad=False
        )

        self.entity_dim = hidden_dim * 2 if double_entity_embedding else hidden_dim
        self.relation_dim = hidden_dim * 2 if double_relation_embedding else hidden_dim

        self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim))
        nn.init.uniform_(
            tensor=self.entity_embedding,
            a=-self.embedding_range.item(),
            b=self.embedding_range.item()
        )

        self.relation_embedding = nn.Parameter(torch.zeros(nrelation, self.relation_dim))
        nn.init.uniform_(
            tensor=self.relation_embedding,
            a=-self.embedding_range.item(),
            b=self.embedding_range.item()
        )

        if model_name == 'pRotatE':
            self.modulus = nn.Parameter(torch.Tensor([[0.5 * self.embedding_range.item()]]))

        if model_name not in ['TransE', 'pRotatE', 'RotatE']:
            raise ValueError('model %s not supported' % model_name)

        if model_name == 'RotatE' and (not double_entity_embedding or double_relation_embedding):
            raise ValueError('RotatE should use --double_entity_embedding')

        self.evaluator = evaluator

    def forward(self, sample, mode='single'):
        '''
        Forward function that calculate the score of a batch of triples.
        In the 'single' mode, sample is a batch of triple.
        In the 'head-batch' or 'tail-batch' mode, sample consists two part.
        The first part is usually the positive sample.
        And the second part is the entities in the negative samples.
        Because negative samples and positive samples usually share two elements
        in their triple ((head, relation) or (relation, tail)).
        '''

        if mode == 'single':
            batch_size, negative_sample_size = sample.size(0), 1

            head = torch.index_select(
                self.entity_embedding,
                dim=0,
                index=sample[:, 0]
            ).unsqueeze(1)

            relation = torch.index_select(
                self.relation_embedding,
                dim=0,
                index=sample[:, 1]
            ).unsqueeze(1)

            tail = torch.index_select(
                self.entity_embedding,
                dim=0,
                index=sample[:, 2]
            ).unsqueeze(1)

        elif mode == 'head-batch':
            tail_part, head_part = sample
            batch_size, negative_sample_size = head_part.size(0), head_part.size(1)

            head = torch.index_select(
                self.entity_embedding,
                dim=0,
                index=head_part.view(-1)
            ).view(batch_size, negative_sample_size, -1)

            relation = torch.index_select(
                self.relation_embedding,
                dim=0,
                index=tail_part[:, 1]
            ).unsqueeze(1)

            tail = torch.index_select(
                self.entity_embedding,
                dim=0,
                index=tail_part[:, 2]
            ).unsqueeze(1)

        elif mode == 'tail-batch':
            head_part, tail_part = sample
            batch_size, negative_sample_size = tail_part.size(0), tail_part.size(1)

            head = torch.index_select(
                self.entity_embedding,
                dim=0,
                index=head_part[:, 0]
            ).unsqueeze(1)

            relation = torch.index_select(
                self.relation_embedding,
                dim=0,
                index=head_part[:, 1]
            ).unsqueeze(1)

            tail = torch.index_select(
                self.entity_embedding,
                dim=0,
                index=tail_part.view(-1)
            ).view(batch_size, negative_sample_size, -1)

        else:
            raise ValueError('mode %s not supported' % mode)

        model_func = {
            'TransE': self.TransE,
            'pRotatE': self.pRotatE,
            'RotatE': self.RotatE,
        }

        if self.model_name in model_func:
            score = model_func[self.model_name](head, relation, tail, mode)
        else:
            raise ValueError('model %s not supported' % self.model_name)

        return score

    def TransE(self, head, relation, tail, mode):
        if mode == 'head-batch':
            score = head + (relation - tail)
        else:
            score = (head + relation) - tail

        score = self.gamma.item() - torch.norm(score, p=1, dim=2)
        return score

    def RotatE(self, head, relation, tail, mode):

        pi = 3.14159265358979323846

        re_head, im_head = torch.chunk(head, 2, dim=2)
        re_tail, im_tail = torch.chunk(tail, 2, dim=2)

        # Make phases of relations uniformly distributed in [-pi, pi]

        phase_relation = relation / (self.embedding_range.item() / pi)

        re_relation = torch.cos(phase_relation)
        im_relation = torch.sin(phase_relation)

        if mode == 'head-batch':
            re_score = re_relation * re_tail + im_relation * im_tail
            im_score = re_relation * im_tail - im_relation * re_tail
            re_score = re_score - re_head
            im_score = im_score - im_head
        else:
            re_score = re_head * re_relation - im_head * im_relation
            im_score = re_head * im_relation + im_head * re_relation
            re_score = re_score - re_tail
            im_score = im_score - im_tail

        score = torch.stack([re_score, im_score], dim=0)
        score = score.norm(dim=0)

        score = self.gamma.item() - score.sum(dim=2)
        return score

    def pRotatE(self, head, relation, tail, mode):
        pi = 3.14159262358979323846

        # Make phases of entities and relations uniformly distributed in [-pi, pi]

        phase_head = head / (self.embedding_range.item() / pi)
        phase_relation = relation / (self.embedding_range.item() / pi)
        phase_tail = tail / (self.embedding_range.item() / pi)

        if mode == 'head-batch':
            score = phase_head + (phase_relation - phase_tail)
        else:
            score = (phase_head + phase_relation) - phase_tail

        score = torch.sin(score)
        score = torch.abs(score)

        score = self.gamma.item() - score.sum(dim=2) * self.modulus
        return score

    @staticmethod
    def train_step(model, optimizer, train_iterator, args):
        '''
        A single train step. Apply back-propation and return the loss
        '''

        model.train()
        optimizer.zero_grad()
        positive_sample, negative_sample, subsampling_weight, mode = next(train_iterator)

        positive_sample = positive_sample.cuda()
        negative_sample = negative_sample.cuda()
        subsampling_weight = subsampling_weight.cuda()

        negative_score = model((positive_sample, negative_sample), mode=mode)
        if args['negative_adversarial_sampling']:
            # In self-adversarial sampling, we do not apply back-propagation on the sampling weight
            negative_score = (F.softmax(negative_score * args['adversarial_temperature'], dim=1).detach()
                              * F.logsigmoid(-negative_score)).sum(dim=1)
        else:
            negative_score = F.logsigmoid(-negative_score).mean(dim=1)

        positive_score = model(positive_sample)
        positive_score = F.logsigmoid(positive_score).squeeze(dim=1)
        
        positive_sample_loss = - (subsampling_weight * positive_score).sum() / subsampling_weight.sum()
        negative_sample_loss = - (subsampling_weight * negative_score).sum() / subsampling_weight.sum()

        loss = (positive_sample_loss + negative_sample_loss) / 2
        regularization_log = {}

        loss.backward()

        optimizer.step()

        log = {
            **regularization_log,
            'positive_sample_loss': positive_sample_loss.item(),
            'negative_sample_loss': negative_sample_loss.item(),
            'loss': loss.item()
        }

        return log

    @staticmethod
    def test_step(model, test_triples, args, entity_dict, random_sampling=False):
        '''
        Evaluate the model on test or valid datasets
        '''

        model.eval()

        # Prepare dataloader for evaluation
        test_dataloader_head = DataLoader(
            TestDataset(
                test_triples,
                args,
                'head-batch',
                random_sampling,
                entity_dict
            ),
            batch_size=args['test_batch_size'],
            num_workers=max(1, args['cpu_num'] // 2),
            collate_fn=TestDataset.collate_fn
        )

        test_dataloader_tail = DataLoader(
            TestDataset(
                test_triples,
                args,
                'tail-batch',
                random_sampling,
                entity_dict
            ),
            batch_size=args['test_batch_size'],
            num_workers=max(1, args['cpu_num'] // 2),
            collate_fn=TestDataset.collate_fn
        )

        test_dataset_list = [test_dataloader_head, test_dataloader_tail]

        test_logs = defaultdict(list)

        step = 0
        total_steps = sum([len(dataset) for dataset in test_dataset_list])

        with torch.no_grad():
            for test_dataset in test_dataset_list:
                for positive_sample, negative_sample, mode in test_dataset:
                    positive_sample = positive_sample.cuda()
                    negative_sample = negative_sample.cuda()

                    batch_size = positive_sample.size(0)
                    score = model((positive_sample, negative_sample), mode)

                    batch_results = model.evaluator.eval({'y_pred_pos': score[:, 0],
                                                          'y_pred_neg': score[:, 1:]})
                    for metric in batch_results:
                        test_logs[metric].append(batch_results[metric])

                    if step % args['test_log_steps'] == 0:
                        print('Evaluating the model... (%d/%d)' % (step, total_steps))

                    step += 1

            metrics = {}
            for metric in test_logs:
                metrics[metric] = torch.cat(test_logs[metric]).mean().item()

        return metrics

We now have our data and our models defined. We can proceed to train our models and get some results!

The code from the training section is also based on the [KnowledgeGraphEmbedding](https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/) repository. As aforementioned, the original code allows for multiple customizations and hyperparameter, which we summarize in the `args` dictionary below. Feel free to play and change these different values to see how they impact convergence and accurancy of the models!

For the full list of arguments, please refer to the [original file](https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/blob/master/codes/run.py) where these are defined.

The `args` dictionary paramets are defined below, which a short explanationL

*   `dataset`: The dataset name.
*   `model`: The model being trained.
*   `hidden_dim`: The embeddings' dimension.
*   `gamma`: The value for `gamma` (refer to the model loss functions).
*   `batch_size`: The size of each batch.
*   `negative_sample_size`: The size of each negative sample.
*   `adversarial_temperature`: The temperature value for `alpha` used in Sun et al (2019) for negative sampling.
*   `learning_rate`: The learning rate.
*   `max_steps`: The maximum number of iterations for the training.
*   `cpu_num`: Number of CPUs.
*   `test_batch_size`: The size of each validation and test sets batch.
*   `neg_size_eval_train`: The number of negative samples when evaluating training.
*   `save_checkpoint_steps`: The number of steps until saving a checkpoint for the model.
*   `valid_steps`: The number of steps until evaluating the model.
*   `log_steps`: The number of steps until logging training data.
*   `double_entity_embedding`: Whether we have double entity embeddings.
*   `double_relation_embedding`: Whether we have double relation embeddings.
*   `negative_adversarial_sampling`: Whether we want negative adversarial sampling.

# Training


We're now ready to train our models. This code is based on [snap-stanford ogb repository](https://github.com/snap-stanford/ogb/blob/9d7a3f5e89640938f93b52388cfe259ae8fe3855/examples/linkproppred/biokg/run.py)

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from torch.utils.data import DataLoader
from ogb.linkproppred import Evaluator
from collections import defaultdict
from tqdm import tqdm
import time
from tensorboardX import SummaryWriter
import pdb

def log_metrics(mode, step, metrics, writer):
    '''
    Print the evaluation logs
    '''
    for metric in metrics:
        print('%s %s at step %d: %f' % (mode, metric, step, metrics[metric]))
        writer.add_scalar("_".join([mode, metric]), metrics[metric], step)

def run(model, double_entity_embedding, hidden_dim=1000, gamma=20):
  dataset_name = 'ogbl-biokg'
  dataset = LinkPropPredDataset(name=dataset_name, root='drive/MyDrive/data')
  
  args = {'dataset': dataset_name,  
          'model': model,          
          'hidden_dim': hidden_dim,
          'gamma': gamma,
          'double_entity_embedding': double_entity_embedding,
          'batch_size': 512,
          'negative_sample_size': 128,
          'adversarial_temperature': 1.0,
          'learning_rate': 0.0001,
          'max_steps': 300000,
          'cpu_num': 2,
          'test_batch_size': 32,
          'neg_size_eval_train': 500,
          'save_checkpoint_steps': 10000,
          'valid_steps': 10000,
          'log_steps': 100,
          'double_relation_embedding': False,    
          'negative_adversarial_sampling': True, 
          'do_train': True,
          'do_test': True,
          'do_valid': True,
          'evaluate_train': True}
  save_path = 'log/%s/%s/%s-%s/%s' % (args['dataset'],
                                            args['model'], 
                                            args['hidden_dim'],
                                            args['gamma'],
                                            time.time())
  writer = SummaryWriter(save_path)

  split_edge = dataset.get_edge_split()
  train_triples, valid_triples, test_triples = split_edge["train"], split_edge["valid"], split_edge["test"]
  nrelation = int(max(train_triples['relation'])) + 1
  entity_dict = dict()
  cur_idx = 0
  for key in dataset[0]['num_nodes_dict']:
      entity_dict[key] = (cur_idx, cur_idx + dataset[0]['num_nodes_dict'][key])
      cur_idx += dataset[0]['num_nodes_dict'][key]
  nentity = sum(dataset[0]['num_nodes_dict'].values())

  args['nentity'] = nentity
  args['nrelation'] = nrelation

  evaluator = Evaluator(name=dataset_name)
  train_count, train_true_head, train_true_tail = defaultdict(lambda: 4), defaultdict(list), defaultdict(list)

  for i in tqdm(range(len(train_triples['head']))):
      head, relation, tail = train_triples['head'][i], train_triples['relation'][i], train_triples['tail'][i]
      head_type, tail_type = train_triples['head_type'][i], train_triples['tail_type'][i]
      train_count[(head, relation, head_type)] += 1
      train_count[(tail, -relation - 1, tail_type)] += 1
      train_true_head[(relation, tail)].append(head)
      train_true_tail[(head, relation)].append(tail)

  kge_model = KGEModel(
      model_name=args['model'],
      nentity=nentity,
      nrelation=nrelation,
      hidden_dim=args['hidden_dim'],
      gamma=args['gamma'],
      double_entity_embedding=args['double_entity_embedding'],
      double_relation_embedding=args['double_relation_embedding'],
      evaluator=evaluator
  )
  
  kge_model = kge_model.cuda()

  if args['do_train']:
      train_dataloader_head = DataLoader(
          TrainDataset(train_triples, nentity, nrelation,
                        args['negative_sample_size'], 'head-batch',
                        train_count, train_true_head, train_true_tail,
                        entity_dict),
          batch_size=args['batch_size'],
          shuffle=True,
          num_workers=max(1, args['cpu_num'] // 2),
          collate_fn=TrainDataset.collate_fn
      )

      train_dataloader_tail = DataLoader(
          TrainDataset(train_triples, nentity, nrelation,
                        args['negative_sample_size'], 'tail-batch',
                        train_count, train_true_head, train_true_tail,
                        entity_dict),
          batch_size=args['batch_size'],
          shuffle=True,
          num_workers=max(1, args['cpu_num'] // 2),
          collate_fn=TrainDataset.collate_fn
      )

      train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)

      current_learning_rate = args['learning_rate']
      optimizer = torch.optim.Adam(
          filter(lambda p: p.requires_grad, kge_model.parameters()),
          lr=current_learning_rate
      )
      warm_up_steps = args['max_steps'] // 2

      print('Ramdomly Initializing %s Model...' % args['model'])
      init_step = 0

  step = init_step

  print('Start Training...')
  print('init_step = %d' % init_step)
  print('batch_size = %d' % args['batch_size'])
  print('negative_adversarial_sampling = %d' % args['negative_adversarial_sampling'])
  print('hidden_dim = %d' % args['hidden_dim'])
  print('gamma = %f' % args['gamma'])
  print('negative_adversarial_sampling = %s' % str(args['negative_adversarial_sampling']))
  print('adversarial_temperature = %f' % args['adversarial_temperature'])

  # Set valid dataloader as it would be evaluated during training

  if args['do_train']:
      print('learning_rate = %d' % current_learning_rate)

      training_logs = []

      # Training Loop
      for step in range(init_step, args['max_steps']):

          log = kge_model.train_step(kge_model, optimizer, train_iterator, args)
          training_logs.append(log)

          if step >= warm_up_steps:
              current_learning_rate = current_learning_rate / 10
              print('Change learning_rate to %f at step %d' % (current_learning_rate, step))
              optimizer = torch.optim.Adam(
                  filter(lambda p: p.requires_grad, kge_model.parameters()),
                  lr=current_learning_rate
              )
              warm_up_steps = warm_up_steps * 3

          if step % args['save_checkpoint_steps'] == 0 and step > 0:  # ~ 41 seconds/saving
              save_variable_list = {
                  'step': step,
                  'current_learning_rate': current_learning_rate,
                  'warm_up_steps': warm_up_steps,
                  'entity_dict': entity_dict
              }
              save_model(kge_model, optimizer, save_variable_list, args)

          if step % args['log_steps'] == 0:
              metrics = {}
              for metric in training_logs[0].keys():
                  metrics[metric] = sum([log[metric] for log in training_logs]) / len(training_logs)
              log_metrics('Train', step, metrics, writer)
              training_logs = []

          if args['do_valid'] and step % args['valid_steps'] == 0 and step > 0:
              print('Evaluating on Valid Dataset...')
              metrics = kge_model.test_step(kge_model, valid_triples, args, entity_dict)
              log_metrics('Valid', step, metrics, writer)

      save_variable_list = {
          'step': step,
          'current_learning_rate': current_learning_rate,
          'warm_up_steps': warm_up_steps
      }
      save_model(kge_model, optimizer, save_variable_list, args)

  if args['do_valid']:
      print('Evaluating on Valid Dataset...')
      metrics = kge_model.test_step(kge_model, valid_triples, args, entity_dict)
      log_metrics('Valid', step, metrics, writer)

  if args['do_test']:
      print('Evaluating on Test Dataset...')
      metrics = kge_model.test_step(kge_model, test_triples, args, entity_dict)
      log_metrics('Test', step, metrics, writer)

  if args['evaluate_train']:
      print('Evaluating on Training Dataset...')
      small_train_triples = {}
      indices = np.random.choice(len(train_triples['head']), args['ntriples_eval_train'], replace=False)
      for i in train_triples:
          if 'type' in i:
              small_train_triples[i] = [train_triples[i][x] for x in indices]
          else:
              small_train_triples[i] = train_triples[i][indices]
      metrics = kge_model.test_step(kge_model, small_train_triples, args, entity_dict, random_sampling=True)
      log_metrics('Train', step, metrics, writer)

We can finally start our training:

In [None]:
run(model = 'RotatE', double_entity_embedding = True)

100%|██████████| 4762678/4762678 [00:15<00:00, 302814.74it/s]


Ramdomly Initializing RotatE Model...
Start Training...
init_step = 0
batch_size = 512
negative_adversarial_sampling = 1
hidden_dim = 1000
gamma = 20.000000
negative_adversarial_sampling = True
adversarial_temperature = 1.000000
learning_rate = 0
Train positive_sample_loss at step 0: 2.989598
Train negative_sample_loss at step 0: 0.058889
Train loss at step 0: 1.524244


KeyboardInterrupt: ignored

In [None]:
run(model = 'pRotatE', double_entity_embedding = False)

100%|██████████| 4762678/4762678 [00:18<00:00, 259242.00it/s]


Ramdomly Initializing pRotatE Model...
Start Training...
init_step = 0
batch_size = 512
negative_adversarial_sampling = 1
hidden_dim = 1000
gamma = 20.000000
negative_adversarial_sampling = True
adversarial_temperature = 1.000000
learning_rate = 0
Train positive_sample_loss at step 0: 0.000002
Train negative_sample_loss at step 0: 13.008783
Train loss at step 0: 6.504393
Train positive_sample_loss at step 100: 0.000218
Train negative_sample_loss at step 100: 9.806888
Train loss at step 100: 4.903553
Train positive_sample_loss at step 200: 0.090560
Train negative_sample_loss at step 200: 3.627894
Train loss at step 200: 1.859227
Train positive_sample_loss at step 300: 0.615183
Train negative_sample_loss at step 300: 0.763102
Train loss at step 300: 0.689143
Train positive_sample_loss at step 400: 0.647756
Train negative_sample_loss at step 400: 0.665541
Train loss at step 400: 0.656649
Train positive_sample_loss at step 500: 0.641103
Train negative_sample_loss at step 500: 0.655639
Trai

NameError: ignored

We discuss in our [blog post](https://medium.com/@seshwan2/de5acf0553ac) the final results of training. The final output is that we will have our entities and relations embedded in 1000-dimensional vectors. These vectors are representations of our entities and relations and can be used, according to our models, to make link predictions and perform graph completion tasks.