In [2]:
import os
os.chdir("/content/drive/MyDrive/DataScience/PROJECT/2022_KGQA/paper")

In [3]:
import numpy as np
from tqdm.auto import tqdm
import torch
from torch.optim.lr_scheduler import ExponentialLR

from dataloader import DatasetMetaQA, DataLoaderMetaQA
from model import RelationExtractor

In [4]:
# {entity name: entity embedding} / {relation_name: relation_embedding} 으로 변환
def preprocess_entities_relations(entity_dict, relation_dict, entities, relations):
    e = {}
    r = {}

    f = open(entity_dict, 'r')
    for line in f:
        line = line.strip().split('\t')
        ent_id = int(line[0])
        ent_name = line[1]
        e[ent_name] = entities[ent_id]
    f.close()

    f = open(relation_dict,'r')
    for line in f:
        line = line.strip().split('\t')
        rel_id = int(line[0])
        rel_name = line[1]
        r[rel_name] = relations[rel_id]
    f.close()
    return e,r

# {entity word: entity id} / {entity id: entity word} / [entity embeddings]
def prepare_embeddings(embedding_dict):
    entity2idx = {}
    idx2entity = {}
    i = 0
    embedding_matrix = []
    for key, entity in embedding_dict.items():
        entity2idx[key.strip()] = i
        idx2entity[i] = key.strip()
        i += 1
        embedding_matrix.append(entity)
    return entity2idx, idx2entity, embedding_matrix

# {Question: Answer} 형태의 데이터를 [Head, Question, Answer] 형태의 데이터로 변환
# split=True 의 경우 하나의 question에 여러 개의 answer 가 있는 경우, 여러 개의 data로 나누어서 저장
def process_text_file(text_file, split=False):
    data_file = open(text_file, 'r')
    data_array = []
    for data_line in data_file.readlines():
        data_line = data_line.strip()
        if data_line == '':
            continue
        data_line = data_line.strip().split('\t')
        question = data_line[0].split('[')
        question_1 = question[0]
        question_2 = question[1].split(']')
        head = question_2[0].strip()
        question_2 = question_2[1]
        question = question_1+'NE'+question_2
        ans = data_line[1].split('|')
        data_array.append([head, question.strip(), ans])
    if split==False:
        return data_array
    else:
        data = []
        for line in data_array:
            head = line[0]
            question = line[1]
            tails = line[2]
            for tail in tails:
                data.append([head, question, tail])
        return data

# Head 와 Answer 을 제외한 Question 에 대한 word 들을 idx 로 변환 
# relation 을 표현하는 word 들을 정리하는 느낌?
# max_len 은 가장 긴 word 길이
def get_vocab(data):
    word_to_ix = {}
    maxLength = 0
    idx2word = {}
    for d in data:
            sent = d[1]
            for word in sent.split():
                if word not in word_to_ix:
                    idx2word[len(word_to_ix)] = word
                    word_to_ix[word] = len(word_to_ix)
                    
            length = len(sent.split())
            if length > maxLength:
                maxLength = length

    return word_to_ix, idx2word, maxLength

def data_generator(data, word2ix, entity2idx):
    for i in range(len(data)):
        data_sample = data[i]
        head = entity2idx[data_sample[0].strip()]
        question = data_sample[1].strip().split(' ')
        encoded_question = [word2ix[word.strip()] for word in question]
        if type(data_sample[2]) is str:
            ans = entity2idx[data_sample[2]]
        else:
            ans = [entity2idx[entity.strip()] for entity in list(data_sample[2])]

        yield torch.tensor(head, dtype=torch.long),torch.tensor(encoded_question, dtype=torch.long) , ans, torch.tensor(len(encoded_question), dtype=torch.long), data_sample[1]



