In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
! pip install -r drive/MyDrive/Diploma/requirements.txt >& /dev/null

In [None]:
! nvidia-smi

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



In [None]:
import numpy as np
import torch
from torch import nn

import wandb

from tqdm.notebook import tqdm

from drive.MyDrive.Diploma.environment import KGEnv
from drive.MyDrive.Diploma.utils import (
    create_model, LabelSmoothingCrossEntropy,
    load_config
)
from drive.MyDrive.Diploma.dataset import KGDataset
from drive.MyDrive.Diploma.train import (
    run_episode, pretrain,
    train_emb, evaluate_emb
)
from drive.MyDrive.Diploma.beam_search import get_ranks
from drive.MyDrive.Diploma.embed_model import ComplEx
from drive.MyDrive.Diploma.metrics import *

%load_ext autoreload
%autoreload 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
train_triplets_path = 'drive/MyDrive/Diploma/wn18rr/train.txt'
test_triplets_path = 'drive/MyDrive/Diploma/wn18rr/test.txt'
entities_path = 'drive/MyDrive/Diploma/wn18rr/entities.txt'
relations_path = 'drive/MyDrive/Diploma/wn18rr/relations.txt'

In [None]:
env = KGEnv(train_triplets_path, entities_path, relations_path)

entities_num = env.entities_num
relations_num = env.relations_num

In [None]:
train_env = KGEnv(train_triplets_path, entities_path, relations_path)
test_env = KGEnv(test_triplets_path, entities_path, train=False, relations_path)

emb_model = ComplEx(entities_num, relations_num, hid_dim=256).to(device)
optimizer = torch.optim.Adam(emb_model.parameters(), lr=0.001)
criterion = LabelSmoothingCrossEntropy() 

In [None]:
optimizer = torch.optim.Adam(emb_model.parameters(), lr=0.001)

In [None]:
wandb.init(project="RL4KGQA", name="Embed model")

[34m[1mwandb[0m: [32m[41mERROR[0m Control-C detected -- Run data was not synced


In [None]:
np.random.shuffle(train_env.triplets)
np.random.shuffle(test_env.triplets)

In [None]:
NUM_EPOCHES = 100
for epoch in tqdm(range(NUM_EPOCHES)):
    train_loss, train_accuracy, train_ranks = train_emb(
        emb_model, train_env.triplets, optimizer, criterion, batch_size=128
    )
    test_loss, test_accuracy, test_ranks = evaluate_emb(
        emb_model, test_env.triplets, criterion, batch_size=64
    )

    wandb.log({"Train CE loss": train_loss,
               "Train accuracy": train_accuracy,
               "Test CE loss": test_loss,
               "Test accuracy": test_accuracy,
               "Train MMR": mmr(train_ranks),
               "Test MMR": mmr(test_ranks),
               "Train HIT@10": hit_k(train_ranks, k=10),
               "Test HIT@10": hit_k(test_ranks, k=10)})
    
    if epoch % 100 == 99:
        torch.save(emb_model.state_dict(), 'drive/MyDrive/Diploma/model/emb_model{}.pt'.format(epoch))


In [None]:
torch.save(emb_model.state_dict(), 'drive/MyDrive/Diploma/models/emb_model.pt')

In [None]:
emb_model.load_state_dict(torch.load('drive/MyDrive/Diploma/models/emb_model599.pt'))

<All keys matched successfully>

In [None]:
loss, accuracy, ranks = evaluate_emb(
        emb_model, test_env.triplets, criterion, batch_size=256
)

In [None]:
print('Test  | HIT@1: {:.3},\t HIT@10: {:.3},\t MMR: {:.3}'.format(
    hit_k(ranks, k=1), hit_k(ranks, k=10), mmr(ranks)
))

Test  | HIT@1: 0.219,	 HIT@10: 0.853,	 MMR: 0.411


In [None]:
print('Test  | HIT@1: {:.3},\t HIT@10: {:.3},\t MMR: {:.3}'.format(
    hit_k(ranks, k=1), hit_k(ranks, k=10), mmr(ranks)
))

Test  | HIT@1: 0.0742,	 HIT@10: 0.363,	 MMR: 0.165
