In [1]:
import gensim
from nltk import word_tokenize
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict


# Function to get the cosine similarity between a relation and query
# Note: Be sure to prepend the relation with ns:
word2vec_model = gensim.models.Word2Vec.load('word2vec_train_dev.dat')
def get_rel_score_word2vecbase(model, rel, query):
    if rel not in word2vec_model.wv:
        return 0.0
    words = word_tokenize(query.lower())
    w_embs = []
    for w in words:
        if w in word2vec_model.wv:
            w_embs.append(word2vec_model.wv[w])
    return np.mean(cosine_similarity(w_embs, [word2vec_model.wv[rel]]))


# Function to load the graph from file
def load_graph(file):
    # Preparing the graph
    graph = defaultdict(list)
    for line in open(file):
        line = eval(line[:-1])
        graph[line[0]].append([line[1], line[2]])
    return graph


# Function to load the queries from file
# Preparing the queries
def load_queries(file):
    queries = []
    for line in open(file):
        line = eval(line[:-1])
        queries.append(line)
    return queries

In [2]:
import gensim
from collections import Counter
# from utils import load_graph, load_queries, get_rel_score_word2vecbase


def find_answer_bfs(model, graph, query, root, theta):
    model = model
    answers = set()
    visited = []
    queue = [root]
    query = query
    while len(queue) != 0:
        first_out = queue.pop(0)
        visited.append(first_out)
        # print('first out : ', first_out)
        neighbours_list = graph.get(first_out[0].strip('ns:'))
        if neighbours_list is not None:
            for neighbour in neighbours_list:
                relation = 'ns:' + neighbour[0]
                # print(relation)
                relevance_score = get_rel_score_word2vecbase(model, relation, query)
                # print(relevance_score)
                if neighbour[1] not in answers and neighbour[1] not in visited and relevance_score > theta:
                    # print('neighbours : ', neighbours_list)
                    queue.append(neighbour[1])
        else:
            answers.add(first_out)
            # print('answer node: ', first_out)
    return answers


def get_topic_tag(topic_code, graph):
    for out_nodes in graph.values():
        for node in out_nodes:
            if node[1] == topic_code:
                return node
    print('Can\'t retrieve node from graph')
    return ''


# Python program to illustrate the intersection
# of two lists in most simple way
def intersection(lst1, lst2):
    lst3 = [value for value in lst1 if value in lst2]
    return lst3


if __name__ == '__main__':
    graph = load_graph('graph.txt');
    queries = load_queries('annotations.txt')
    # for k, v in graph.items():
    #     print('source : ', k,  ' | destinations : ', v)
    word2vec_model = gensim.models.Word2Vec.load('word2vec_train_dev.dat')

    overlap_len = 0
    my_estimate_pool_len = 0
    ground_truth_pool_len = 0
    for query in queries:
        my_estimate_pool = []
        ground_truth_pool = []

        root = query[3][0]
        question = query[1]
        print('\n' + question)
        print('root : ', root)
        my_answers = find_answer_bfs(word2vec_model, graph, question, root, 0.3)
        print('My answers : ', my_answers)
        my_estimate_pool.extend(my_answers)
        real_answer = [q.get('AnswerArgument') for q in query[5]]
        print('Real answers : ', real_answer, '\n')
        ground_truth_pool.extend(real_answer)

        overlap_len += len(intersection(my_estimate_pool, ground_truth_pool))
        my_estimate_pool_len += len(my_estimate_pool)
        ground_truth_pool_len += len(ground_truth_pool)

    precision = overlap_len / my_estimate_pool_len
    print('Precision : ', precision)
    recall = overlap_len / ground_truth_pool_len
    print('Recall : ', recall)
    f1_score = 2 * ((precision * recall) / (precision + recall))
    print('F1_score : ', f1_score)


what time zones are there in the us
root :  ['ns:m.09c7w0', 'ns:location.location.time_zones', '?x']
My answers :  {'m.02lcrv', 'm.027wjl3', 'm.027wj2_', 'm.02hcv8', 'm.02fqwt', 'm.042g7t', 'm.02hczc', 'm.02lcqs', 'm.02lctm'}
Real answers :  ['m.027wj2_', 'm.027wjl3', 'm.02fqwt', 'm.02hcv8', 'm.02hczc', 'm.02lcqs', 'm.02lcrv', 'm.02lctm', 'm.042g7t'] 


what are major exports of the usa
root :  ['ns:m.09c7w0', 'ns:location.statistical_region.major_exports', '?y']
My answers :  {'m.04g4s8w', 'm.04g4s8q', 'm.04g4s90', 'm.04g4s8k'}
Real answers :  ['m.015smg', 'm.03q9wp2', 'm.03qtd_n', 'm.03qtf10'] 


what time is right now in texas
root :  ['ns:m.07b_l', 'ns:location.location.time_zones', '?x']
My answers :  {'m.09c7w0', 'm.02fqwt', 'm.02hczc'}
Real answers :  ['m.02fqwt', 'm.02hczc'] 


what war was george washington associated with
root :  ['ns:m.034rd', 'ns:military.military_commander.military_commands', '?y']
My answers :  {'m.04yykw2', 'm.064xssc', 'm.049xrhp', 'm.04yvvkf', 'm.049y