In [5]:
def train(data_path, entity_path, relation_path, entity_dict, relation_dict, neg_batch_size, batch_size, shuffle, num_workers, nb_epochs, embedding_dim, hidden_dim, relation_dim, gpu, use_cuda,patience, freeze, validate_every, num_hops, lr, entdrop, reldrop, scoredrop, l3_reg, model_name, decay, ls, w_matrix, bn_list, valid_data_path=None):
    
    # entity & relation 을 embedding 값으로 변환 (embedding 은 pretrained model 이용)
    entities = np.load(entity_path)
    relations = np.load(relation_path)
    e,r = preprocess_entities_relations(entity_dict, relation_dict, entities, relations)
    entity2idx, idx2entity, embedding_matrix = prepare_embeddings(e)

    # Question - Answer preprocessing
    data = process_text_file(data_path, split=False)

    # relation word 
    word2ix,idx2word, max_len = get_vocab(data)
    hops = str(num_hops)

    device = torch.device("cuda" if use_cuda else "cpu")

    # dataset: entity2id matrix 를 이용해서 question ids, head id, answer ids 형태로 변환
    # answer ids 의 경우 answer 가 여러 개인 경우 해당 id들은 모두 1
    # data_loader: 각 batch 별로 max_len 계산 후 question ids, head id, tail onehot ids
    dataset = DatasetMetaQA(data=data, word2ix=word2ix, relations=r, entities=e, entity2idx=entity2idx)
    data_loader = DataLoaderMetaQA(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    model = RelationExtractor(embedding_dim=embedding_dim, hidden_dim=hidden_dim, vocab_size=len(word2ix), num_entities = len(idx2entity), relation_dim=relation_dim, pretrained_embeddings=embedding_matrix, freeze=freeze, device=device, entdrop = entdrop, reldrop = reldrop, scoredrop = scoredrop, l3_reg = l3_reg, model = model_name, ls = ls, w_matrix = w_matrix, bn_list=bn_list)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = ExponentialLR(optimizer, decay)
    optimizer.zero_grad()
    best_score = -float("inf")
    best_model = model.state_dict()
    no_update = 0
    for epoch in range(nb_epochs):

        # 매 validate_every epoch 마다 validation 진행
        phases = []
        for i in range(validate_every):
            phases.append('train')
        phases.append('valid')

        for phase in phases:
            if phase == 'train':
                model.train()
                if freeze == True:
                    # print('Freezing batch norm layers')
                    model.apply(set_bn_eval)
                loader = tqdm(data_loader, total=len(data_loader), unit="batches")
                running_loss = 0
                for i_batch, a in enumerate(loader):
                    model.zero_grad()
                    question = a[0].to(device)
                    sent_len = a[1].to(device)
                    positive_head = a[2].to(device)
                    positive_tail = a[3].to(device)                    

                    loss = model(sentence=question, p_head=positive_head, p_tail=positive_tail, question_len=sent_len)
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                    loader.set_postfix(Loss=running_loss/((i_batch+1)*batch_size), Epoch=epoch)
                    loader.set_description('{}/{}'.format(epoch, nb_epochs))
                    loader.update()
                
                scheduler.step()

            elif phase=='valid':
                model.eval()
                eps = 0.0001
                answers, score = validate(model=model, data_path= valid_data_path, word2idx= word2ix, entity2idx= entity2idx, device=device, model_name=model_name)
                if score > best_score + eps:
                    best_score = score
                    no_update = 0
                    best_model = model.state_dict()
                    print(hops + " hop Validation accuracy increased from previous epoch", score)
                    _, test_score = validate(model=model, data_path= test_data_path, word2idx= word2ix, entity2idx= entity2idx, device=device, model_name=model_name)
                    print('Test score for best valid so far:', test_score)
                    # writeToFile(answers, 'results_' + model_name + '_' + hops + '.txt')
                    suffix = ''
                    if freeze == True:
                        suffix = '_frozen'
                    checkpoint_path = 'checkpoints/MetaQA/'
                    checkpoint_file_name = checkpoint_path +model_name+ '_' + hops + suffix + ".pt"
                    print('Saving checkpoint to ', checkpoint_file_name)
                    torch.save(model.state_dict(), checkpoint_file_name)
                elif (score < best_score + eps) and (no_update < patience):
                    no_update +=1
                    print("Validation accuracy decreases to %f from %f, %d more epoch to check"%(score, best_score, patience-no_update))
                elif no_update == patience:
                    print("Model has exceed patience. Saving best model and exiting")
                    torch.save(best_model, checkpoint_path+ "best_score_model.pt")
                    exit()
                if epoch == nb_epochs-1:
                    print("Final Epoch has reached. Stopping and saving model.")
                    torch.save(best_model, checkpoint_path +"best_score_model.pt")
                    exit()

def validate(data_path, device, model, word2idx, entity2idx, model_name):
    model.eval()
    data = process_text_file(data_path)
    answers = []
    data_gen = data_generator(data=data, word2ix=word2idx, entity2idx=entity2idx)
    total_correct = 0
    error_count = 0
    for i in tqdm(range(len(data))):
        try:
            d = next(data_gen)
            head = d[0].to(device)
            question = d[1].to(device)
            ans = d[2]
            ques_len = d[3].unsqueeze(0)
            tail_test = torch.tensor(ans, dtype=torch.long).to(device)

            # (head-relation-tail) 중 score 가 가장 높은 top k개 추출
            top_2 = model.get_score_ranked(head=head, sentence=question, sent_len=ques_len)
            top_2_idx = top_2[1].tolist()[0]
            head_idx = head.tolist()
            if top_2_idx[0] == head_idx:
                pred_ans = top_2_idx[1]
            else:
                pred_ans = top_2_idx[0]
            if type(ans) is int:
                ans = [ans]
            is_correct = 0
            if pred_ans in ans:
                total_correct += 1
                is_correct = 1

            # question + prediction + correctness 저장
            q_text = d[-1]
            answers.append(q_text + '\t' + str(pred_ans) + '\t' + str(is_correct))
        except:
            error_count += 1
            
    print(error_count)
    accuracy = total_correct/len(data)
    return answers, accuracy

In [6]:
hops = 2
model_name = 'ComplEx'
kg_type = 'half'
neg_batch_size = 128
batch_size = 128
shuffle_data = True
num_workers = 2
nb_epochs = 10
embedding_dim = 256
hidden_dim = 256
relation_dim = 200
gpu = 0
use_cuda = True
patience = 5
validate_every = 5
freeze = 0
lr = 0.0005
entdrop = 0.1
reldrop = 0.2
scoredrop = 0.2
l3_reg = 0.0
model = 'ComplEx'
decay = 1.0
ls = 0.0

In [7]:
data_path = 'data/QA_data/MetaQA/qa_train_' + f"{hops}hop" + '.txt'

hops_without_old = f"{hops}hop".replace('_old', '')
valid_data_path = 'data/QA_data/MetaQA/qa_dev_' + hops_without_old + '.txt'
test_data_path = 'data/QA_data/MetaQA/qa_test_' + hops_without_old + '.txt'

embedding_folder = 'pretrained_models/embeddings/' + model_name + '_MetaQA_' + kg_type
entity_embedding_path = embedding_folder + '/E.npy'
relation_embedding_path = embedding_folder + '/R.npy'
entity_dict = embedding_folder + '/entities.dict'
relation_dict = embedding_folder + '/relations.dict'
w_matrix =  embedding_folder + '/W.npy'

In [8]:
bn_list = []

for i in range(3):
    bn = np.load(embedding_folder + '/bn' + str(i) + '.npy', allow_pickle=True)
    bn_list.append(bn.item())

# train

In [11]:
train(data_path=data_path, 
    entity_path=entity_embedding_path, 
    relation_path=relation_embedding_path,
    entity_dict=entity_dict, 
    relation_dict=relation_dict, 
    neg_batch_size=neg_batch_size, 
    batch_size=batch_size,
    shuffle=shuffle_data, 
    num_workers=num_workers,
    nb_epochs=nb_epochs, 
    embedding_dim=embedding_dim, 
    hidden_dim=hidden_dim, 
    relation_dim=relation_dim, 
    gpu=gpu, 
    use_cuda=use_cuda, 
    valid_data_path=valid_data_path,
    patience=patience,
    validate_every=validate_every,
    freeze=freeze,
    num_hops=hops,
    lr=lr,
    entdrop=entdrop,
    reldrop=reldrop,
    scoredrop = scoredrop,
    l3_reg = l3_reg,
    model_name=model,
    decay=decay,
    ls=ls,
    w_matrix=w_matrix,
    bn_list=bn_list)

Model is ComplEx
Frozen: 0


  self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(pretrained_embeddings), freeze=self.freeze)


  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa16d93eef0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()    Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fa16d93eef0>  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers

Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa16d93eef0>Traceback (most recent call last):
if w.is_alive():

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
        self._shutdown_workers()assert self._parent_pid == os.getpid(), 'ca

  0%|          | 0/14872 [00:00<?, ?it/s]

