In [112]:
import argparse
from typing import Dict
import logging
import torch
from torch import optim
import pickle
import numpy as np

from qa_models import QA_model, QA_model_Only_Embeddings, QA_model_BERT, QA_model_EaE, QA_model_EmbedKGQA, \
    QA_model_EaE_replace, QA_model_EmbedKGQA_complex
from qa_datasets import QA_Dataset, QA_Dataset_model1, QA_Dataset_EaE, QA_Dataset_EmbedKGQA, QA_Dataset_EaE_replace
from torch.utils.data import Dataset, DataLoader
import utils
from tqdm import tqdm
from utils import loadTkbcModel, loadTkbcModel_complex
from collections import defaultdict
from datetime import datetime
from collections import OrderedDict


In [113]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


In [114]:
dataset_name = 'wikidata_big'
tkbc_model_file = 'tcomplex_17dec.ckpt'
tkbc_model = loadTkbcModel_complex('models/{dataset_name}/kg_embeddings/{tkbc_model_file}'.format(
    dataset_name=dataset_name, tkbc_model_file=tkbc_model_file
))

Loading tkbc model from models/wikidata_big/kg_embeddings/tcomplex_17dec.ckpt
Number ent,rel,ts from loaded model: 125726 406 9621
Loaded tkbc model


In [115]:
class Args:
    lm_frozen = 1
    frozen = 1

    multi_label=0
    combine_all_ents = 'None'
    attention = False


def openFileAsDict(filename):
    f = open(filename, 'r', errors='ignore')
    out = {}
    for line in f:
        line = line[:-1].split('\t')  # can't strip() since name can be whitespace
        out[line[0]] = line[1]
    return out


def convertToDataPoint(question_text, entities, times, answer_type='entity', answers=set()):
    question = {}
    question['question'] = question_text
    #     question['answers'] = answers
    question['answers'] = set(['Q888504'])
    question['answer_type'] = answer_type
    question['entities'] = set(entities)
    question['times'] = set(times)
    entFile = 'data/wikidata_big/kg/wd_id2entity_text.txt'
    id2ent = openFileAsDict(entFile)
    paraphrase = question_text
    for e in entities:
        paraphrase = paraphrase.replace(e, id2ent[e])
    question['paraphrases'] = [paraphrase]
    return question



In [116]:
entFile = './data/wikidata_big/kg/wd_id2entity_text.txt'
id2ent = openFileAsDict(entFile)

In [117]:
args = Args()
qa_model = QA_model_EmbedKGQA(tkbc_model, args)
filename = 'models/{dataset_name}/qa_models/{model_file}.ckpt'.format(
    dataset_name=dataset_name,
    # model_file='embedkgqa_dual_frozen_lm_fix_order_ce'
    model_file='temp'
)
print('Loading model from', filename)
qa_model.load_state_dict(torch.load(filename))
print('Loaded qa model from ', filename)
qa_model = qa_model.cuda()

Freezing LM params
Freezing entity/time embeddings
Loading model from models/wikidata_big/qa_models/temp.ckpt
Loaded qa model from  models/wikidata_big/qa_models/temp.ckpt


In [118]:
valid_dataset = QA_Dataset_EmbedKGQA(split='valid', dataset_name=dataset_name)
original_dataset = QA_Dataset_EmbedKGQA(split='valid', dataset_name=dataset_name)

Total questions =  30000
Preparing data for split valid
Total questions =  30000
Preparing data for split valid


In [119]:
def predict(qa_model, dataset, batch_size = 128, split='valid', k=10):
    num_workers = 4
    qa_model.eval()
    eval_log = []
    k_for_reporting = k # not change name in fn signature since named param used in places
    # k_list = [1, 3, 10]
    # k_list = [1, 10]
    k_list = [1, 5]
    max_k = max(k_list)
    eval_log.append("Split %s" % (split))
    print('Evaluating split', split)

    # id = 13799
    ids = [0]
    prepared_data = {}
    for k, v in dataset.prepared_data.items():
        prepared_data[k] = [v[i] for i in ids]
    dataset.prepared_data = prepared_data
    dataset.data = [dataset.data[i] for i in ids]

    # dataset.print_prepared_data()

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=dataset._collate_fn)
    topk_answers = []
    topk_scores = []
    total_loss = 0
    loader = tqdm(data_loader, total=len(data_loader), unit="batches")


    for i_batch, a in enumerate(loader):
        # if size of split is multiple of batch size, we need this
        # todo: is there a more elegant way?
        if i_batch * batch_size == len(dataset.data):
            break
        answers_khot = a[-1] # last one assumed to be target
        scores = qa_model.forward(a)
        sm = torch.nn.Softmax(dim=1)

        scores = sm(scores)
        # scores = torch.nn.functional.normalize(scores, p=2, dim=1)

        for s in scores:
            pred_s, pred = dataset.getAnswersFromScoresWithScores(s, k=max_k)
            topk_answers.append(pred)
            topk_scores.append(pred_s)
        loss = qa_model.loss(scores, answers_khot.cuda().long())
        total_loss += loss.item()
    eval_log.append('Loss %f' % total_loss)
    eval_log.append('Eval batch size %d' % batch_size)

    for i in range(len(dataset.data)):
        question = dataset.data[i]
        predicted_answers = topk_answers[i]
        predicted_scores = topk_scores[i]
        actual_answers = question['answers']

        if question['answer_type'] == 'entity':
            actual_answers = [dataset.getEntityToText(x) for x in actual_answers]
            pa = []
            aa = []
            for a in predicted_answers:
                if 'Q' in str(a): # TODO: hack to check whether entity or time predicted
                    pa.append(dataset.getEntityToText(a))
                else:
                    pa.append(a)
            predicted_answers = pa

            for a in actual_answers:
                if 'Q' in str(a): # TODO: hack to check whether entity or time predicted
                    aa.append(dataset.getEntityToText(a))
                else:
                    aa.append(a)
            actual_answers = aa


        # print(question['paraphrases'][0])
        # print('Actual answers', actual_answers)
        # print('Predicted answers', predicted_answers)
        # print()
        print(question['paraphrases'][0])
        print(question['question'])
        answers_with_scores_text = []
        for pa, ps in zip(predicted_answers, predicted_scores):
            formatted = '{answer} ({score})'.format(answer = pa, score=ps)
            answers_with_scores_text.append(formatted)
        print('Predicted:', ', '.join(answers_with_scores_text))
        print('Actual:', ', '.join([str(x) for x in actual_answers]))
        print()


