In [1]:
import argparse
import os.path as osp

import torch
import torch.optim as optim

from pyg_dataloader import KGData
from torch_geometric.nn import ComplEx, DistMult, RotatE, TransE

model_map = {
    'transe': TransE,
    'complex': ComplEx,
    'distmult': DistMult,
    'rotate': RotatE,
}

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# import ipdb
# %pdb on

In [3]:
path = '../dataset/FB15K237/'
device='cuda:0'
model_name = 'rotate'

In [4]:
train_data = KGData(path, split='train')[0].to(device)
val_data = KGData(path, split='val')[0].to(device)
test_data = KGData(path, split='test')[0].to(device)

In [5]:
model_arg_map = {'rotate': {'margin': 9.0}}
model = model_map[model_name](
    num_nodes=train_data.num_nodes,
    num_relations=train_data.num_edge_types,
    hidden_channels=50,
    **model_arg_map.get(model_name, {}),
).to(device)

In [6]:
loader = model.loader(
    head_index=train_data.edge_index[0],
    rel_type=train_data.edge_type,
    tail_index=train_data.edge_index[1],
    batch_size=1000,
    shuffle=True,
)

optimizer_map = {
    'transe': optim.Adam(model.parameters(), lr=0.01),
    'complex': optim.Adagrad(model.parameters(), lr=0.001, weight_decay=1e-6),
    'distmult': optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6),
    'rotate': optim.Adam(model.parameters(), lr=1e-3),
}
optimizer = optimizer_map[model_name]

In [7]:
def train():
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in loader:
        optimizer.zero_grad()
        loss = model.loss(head_index, rel_type, tail_index)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
    return total_loss / total_examples


@torch.no_grad()
def test(data):
    model.eval()
    return model.test(
        head_index=data.edge_index[0],
        rel_type=data.edge_type,
        tail_index=data.edge_index[1],
        batch_size=20000,
        k_list=[1,3,5,10],
    )


for epoch in range(1, 101):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    if epoch % 25 == 0:
        rank, mrr, hr1, hr3, hr5, hr10 = test(val_data)
        print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, Val Mean Reciprocal Rank: {mrr:.4f}, \
              Val Hits@1: {hr1:.4f}, Val Hits@3: {hr3:.4f}, Val Hits@5: {hr5:.4f}, Val Hits@10: {hr10:.4f}' )

rank, mrr, hits_at_1, hits_at_3, hits_at_5, hits_at_10 = test(test_data)
print(f'Test Mean Rank: {rank:.2f}, Test Mean Reciprocal Rank: {mrr:.4f}, Test Hits@1: {hits_at_1:.4f},\
      Test Hits@3: {hits_at_3:.4f}, Val Hits@5: {hits_at_5:.4f}, Test Hits@10: {hits_at_10:.4f}')

Epoch: 001, Loss: 4.1396
Epoch: 002, Loss: 3.5518
Epoch: 003, Loss: 2.9820
Epoch: 004, Loss: 2.4954
Epoch: 005, Loss: 2.1152
Epoch: 006, Loss: 1.8179
Epoch: 007, Loss: 1.5833
Epoch: 008, Loss: 1.3974
Epoch: 009, Loss: 1.2480
Epoch: 010, Loss: 1.1284
Epoch: 011, Loss: 1.0316
Epoch: 012, Loss: 0.9511
Epoch: 013, Loss: 0.8834
Epoch: 014, Loss: 0.8270
Epoch: 015, Loss: 0.7790
Epoch: 016, Loss: 0.7408
Epoch: 017, Loss: 0.7062
Epoch: 018, Loss: 0.6777
Epoch: 019, Loss: 0.6537
Epoch: 020, Loss: 0.6329
Epoch: 021, Loss: 0.6155
Epoch: 022, Loss: 0.5998
Epoch: 023, Loss: 0.5865
Epoch: 024, Loss: 0.5742
Epoch: 025, Loss: 0.5628


100%|██████████| 17535/17535 [00:10<00:00, 1744.27it/s]