0
2 hop Validation accuracy increased from previous epoch 0.8286713286713286


  0%|          | 0/14872 [00:00<?, ?it/s]

0
Test score for best valid so far: 0.8384884346422808
Saving checkpoint to  checkpoints/MetaQA/ComplEx_2.pt


  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/14872 [00:00<?, ?it/s]

0
2 hop Validation accuracy increased from previous epoch 0.9041823561054331


  0%|          | 0/14872 [00:00<?, ?it/s]

0
Test score for best valid so far: 0.9042495965572889
Saving checkpoint to  checkpoints/MetaQA/ComplEx_2.pt


  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/14872 [00:00<?, ?it/s]

0
2 hop Validation accuracy increased from previous epoch 0.9306750941366326


  0%|          | 0/14872 [00:00<?, ?it/s]

0
Test score for best valid so far: 0.9291285637439484
Saving checkpoint to  checkpoints/MetaQA/ComplEx_2.pt


  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa16d93eef0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa16d93eef0>    
self._shutdown_workers()
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
        self._shutdown_workers()if w.is_alive():

  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
        assert self._parent_pid == os.getpid(), 'can only test a child process'if w.is_alive():
AssertionError: 
can only test a child process
  File "/usr/lib/pytho

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/14872 [00:00<?, ?it/s]

