In [1]:
import pickle
from tqdm.notebook import tqdm
import os


In [2]:
folder = '/scratche/home/apoorv/transformer-kgc/data'

In [3]:
subgraph_split = 'MetaQA_half'

In [4]:
def readTriples(fname):
    f = open(fname)
    triples = []
    for line in f:
        if line.startswith('predict head:'):
            continue
        line = line.rstrip()
        line = line.replace('predict tail: ', '')
        line = line.split('\t')
        head_rel = line[0]
        tail = line[1]
        head_rel = head_rel.replace(' | ', '|')
        head_rel = head_rel.replace(' |', '|')
        head, rel, _ = head_rel.split('|')
        if rel != 'noop':
            triples.append([head, rel, tail])
    f.close()
    return triples

def getEntities(triples):
    ents = set()
    for t in triples:
        ents.add(t[0])
        ents.add(t[2])
    return ents

def readLines(fname):
    f = open(fname)
    lines = []
    for line in f:
        lines.append(line.rstrip())
    return lines

In [5]:
subgraph_splits = ['MetaQA_half', 'MetaQA_half2', 'MetaQA_half3', 'MetaQA_half_allents',
                  'MetaQA_half_allents2', 'MetaQA_half_allents3']

In [6]:
def getHeadFromQuestion(q):
    h1 = q.split('[')[1]
    head = h1.split(']')[0]
    return head

def readQuestions(fname):
    f = open(fname)
    head_answer_pairs = []
    for line in f:
        q, answer_string = line.rstrip().split('\t')
        head = getHeadFromQuestion(q)
        answers = answer_string.split('|')
        head_answer_pairs.append((head, answers))
    return head_answer_pairs
        

In [117]:
q_file_name = os.path.join(folder, subgraph_split, 'qa_test_3hop.txt')
ha_pairs = readQuestions(q_file_name)

In [118]:
# get templates for questions
# then from templates get reasoning paths
# need to do this to get number of answerable questions for 2,3 hop
# just seeing reachability not enough for 2,3 hop
import copy

def split_qtype(qtype):
    qtype = qtype.split('_to_')
    return qtype

def relation_from_ent_type(ent):
    switcher = {
        'movie': 'error',
        'director': 'directed by',
        'actor': 'starred actors',
        'genre': 'has genre',
        'language': 'in language',
        'tags': 'has tags',
        'writer': 'written by',
        'year': 'release year',
        'tag': 'has tags',
    }
    return switcher.get(ent, 'error')

def qtype_to_path(qtype):
    sq = split_qtype(qtype)
    if len(sq) < 2:
        return []
    path = []
    for i in range(len(sq) - 1):
        start_ent = copy.copy(sq[i])
        end_ent = copy.copy(sq[i+1])
        if start_ent == 'movie':
            is_reverse = False
        else:
            is_reverse = True
            start_ent, end_ent = end_ent, start_ent
        rel = relation_from_ent_type(end_ent)
        if is_reverse:
            rel += ' reverse'
        path.append(rel)
    return path
        
def follow_path_1hop(head, graph, rel):
    return graph[head][rel]

# follow doesn't return head!
# this is ok since for a question, head never the answer
def follow_path(head, graph, path):
    ans_ents = []
    inter_ents = set([head])
    for rel in path:
        ents_for_next_step = set()
        for e in inter_ents:
            ents_1hop = follow_path_1hop(e, graph, rel)
            ents_for_next_step.update(ents_1hop)
        inter_ents = list(set(ents_for_next_step))
    out = set(inter_ents)
    if head in out:
        out.remove(head)
    return out

In [119]:
triples_file = 'data/MetaQA_half_allents/train.txt'
train = readTriples(triples_file)
len(train)

70844

In [120]:
questions_folder = 'data/MetaQA_half'
train[0]

['Lawless Heart', 'directed by', 'Tom Hunsinger']

In [121]:
len(getEntities(train)), len(train)

