In [14]:
import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.datasets import ICEWS18
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models.re_net import RENet

import wandb
from easydict import EasyDict

wandb.login(key="8df071c79082d7ec99e9da99802221c4edef7d8c")

CFG = EasyDict()
CFG.project = "temporal-knowledge-base-completion"
CFG.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CFG.epochs = 21
CFG.seq_len = 5
CFG.model = "Renet"
CFG.tags = "Baseline"
CFG.batch_size = 2048
CFG.hidden_channels = 100


def wandb_init():
    config = {k: v for k, v in CFG.items() if '__' not in k}
    run = wandb.init(
        project=CFG.project,
        name=f"{CFG.model}-epoch-{CFG.epochs}",
        tags=CFG.tags,
        config=config,
        save_code=True
    )
    return run



In [15]:
__file__ = "./"
# Load the dataset and precompute history objects.
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ICEWS18')
train_dataset = ICEWS18(path, pre_transform=RENet.pre_transform(CFG.seq_len))
test_dataset = ICEWS18(path, split='test')
train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, follow_batch=['h_sub', 'h_obj'])
test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False, follow_batch=['h_sub', 'h_obj'])

In [19]:
from torch_scatter import scatter_mean

print(train_loader.__iter__())
for data in train_loader:
    print(data.h_obj_t)
    batch_size, seq_len = data.sub.size(0), CFG.seq_len
    h_sub_t = data.h_sub_t + data.h_sub_batch * seq_len
    print(data.h_sub_t, data.h_sub_batch, h_sub_t)
    ent = torch.Tensor(train_dataset.num_nodes, 100)
    h_sub = scatter_mean(ent[data.h_sub], h_sub_t, dim=0, dim_size=batch_size * seq_len).view(batch_size, seq_len, -1)
    print(h_sub.size())
    h_obj_t = data.h_obj_t + data.h_obj_batch * seq_len
    break

<torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x0000028136815240>
tensor([0, 0, 0,  ..., 3, 3, 4])
tensor([4, 1, 2,  ..., 3, 3, 4]) tensor([   4,    5,    5,  ..., 2045, 2045, 2046]) tensor([   24,    26,    27,  ..., 10228, 10228, 10234])
torch.Size([2048, 5, 100])


In [None]:
import numpy as np

# Initialize model and optimizer.
model = RENet(train_dataset.num_nodes, train_dataset.num_rels, hidden_channels=CFG.hidden_channels, seq_len=seq_len,
              dropout=0.2, ).to(CFG.device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.00001)
run = wandb_init()
wandb.watch(model, log='all')
Total_params = 0
Trainable_params = 0
for param in model.parameters():
    mulValue = np.prod(param.size())
    Total_params += mulValue
    if param.requires_grad:
        Trainable_params += mulValue
wandb.log({"Total params": Total_params, "Trainable params": Trainable_params})


def train():
    model.train()
    # Train model via multi-class classification against the corresponding
    # object and subject entities.
    for data in train_loader:
        data = data.to(CFG.device)
        optimizer.zero_grad()
        log_prob_obj, log_prob_sub = model(data)
        loss_obj = F.nll_loss(log_prob_obj, data.obj)
        loss_sub = F.nll_loss(log_prob_sub, data.sub)
        loss = loss_obj + loss_sub
        loss.backward()
        optimizer.step()


def test(loader):
    model.eval()
    # Compute Mean Reciprocal Rank (MRR) and Hits@1/3/10.
    result = torch.tensor([0, 0, 0, 0], dtype=torch.float)
    for data in loader:
        data = data.to(CFG.device)
        with torch.no_grad():
            log_prob_obj, log_prob_sub = model(data)
        result += model.test(log_prob_obj, data.obj) * data.obj.size(0)
        result += model.test(log_prob_sub, data:.sub) *data.sub.size(0)
    result = result / (2 * len(loader.dataset))
    return result.tolist()  #%%

In [24]:
for epoch in range(1, 21):
    wandb.log({"epoch": epoch})
    train()
    mrr, hits1, hits3, hits10 = test(test_loader)
    wandb.log({"MRR": mrr, "Hits@1": hits1, "Hits@3": hits3, "Hits@10": hits10})
torch.save(model.state_dict(), "a.pt")
run.finish()

0,1
Hits@1,▁▄▅▆▆▇▇▇▇███████████
Hits@10,▁▄▅▆▇▇▇▇████████████
Hits@3,▁▄▅▆▇▇▇▇████████████
MRR,▁▄▅▆▇▇▇▇████████████
Total params,▁
Trainable params,▁
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██

0,1
Hits@1,0.1822
Hits@10,0.46237
Hits@3,0.31611
MRR,0.27751
Total params,16435966.0
Trainable params,16435966.0
epoch,20.0
