## Install DGL-KE
Before training the model, we need to install dgl and dgl-ke packages as well as other dependencies. 

In [1]:
# !pip3 install torch
# !pip3 install dgl==0.4.3post2 
# !pip3 install dglke

Collecting dgl==0.4.3post2
  Downloading dgl-0.4.3.post2-cp37-cp37m-manylinux1_x86_64.whl (3.0 MB)
[K     |████████████████████████████████| 3.0 MB 11.5 MB/s eta 0:00:01
Installing collected packages: dgl
Successfully installed dgl-0.4.3.post2
Collecting dglke
  Downloading dglke-0.1.2-py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 3.0 MB/s  eta 0:00:01
Installing collected packages: dglke
Successfully installed dglke-0.1.2


## Prepare train/valid/test set
Before training, we need to split the original drkg into train/valid/test set with a 9:0.5:0.5 manner.

In [10]:
import pandas as pd
import numpy as np
import sys
sys.path.insert(1, './utils')
from utils import download_and_extract
download_and_extract()
drkg_file = './hetionet/hetionet.tsv'

df = pd.read_csv(drkg_file, sep=",")
triples = df.values.tolist()

len(triples)

2250196

We get 5,869,293 triples, now we will split them into three files

In [11]:
num_triples = len(triples)
num_triples

2250196

In [5]:
# Please make sure the output directory exist.
seed = np.arange(num_triples)
np.random.shuffle(seed)

train_cnt = int(num_triples * 0.9)
valid_cnt = int(num_triples * 0.05)
train_set = seed[:train_cnt]
train_set = train_set.tolist()
valid_set = seed[train_cnt:train_cnt+valid_cnt].tolist()
test_set = seed[train_cnt+valid_cnt:].tolist()

with open("./hetionet/hetionet_train.tsv", 'w+') as f:
    for idx in train_set:
        f.writelines("{}\t{}\t{}\n".format(triples[idx][0], triples[idx][1], triples[idx][2]))
        
with open("./hetionet/hetionet_valid.tsv", 'w+') as f:
    for idx in valid_set:
        f.writelines("{}\t{}\t{}\n".format(triples[idx][0], triples[idx][1], triples[idx][2]))

with open("./hetionet/hetionet_test.tsv", 'w+') as f:
    for idx in test_set:
        f.writelines("{}\t{}\t{}\n".format(triples[idx][0], triples[idx][1], triples[idx][2]))



In [6]:
# with open("./yago/yago_train.tsv", 'w+') as f:
#     for idx in train_set:
#         f.writelines("{}\t{}\t{}\n".format(triples[idx][0], triples[idx][1], triples[idx][2]))
        
# with open("./yago/yago_valid.tsv", 'w+') as f:
#     for idx in valid_set:
#         f.writelines("{}\t{}\t{}\n".format(triples[idx][0], triples[idx][1], triples[idx][2]))

# with open("./yago/yago_test.tsv", 'w+') as f:
#     for idx in test_set:
#         f.writelines("{}\t{}\t{}\n".format(triples[idx][0], triples[idx][1], triples[idx][2]))

## Training TransE_l2 model
We can training the TransE_l2 model by simplying using DGL-KE command line. For more information about using DGL-KE please refer to https://github.com/awslabs/dgl-ke.

Here we train the model using 8 GPUs on an AWS p3.16xlarge instance.

In [14]:
!DGLBACKEND=pytorch dglke_train --dataset hetionet --data_path ./hetionet --data_files hetionet_train.tsv hetionet_valid.tsv hetionet_test.tsv --format 'raw_udd_hrt' --model_name TransE_l2 --batch_size 2048 \
--neg_sample_size 256 --hidden_dim 200 --gamma 12.0 --lr 0.1 --max_step 100000 --log_interval 1000 --batch_size_eval 16 -adv --regularization_coef 1.00E-07 --test --num_thread 1 --gpu 0 1 2 3 --num_proc 8 --neg_sample_size_eval 10000 --async_update

!DGLBACKEND=pytorch dglke_train --dataset hetionet --data_path ./hetionet --data_files hetionet_train.tsv hetionet_valid.tsv hetionet_test.tsv --format 'raw_udd_hrt' --model_name RotatE --batch_size 512 \
--neg_sample_size 128 --hidden_dim 200 --gamma 12.0 --lr 0.1 --max_step 100000 --log_interval 1000 --batch_size_eval 16 -adv --regularization_coef 1.00E-07 --test --num_thread 1 --gpu 0 1 2 3 --num_proc 8 --neg_sample_size_eval 10000 --async_update


# !DGLBACKEND=pytorch dglke_train --dataset yago --data_path ./yago --data_files yago_train.tsv yago_valid.tsv yago_test.tsv --format 'raw_udd_hrt' --model_name TransE_l2 --batch_size 2048 \
# --neg_sample_size 256 --hidden_dim 100 --gamma 12.0 --lr 0.1 --max_step 100000 --log_interval 1000 --batch_size_eval 16 -adv --regularization_coef 1.00E-07 --test --num_thread 1 --gpu 0 1 2 3 --num_proc 8 --neg_sample_size_eval 10000 --async_update



Using backend: pytorch
Reading train triples....
Finished. Read 2025176 train triples.
Reading valid triples....
Finished. Read 112509 valid triples.
Reading test triples....
Finished. Read 112511 test triples.
|Train|: 2025176
random partition 2025176 edges into 8 parts
part 0 has 253147 edges
part 1 has 253147 edges
part 2 has 253147 edges
part 3 has 253147 edges
part 4 has 253147 edges
part 5 has 253147 edges
part 6 has 253147 edges
part 7 has 253147 edges
|valid|: 112509
|test|: 112511
Total initialize time 5.560 seconds
^C
Traceback (most recent call last):
  File "/workspace/anaconda3/envs/bio/bin/dglke_train", line 8, in <module>
Process Process-1:1:
Process Process-3:1:
Process Process-4:1:
Process Process-2:1:
    sys.exit(main())
  File "/workspace/anaconda3/envs/bio/lib/python3.7/site-packages/dglke/train.py", line 281, in main
    proc.join()
  File "/workspace/anaconda3/envs/bio/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout

## Get Entity and Relation Embeddings
The resulting model, i.e., the entity and relation embeddings can be found under ./ckpts. (Please refer to the first line of the training log for the specific location.)

The overall process will generate 4 important files:

  - Entity embedding: ./ckpts/<model\_name>_<dataset\_name>_<run_\id>/xxx\_entity.npy
  - Relation embedding: ./ckpts/<model\_name>_<dataset\_name>_<run\_id>/xxx\_relation.npy
  - The entity id mapping, formated in <entity\_name> <entity\_id> pair: <data\_path>/entities.tsv
  - The relation id mapping, formated in <relation\_name> <relation\_id> pair: <data\_path>/relations.tsv

In [None]:
!ls ./ckpts/TransE_l2_DRKG_0/
!ls ./train/

## A Glance of the Entity and Relation Embeddings

In [None]:
node_emb = np.load('./ckpts/TransE_l2_DRKG_0/DRKG_TransE_l2_entity.npy')
relation_emb = np.load('./ckpts/TransE_l2_DRKG_0/DRKG_TransE_l2_relation.npy')

print(node_emb.shape)
print(relation_emb.shape)