(43234, 70844)

In [122]:
fname = os.path.join('data/MetaQA', 'train.txt')
full_triples = readTriples(fname)
entities = getEntities(full_triples)
len(entities)

43234

In [123]:
triples_file = 'data/MetaQA_half_allents/train.txt'
train = readTriples(triples_file)
len(train)

70844

In [124]:
relations = set()
for t in full_triples:
    relations.add(t[1])

In [125]:
# manually make relational graph
# graph is a dict(dict(list))
from collections import defaultdict
graph = dict()
triples_for_graph = train
for e in entities:
    e = unidecode(e).rstrip()
    graph[e] = dict()
    for r in relations:
        graph[e][r] = set()
        graph[e][r + ' reverse'] = set()

for t in triples_for_graph:
    head, rel, tail = t
    rel_reverse = rel + ' reverse'
    graph[head][rel].add(tail)
    graph[tail][rel_reverse].add(head)
    

In [126]:
len(getEntities(triples_for_graph))

43234

In [145]:
qtype_folder = 'data/metaqa_qtype'
hops = '1'
qtype_split = 'test'
qtype_file = os.path.join(qtype_folder, hops+'-hop', 'qa_'+ qtype_split + '_qtype.txt')
qtypes = readLines(qtype_file)

q_file_name = os.path.join('data/MetaQA_half', 'qa_{}_{}hop.txt'.format(qtype_split, hops))
ha_pairs = readQuestions(q_file_name)

haq_triples = []
for i in range(len(ha_pairs)):
    item = (unidecode(ha_pairs[i][0]).rstrip(), ha_pairs[i][1], qtypes[i])
    haq_triples.append(item)

In [146]:
id = 0
haq_triples[id], qtype_to_path(haq_triples[id][2] )

(('Gregoire Colin', ['Before the Rain'], 'actor_to_movie'),
 ['starred actors reverse'])

In [147]:
from collections import defaultdict
count = 0
qtype_answerable_dict = defaultdict(list)
for h, a, qt in haq_triples:
    path = qtype_to_path(qt)
    h = unidecode(h).rstrip()
    results = follow_path(h, graph, path)
    is_answerable = 0
    if len(set(results).intersection(set(a))) > 0:
        if len(set(results)) <= len(set(a)): # sanity check
            is_answerable = 1
            count += 1
    qtype_answerable_dict[qt].append(is_answerable)
#     if len(results) > 0:
#         count += 1
count, count/len(haq_triples)

(6739, 0.677490700713783)

In [148]:
for k, v in qtype_answerable_dict.items():
    print(k, round(sum(v)/len(v), 2), round(len(v)/len(haq_triples), 2))

actor_to_movie 0.96 0.09
director_to_movie 0.84 0.06
movie_to_actor 0.79 0.11
movie_to_director 0.52 0.13
movie_to_genre 0.48 0.11
movie_to_language 0.49 0.03
movie_to_tags 0.72 0.09
movie_to_writer 0.66 0.11
movie_to_year 0.46 0.14
tag_to_movie 1.0 0.04
writer_to_movie 0.88 0.09


In [149]:
for k, v in qtype_answerable_dict.items():
    print(k, len(v))

actor_to_movie 879
director_to_movie 553
movie_to_actor 1105
movie_to_director 1301
movie_to_genre 1143
movie_to_language 294
movie_to_tags 846
movie_to_writer 1091
movie_to_year 1420
tag_to_movie 411
writer_to_movie 904


In [150]:
# get model predictions
import pickle
fname = 'scores/scores_test_1hop_half_allents.pickle'
# fname = 'scores_500_base_trie|.pickle'
scores_data = pickle.load(open(fname, 'rb'))

In [151]:
len(scores_data['input_strings'])

9947

In [152]:
scores_data['prediction_strings'][0]

['Before the Rain',
 'Before the Rains',
 'The Princess and the Frog',
 'Between the Rains']

