In [None]:
! pip install -r drive/MyDrive/Diploma/requirements.txt >& /dev/null

In [None]:
! nvidia-smi

Mon May 17 13:51:30 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import numpy as np
import torch
from torch import nn

import wandb

from tqdm.notebook import tqdm

from drive.MyDrive.Diploma.environment import KGEnv
from drive.MyDrive.Diploma.utils import create_model, load_config
from drive.MyDrive.Diploma.dataset import KGDataset
from drive.MyDrive.Diploma.train_rl import run_episode, train
from drive.MyDrive.Diploma.beam_search import get_ranks
from drive.MyDrive.Diploma.embed_model import ComplEx
from drive.MyDrive.Diploma.metrics import *
from drive.MyDrive.Diploma.model import GraphSearchPolicy

from drive.MyDrive.Diploma.load_config import config

%load_ext autoreload
%autoreload 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
config

{'add_reverse_relations': True,
 'batch_size': 512,
 'beam_size': 128,
 'entities_file': 'drive/MyDrive/Diploma/fb15k-237/entities.txt',
 'entropy_coef': 0.0,
 'gamma': 0.99,
 'len_penalty': True,
 'lr': 0.0001,
 'max_rollout': 20,
 'normalize_reward': False,
 'num_beam_steps': 10,
 'num_epochs': 1000,
 'num_rollouts': 100,
 'only_relations': False,
 'relations_file': 'drive/MyDrive/Diploma/fb15k-237/relations.txt',
 'rl_method': 'REINFORCE',
 'rs_coef': 0.2,
 'seed': 2441406995705867619,
 'static_relation': 'stop',
 'test_triplets_path': 'drive/MyDrive/Diploma/fb15k-237/test.txt',
 'train_triplets_path': 'drive/MyDrive/Diploma/fb15k-237/train.txt'}

In [None]:
env = KGEnv(config['train_triplets_path'])

entities_num = env.entities_num
relations_num = env.relations_num

In [None]:
emb_model = ComplEx(entities_num, relations_num, hid_dim=256).to(device)
emb_model.load_state_dict(torch.load('drive/MyDrive/Diploma/models/emb_model_fb15k.pt'))

<All keys matched successfully>

In [None]:
batch_size = config['batch_size']
env = KGEnv(config['train_triplets_path'], batch_size=batch_size, emb_model=emb_model)

agent = create_model(
    entity_input_dim=env.entities_num + 1,  # pad
    relation_input_dim=env.relations_num + 2,  # cls + pad
    output_dim=env.relations_num,
    entity_pad_idx=env.e_pad_idx,
    relation_pad_idx=env.r_pad_idx,
    hid_dim=128,
    enc_pf_dim=256,
    device=device
)

In [None]:
optimizer = torch.optim.Adam(agent.parameters(), lr=config['lr'])

In [None]:
wandb.init(project="RL4KGQA", name="REINFORCE")
wandb.config.update(config)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [None]:
torch.initial_seed()

569581447802162076

In [None]:
train_env = KGEnv(config['train_triplets_path'], batch_size=1)
test_env = KGEnv(config['test_triplets_path'], train=False, batch_size=1)

In [None]:
agent.train()

for epoch in tqdm(range(config['num_epochs'])):
    episod_info = run_episode(env, agent)
    loss = train(agent, optimizer, episod_info)

    if epoch % 500 == 499:
        test_ranks = get_ranks(agent, train_env, test_env)
        test_ranks = test_ranks[test_ranks != env.entities_num]
        print(hit_k(test_ranks, k=1), hit_k(test_ranks, k=10), mmr(test_ranks))

    flat_stops = [prob[0] for probs in episod_info['probs'] for prob in probs]
    sum_reward = sum([r for reward in episod_info['rewards'] for r in reward])
    wandb.log({
        'Loss': loss,
        'Accuracy': episod_info['num_correct'] / batch_size,
        'Steps': episod_info['num_steps'],
        'Mean steps': episod_info['mean_steps'],
        'Stop_prob': sum(flat_stops) / len(flat_stops),
        'Mean_reward': sum_reward / batch_size,
        'Len_penalty': episod_info['len_pen']
    })

In [None]:
train_ranks = get_ranks(agent, train_env, train_env)
test_ranks = get_ranks(agent, train_env, test_env)

HBox(children=(FloatProgress(value=0.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))




In [None]:
train_ranks = train_ranks[train_ranks != env.entities_num]
test_ranks = test_ranks[test_ranks != env.entities_num]

In [None]:
print('Train | HIT@1: {:.3},\t HIT@10: {:.3},\t MRR: {:.3}'.format(
    hit_k(train_ranks, k=1), hit_k(train_ranks, k=10), mmr(train_ranks)
))
print('Test  | HIT@1: {:.3},\t HIT@10: {:.3},\t MRR: {:.3}'.format(
    hit_k(test_ranks, k=1), hit_k(test_ranks, k=10), mmr(test_ranks)
))