0
2 hop Validation accuracy increased from previous epoch 0.940357719203873


  0%|          | 0/14872 [00:00<?, ?it/s]

0
Test score for best valid so far: 0.9404922001075847
Saving checkpoint to  checkpoints/MetaQA/ComplEx_2.pt


  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/14872 [00:00<?, ?it/s]

0
2 hop Validation accuracy increased from previous epoch 0.9457369553523399


  0%|          | 0/14872 [00:00<?, ?it/s]

0
Test score for best valid so far: 0.9445938676707908
Saving checkpoint to  checkpoints/MetaQA/ComplEx_2.pt


  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa16d93eef0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa16d93eef0><function _MultiProcessingDataLoaderIter.__del__ at 0x7fa16d93eef0>
Traceback (most recent call last):

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
        self._shutdown_workers()if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive

  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():    
assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError  File "/usr/lib/python3.7/multiprocessing/process.py"

  0%|          | 0/14872 [00:00<?, ?it/s]

0
2 hop Validation accuracy increased from previous epoch 0.9520575578267886


  0%|          | 0/14872 [00:00<?, ?it/s]

0
Test score for best valid so far: 0.9500403442711135
Saving checkpoint to  checkpoints/MetaQA/ComplEx_2.pt


  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/14872 [00:00<?, ?it/s]

0
2 hop Validation accuracy increased from previous epoch 0.9522592791823561


  0%|          | 0/14872 [00:00<?, ?it/s]

0
Test score for best valid so far: 0.9509817105970952
Saving checkpoint to  checkpoints/MetaQA/ComplEx_2.pt


  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/14872 [00:00<?, ?it/s]