In [153]:
scores_data.keys()

dict_keys(['prediction_strings', 'scores', 'target_strings', 'input_strings'])

In [154]:
qtypes[0]

'actor_to_movie'

In [155]:
scores_data['input_strings'][0]

'predict answer: Gregoire Colin | what does NE appear in |'

In [156]:
def getHeadFromInputString(s):
    s = s.split('|')[0][:-1]
    s = s.replace('predict answer: ', '')
    return s

In [157]:
qtype_correct_dict = defaultdict(list)
correct = 0
for preds, scores, input_string, actuals, qt in zip(scores_data['prediction_strings'], 
                                                    scores_data['scores'],
                                                    scores_data['input_strings'],
                                                    scores_data['target_strings'],
                                                    qtypes):
    is_correct = 0
    ps_pairs = [(p,s.item()) for p, s in zip(preds, scores)]
    ps_pairs.sort(key = lambda x: x[1], reverse=True)
    head = getHeadFromInputString(input_string)
    pred = ps_pairs[0][0]
    if pred == head or pred not in entities:
        pred = ps_pairs[1][0]
    if pred in actuals:
        is_correct = 1
    correct += is_correct
    qtype_correct_dict[qt].append(is_correct)
correct/len(qtypes)

0.7305720317683724

In [158]:
for k, v in qtype_correct_dict.items():
    print(k, round(sum(v)/len(v), 2), round(len(v)/len(haq_triples), 2))

actor_to_movie 0.95 0.09
director_to_movie 0.92 0.06
movie_to_actor 0.77 0.11
movie_to_director 0.64 0.13
movie_to_genre 0.63 0.11
movie_to_language 0.63 0.03
movie_to_tags 0.7 0.09
movie_to_writer 0.8 0.11
movie_to_year 0.45 0.14
tag_to_movie 0.96 0.04
writer_to_movie 0.94 0.09


In [159]:
for k, v in qtype_answerable_dict.items():
    print(k, round(sum(v)/len(v), 2), round(len(v)/len(haq_triples), 2))

actor_to_movie 0.96 0.09
director_to_movie 0.84 0.06
movie_to_actor 0.79 0.11
movie_to_director 0.52 0.13
movie_to_genre 0.48 0.11
movie_to_language 0.49 0.03
movie_to_tags 0.72 0.09
movie_to_writer 0.66 0.11
movie_to_year 0.46 0.14
tag_to_movie 1.0 0.04
writer_to_movie 0.88 0.09


In [160]:
print('Question type    \tAnswerable\tAcc\tPct of total qn')
for key in qtype_answerable_dict.keys():
    graph_answerable = sum(qtype_answerable_dict[key])/len(qtype_answerable_dict[key])
    correct_pct = sum(qtype_correct_dict[key])/len(qtype_correct_dict[key])
    pct_of_questions = (len(qtype_correct_dict[key])/len(haq_triples))*100
    if len(key) < 32:
        key = key + '    '
    x = '{}\t{}\t{}\t{}'.format(key,
                            round(graph_answerable, 2),
                            round(correct_pct, 2),
                            round(pct_of_questions,2)
                           )
#     if graph_answerable > correct_pct:
    print(x)

Question type    	Answerable	Acc	Pct of total qn
actor_to_movie    	0.96	0.95	8.84
director_to_movie    	0.84	0.92	5.56
movie_to_actor    	0.79	0.77	11.11
movie_to_director    	0.52	0.64	13.08
movie_to_genre    	0.48	0.63	11.49
movie_to_language    	0.49	0.63	2.96
movie_to_tags    	0.72	0.7	8.51
movie_to_writer    	0.66	0.8	10.97
movie_to_year    	0.46	0.45	14.28
tag_to_movie    	1.0	0.96	4.13
writer_to_movie    	0.88	0.94	9.09


In [61]:
len('actor_to_movie_to_actor')

23