In [120]:
def getEntities(question_text):
    words = question_text.split(' ')
    entities = []
    for word in words:
        if word[0] == 'Q': # TODO: hack
            entities.append(word)
    return entities

In [140]:
question_text = 'What is the name of the first team that Q1487425 was part of'
entities = getEntities(question_text)
times = []
dataPoint = convertToDataPoint(question_text, entities, times)
data = [dataPoint]
valid_dataset.data = data
valid_dataset.prepared_data = valid_dataset.prepare_data_(data)
# print(valid_dataset[:100])
predict(qa_model, valid_dataset)

Evaluating split valid


100%|██████████| 1/1 [00:07<00:00,  7.05s/batches]

The team which Martin Peters played for in 1972 was
The team which Q505028 played for in 1972 was
Predicted: Kingdom of Hungary (0.08478637039661407), Kingdom of France (0.032511934638023376), Sheffield United F.C. (0.0291510708630085), Ottoman Empire (0.014648595824837685), Kingdom of Bohemia (0.009740407578647137)
Actual: Bobby Clarke






In [122]:
id2ent['Q1543']

'A.C. Milan'

In [123]:
original_dataset.data[7:10]

[{'question': 'When did Q5220937 play their first game',
  'answers': {1997},
  'answer_type': 'time',
  'template': 'When did {head} play their first game',
  'entities': {'Q5220937'},
  'times': set(),
  'relations': {'P54'},
  'type': 'first_last',
  'annotation': {'head': 'Q5220937', 'adj': 'first'},
  'uniq_id': 3015,
  'paraphrases': ['When did Danny Williams play their first game']},
 {'question': "Who last held Q30524718's position",
  'answers': {'Q10652',
   'Q14948811',
   'Q153454',
   'Q16190632',
   'Q16190712',
   'Q16934040',
   'Q17180679',
   'Q17306267',
   'Q180589',
   'Q18211142',
   'Q18921442',
   'Q189947',
   'Q197894',
   'Q19831',
   'Q19871785',
   'Q19871819',
   'Q19871931',
   'Q19874405',
   'Q19874642',
   'Q19880278',
   'Q19882875',
   'Q19882967',
   'Q19883117',
   'Q19956858',
   'Q19957769',
   'Q19957802',
   'Q19957992',
   'Q19979355',
   'Q19979409',
   'Q20054083',
   'Q20055561',
   'Q20127929',
   'Q20128115',
   'Q20647738',
   'Q20647740

In [124]:
valid_dataset.prepared_data

{'question_text': ['What year was Andrea Russotto playing in S.S.C. Napoli'],
 'head': [43265],
 'tail': [47556],
 'time': [125726],
 'answers_arr': [[120143]]}

In [125]:
original_dataset.prepared_data.keys()

dict_keys(['question_text', 'head', 'tail', 'time', 'answers_arr'])

In [126]:
original_dataset.data[0]

{'question': 'What was the award that was awarded to Q980677 for the first time ever',
 'answers': {'Q1967210'},
 'answer_type': 'entity',
 'template': 'What was the award that was awarded to {head} for the first time ever',
 'entities': {'Q980677'},
 'times': set(),
 'relations': {'P166'},
 'type': 'first_last',
 'annotation': {'head': 'Q980677', 'adj': 'first'},
 'uniq_id': 28188,
 'paraphrases': ['What was the award that was awarded to Lage Raho Munna Bhai for the first time ever']}

In [127]:
original_dataset.prepared_data['head'][0]

124930

In [128]:
original_dataset.prepared_data['tail'][0]

124930

In [129]:
original_dataset.prepared_data['time'][0]

125726

In [130]:
original_dataset.prepared_data['answers_arr'][0]

[32445]