In [None]:
# ! wget https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/WN18RR/original/train.txt
# ! wget https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/WN18RR/original/valid.txt
# ! wget https://raw.githubusercontent.com/villmow/datasets_knowledge_embedding/master/WN18RR/original/test.txt

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
! pip install -r drive/MyDrive/Diploma/requirements.txt >& /dev/null

OK


In [None]:
! nvidia-smi

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



In [None]:
import numpy as np
import torch
from torch import nn

import wandb

from tqdm.notebook import tqdm

from drive.MyDrive.Diploma.environment import KGEnv
from drive.MyDrive.Diploma.utils import (
    create_test_dataset, create_random_dataset,
    create_model, read_from_file, LabelSmoothingCrossEntropy,
    load_config
)
from drive.MyDrive.Diploma.dataset import KGDataset
from drive.MyDrive.Diploma.train import (
    run_episode, train, evaluate, pretrain,
    train_emb, evaluate_emb
)
from drive.MyDrive.Diploma.beam_search import get_ranks
from drive.MyDrive.Diploma.embed_model import ComplEx
from drive.MyDrive.Diploma.metrics import *

%load_ext autoreload
%autoreload 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
train_triplets_path = 'drive/MyDrive/Diploma/wn18rr/train.txt'
test_triplets_path = 'drive/MyDrive/Diploma/wn18rr/test.txt'
entities_path = 'drive/MyDrive/Diploma/wn18rr/entities.txt'
relations_path = 'drive/MyDrive/Diploma/wn18rr/relations.txt'

In [None]:
env = KGEnv(train_triplets_path, entities_path, relations_path)

entities_num = env.entities_num
relations_num = env.relations_num

In [None]:
train_env = KGEnv(train_triplets_path, entities_path, relations_path, batch_size=1)
test_env = KGEnv(test_triplets_path, entities_path, relations_path, train=False, batch_size=1)

In [None]:
test_dataset = create_test_dataset(train_env, test_env, out_file='drive/MyDrive/Diploma/test_dataset.txt')

In [None]:
train_dataset = create_random_dataset(train_env, out_file='drive/MyDrive/Diploma/train_dataset.txt', size=20000)

In [None]:
train_dataset = read_from_file('/content/drive/MyDrive/Diploma/train_dataset.txt', train_env)
test_dataset = read_from_file('/content/drive/MyDrive/Diploma/test_dataset.txt', test_env)

In [None]:
train_kg_dataset = KGDataset(train_dataset, train_env, shuffle=True)
# train_iter = CustomIterator(train_kg_dataset, batch_size=512, device=torch.device('cuda'), repeat=False, train=True, shuffle=True)

test_kg_dataset = KGDataset(test_dataset, test_env, shuffle=True)
# test_iter = CustomIterator(test_kg_dataset, batch_size=64, device=torch.device('cuda'), repeat=False, train=True, shuffle=True)

In [None]:
model = create_model(
    entity_input_dim=entities_num + 1,  # pad
    relation_input_dim=relations_num + 2,  # cls + pad
    output_dim=relations_num,
    entity_pad_idx=train_env.e_pad_idx,
    relation_pad_idx=train_env.r_pad_idx,
    hid_dim=128,
    enc_pf_dim=256,
    device=device
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
wandb.init(project="RL4KGQA", name="Pretrain")

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
NUM_EPOCHES = 100
for epoch in tqdm(range(NUM_EPOCHES)):
    loss, accuracy = pretrain(model, train_kg_dataset, optimizer, criterion, batch_size=128)
    test_loss, test_accuracy = evaluate(model, test_kg_dataset, criterion, batch_size=64)

    wandb.log({"Train CE loss": loss,
               "Train accuracy": accuracy,
               "Test CE loss": test_loss,
               "Test accuracy": test_accuracy})

HBox(children=(FloatProgress(value=0.0), HTML(value='')))




In [None]:
torch.save(model.state_dict(), 'drive/MyDrive/Diploma/pretrained_agent.pt')