Epoch: 025, Val Mean Rank: 3186.07, Val Mean Reciprocal Rank: 0.0161,               Val Hits@1: 0.0074, Val Hits@3: 0.0122, Val Hits@5: 0.0164, Val Hits@10: 0.0270
Epoch: 026, Loss: 0.5527
Epoch: 027, Loss: 0.5419
Epoch: 028, Loss: 0.5325
Epoch: 029, Loss: 0.5234
Epoch: 030, Loss: 0.5135
Epoch: 031, Loss: 0.5036
Epoch: 032, Loss: 0.4937
Epoch: 033, Loss: 0.4842
Epoch: 034, Loss: 0.4742
Epoch: 035, Loss: 0.4640
Epoch: 036, Loss: 0.4536
Epoch: 037, Loss: 0.4435
Epoch: 038, Loss: 0.4337
Epoch: 039, Loss: 0.4235
Epoch: 040, Loss: 0.4128
Epoch: 041, Loss: 0.4026
Epoch: 042, Loss: 0.3924
Epoch: 043, Loss: 0.3818
Epoch: 044, Loss: 0.3726
Epoch: 045, Loss: 0.3625
Epoch: 046, Loss: 0.3530
Epoch: 047, Loss: 0.3433
Epoch: 048, Loss: 0.3340
Epoch: 049, Loss: 0.3247
Epoch: 050, Loss: 0.3156


100%|██████████| 17535/17535 [00:10<00:00, 1595.86it/s]


Epoch: 050, Val Mean Rank: 1191.29, Val Mean Reciprocal Rank: 0.1211,               Val Hits@1: 0.0668, Val Hits@3: 0.1241, Val Hits@5: 0.1654, Val Hits@10: 0.2323
Epoch: 051, Loss: 0.3065
Epoch: 052, Loss: 0.2979
Epoch: 053, Loss: 0.2893
Epoch: 054, Loss: 0.2811
Epoch: 055, Loss: 0.2736
Epoch: 056, Loss: 0.2658
Epoch: 057, Loss: 0.2588
Epoch: 058, Loss: 0.2513
Epoch: 059, Loss: 0.2450
Epoch: 060, Loss: 0.2379
Epoch: 061, Loss: 0.2316
Epoch: 062, Loss: 0.2251
Epoch: 063, Loss: 0.2195
Epoch: 064, Loss: 0.2143
Epoch: 065, Loss: 0.2091
Epoch: 066, Loss: 0.2041
Epoch: 067, Loss: 0.1996
Epoch: 068, Loss: 0.1942
Epoch: 069, Loss: 0.1906
Epoch: 070, Loss: 0.1858
Epoch: 071, Loss: 0.1825
Epoch: 072, Loss: 0.1780
Epoch: 073, Loss: 0.1741
Epoch: 074, Loss: 0.1705
Epoch: 075, Loss: 0.1676


100%|██████████| 17535/17535 [00:10<00:00, 1731.15it/s]


Epoch: 075, Val Mean Rank: 402.34, Val Mean Reciprocal Rank: 0.2204,               Val Hits@1: 0.1411, Val Hits@3: 0.2370, Val Hits@5: 0.2942, Val Hits@10: 0.3845
Epoch: 076, Loss: 0.1648
Epoch: 077, Loss: 0.1612
Epoch: 078, Loss: 0.1585
Epoch: 079, Loss: 0.1558
Epoch: 080, Loss: 0.1535
Epoch: 081, Loss: 0.1514
Epoch: 082, Loss: 0.1495
Epoch: 083, Loss: 0.1469
Epoch: 084, Loss: 0.1442
Epoch: 085, Loss: 0.1418
Epoch: 086, Loss: 0.1407
Epoch: 087, Loss: 0.1385
Epoch: 088, Loss: 0.1376
Epoch: 089, Loss: 0.1355
Epoch: 090, Loss: 0.1341
Epoch: 091, Loss: 0.1320
Epoch: 092, Loss: 0.1311
Epoch: 093, Loss: 0.1300
Epoch: 094, Loss: 0.1287
Epoch: 095, Loss: 0.1276
Epoch: 096, Loss: 0.1253
Epoch: 097, Loss: 0.1253
Epoch: 098, Loss: 0.1241
Epoch: 099, Loss: 0.1237
Epoch: 100, Loss: 0.1221


100%|██████████| 17535/17535 [00:10<00:00, 1740.65it/s]


Epoch: 100, Val Mean Rank: 248.46, Val Mean Reciprocal Rank: 0.2448,               Val Hits@1: 0.1623, Val Hits@3: 0.2617, Val Hits@5: 0.3212, Val Hits@10: 0.4160


100%|██████████| 20466/20466 [00:11<00:00, 1849.59it/s]

Test Mean Rank: 265.86, Test Mean Reciprocal Rank: 0.2397, Test Hits@1: 0.1573,      Test Hits@3: 0.2558, Val Hits@5: 0.3165, Test Hits@10: 0.4109



