# Training CTKG Using TransE_L2
Adapted from the notebook of DRKG: https://github.com/gnn4dr/DRKG
This notebook shows how to train CTKG embeddings using TransE_L2

## Install DGL-KE
Before training the model, we need to install dgl and dgl-ke packages as well as other dependencies. 

In [9]:
import GPUtil
GPUtil.getAvailable()

[]

## Prepare train/valid/test set
Before training, we need to split the original ctkg into train/valid/test set with a 9:0.5:0.5 manner.

In [1]:
import pandas as pd
import numpy as np
import sys
sys.path.insert(1, '../utils')
ctkg_file = '../rawdata/ctkg.tsv'

df = pd.read_csv(ctkg_file, sep="\t")
triples = df.values.tolist()

In [2]:
num_triples = len(triples)
num_triples

7322993

In [3]:
# Please make sure the output directory exist.
seed = np.arange(num_triples)
np.random.shuffle(seed)

train_cnt = int(num_triples * 0.9)
valid_cnt = int(num_triples * 0.05)
train_set = seed[:train_cnt]
train_set = train_set.tolist()
valid_set = seed[train_cnt:train_cnt+valid_cnt].tolist()
test_set = seed[train_cnt+valid_cnt:].tolist()

with open("../train/ctkg_train.tsv", 'w+') as f:
    for idx in train_set:
        f.writelines("{}\t{}\t{}\n".format(triples[idx][0], triples[idx][1], triples[idx][2]))
        
with open("../train/ctkg_valid.tsv", 'w+') as f:
    for idx in valid_set:
        f.writelines("{}\t{}\t{}\n".format(triples[idx][0], triples[idx][1], triples[idx][2]))

with open("../train/ctkg_test.tsv", 'w+') as f:
    for idx in test_set:
        f.writelines("{}\t{}\t{}\n".format(triples[idx][0], triples[idx][1], triples[idx][2]))

## Training TransE_l2 model
We can training the TransE_l2 model by simplying using DGL-KE command line. For more information about using DGL-KE please refer to https://github.com/awslabs/dgl-ke.

In [4]:
import dgl
import dglke

In [5]:
!DGLBACKEND=pytorch  ~/.local/bin/dglke_train --dataset CTKG --data_path ../train --data_files \
    ctkg_train.tsv ctkg_valid.tsv ctkg_test.tsv --format 'raw_udd_hrt' --model_name TransE_l2 --batch_size 2048 \
--neg_sample_size 256 --hidden_dim 200 --gamma 12.0 --lr 0.1 --max_step 200 --log_interval 1000 \
--batch_size_eval 16 -adv --regularization_coef 1.00E-07 --test --num_thread 1 --num_proc 16 \
--neg_sample_size_eval 1000

Reading train triples....
Finished. Read 6590693 train triples.
Reading valid triples....
Finished. Read 366149 valid triples.
Reading test triples....
Finished. Read 366151 test triples.
|Train|: 6590693
random partition 6590693 edges into 16 parts
part 0 has 411919 edges
part 1 has 411919 edges
part 2 has 411919 edges
part 3 has 411919 edges
part 4 has 411919 edges
part 5 has 411919 edges
part 6 has 411919 edges
part 7 has 411919 edges
part 8 has 411919 edges
part 9 has 411919 edges
part 10 has 411919 edges
part 11 has 411919 edges
part 12 has 411919 edges
part 13 has 411919 edges
part 14 has 411919 edges
part 15 has 411908 edges
|valid|: 366149
|test|: 366151


## Get Entity and Relation Embeddings
The resulting model, i.e., the entity and relation embeddings can be found under ./ckpts. (Please refer to the first line of the training log for the specific location.)

The overall process will generate 4 important files:

  - Entity embedding: ./ckpts/<model\_name>_<dataset\_name>_<run_\id>/xxx\_entity.npy
  - Relation embedding: ./ckpts/<model\_name>_<dataset\_name>_<run\_id>/xxx\_relation.npy
  - The entity id mapping, formated in <entity\_name> <entity\_id> pair: <data\_path>/entities.tsv
  - The relation id mapping, formated in <relation\_name> <relation\_id> pair: <data\_path>/relations.tsv

In [6]:
!ls ./ckpts

TransE_l2_CTKG_0  TransE_l2_CTKG_2  TransE_l2_CTKG_4
TransE_l2_CTKG_1  TransE_l2_CTKG_3


In [9]:
!ls -la ./ckpts/TransE_l2_CTKG_*

./ckpts/TransE_l2_CTKG_0:
total 8
drwxrwxr-x 2 nobody nobody 4096 Dec 19 21:54 .
drwxrwxr-x 7 nobody nobody 4096 Dec 20 18:55 ..

./ckpts/TransE_l2_CTKG_1:
total 8
drwxrwxr-x 2 nobody nobody 4096 Dec 19 21:55 .
drwxrwxr-x 7 nobody nobody 4096 Dec 20 18:55 ..

./ckpts/TransE_l2_CTKG_2:
total 8
drwxrwxr-x 2 nobody nobody 4096 Dec 20 17:59 .
drwxrwxr-x 7 nobody nobody 4096 Dec 20 18:55 ..

./ckpts/TransE_l2_CTKG_3:
total 8
drwxrwxr-x 2 nobody nobody 4096 Dec 20 18:29 .
drwxrwxr-x 7 nobody nobody 4096 Dec 20 18:55 ..

./ckpts/TransE_l2_CTKG_4:
total 8
drwxrwxr-x 2 nobody nobody 4096 Dec 20 18:55 .
drwxrwxr-x 7 nobody nobody 4096 Dec 20 18:55 ..


In [8]:
!ls ../train/

ctkg_test.tsv  ctkg_train.tsv  ctkg_valid.tsv  entities.tsv  relations.tsv


## A Glance of the Entity and Relation Embeddings

In [8]:
node_emb = np.load('./ckpts/TransE_l2_DRKG_0/DRKG_TransE_l2_entity.npy')
relation_emb = np.load('./ckpts/TransE_l2_DRKG_0/DRKG_TransE_l2_relation.npy')

print(node_emb.shape)
print(relation_emb.shape)

FileNotFoundError: [Errno 2] No such file or directory: './ckpts/TransE_l2_DRKG_0/DRKG_TransE_l2_entity.npy'