0
2 hop Validation accuracy increased from previous epoch 0.9534023668639053


  0%|          | 0/14872 [00:00<?, ?it/s]

0
Test score for best valid so far: 0.9539402904787521
Saving checkpoint to  checkpoints/MetaQA/ComplEx_2.pt


  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa16d93eef0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa16d93eef0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

  0%|          | 0/1812 [00:00<?, ?batches/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa16d93eef0>
Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fa16d93eef0>  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__

Traceback (most recent call last):
    self._shutdown_workers()  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__

  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
        self._shutdown_workers()if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    
assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    Exception ignored in: if w.is_alive():


  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/14872 [00:00<?, ?it/s]

0
2 hop Validation accuracy increased from previous epoch 0.9542764927380312


  0%|          | 0/14872 [00:00<?, ?it/s]

0
Test score for best valid so far: 0.9563609467455622
Saving checkpoint to  checkpoints/MetaQA/ComplEx_2.pt


  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/1812 [00:00<?, ?batches/s]

  0%|          | 0/14872 [00:00<?, ?it/s]

0
2 hop Validation accuracy increased from previous epoch 0.9553523399677246


  0%|          | 0/14872 [00:00<?, ?it/s]

0
Test score for best valid so far: 0.9562264658418504
Saving checkpoint to  checkpoints/MetaQA/ComplEx_2.pt
Final Epoch has reached. Stopping and saving model.


# test

In [24]:
entities = np.load(entity_embedding_path)
relations = np.load(relation_embedding_path)
e,r = preprocess_entities_relations(entity_dict, relation_dict, entities, relations)
entity2idx, idx2entity, embedding_matrix = prepare_embeddings(e)

# Question - Answer preprocessing
data = process_text_file(data_path, split=False)
test_data = process_text_file(test_data_path, split=False)

word2ix,idx2word, max_len = get_vocab(data)
hops = str(hops)

In [18]:
device = torch.device("cuda" if use_cuda else "cpu")
model = RelationExtractor(embedding_dim=embedding_dim, hidden_dim=hidden_dim, vocab_size=len(word2ix), num_entities = len(idx2entity), relation_dim=relation_dim, pretrained_embeddings=embedding_matrix, freeze=freeze, device=device, entdrop = entdrop, reldrop = reldrop, scoredrop = scoredrop, l3_reg = l3_reg, model = model_name, ls = ls, w_matrix = w_matrix, bn_list=bn_list)
model.to(device)

model.load_state_dict(torch.load("checkpoints/MetaQA/best_score_model.pt"))

Model is ComplEx
Frozen: 0


<All keys matched successfully>

In [19]:
test_pred, test_score = validate(model=model, data_path= test_data_path, word2idx= word2ix, entity2idx= entity2idx, device=device, model_name=model_name)
pred_ids = [i.split('\t')[1] for i in test_pred]

  0%|          | 0/14872 [00:00<?, ?it/s]

0


In [38]:
for i in range(5,10):
    print(f"Head: {test_data[i][0]} / Qestion: {test_data[i][1]} / Pred: {idx2entity[int(pred_ids[i])]} / Answer: {test_data[i][2]}")

Head: Jerry Lewis / Qestion: what are the genres of the films directed by NE / Pred: Comedy / Answer: ['Drama', 'Comedy', 'War']
Head: Angie Everhart / Qestion: who appeared in the same movie with NE / Pred: Erika Eleniak / Answer: ['Erika Eleniak', 'Dennis Miller']
Head: Mike Nichols / Qestion: what are the genres of the films directed by NE / Pred: Drama / Answer: ['Drama', 'Horror', 'Comedy', 'War', 'Thriller']
Head: John Travis / Qestion: who are the actors in the films written by NE / Pred: Chace Crawford / Answer: ['Haley Bennett', 'Chace Crawford', 'Jake Weber']
Head: Clifford Rose / Qestion: the movies starred by NE were written by who / Pred: Peter Weiss / Answer: ['Adrian Mitchell', 'Geoffrey Skelton', 'Peter Weiss']
