In [1]:
! touch BuboQA/__init__.py

In [44]:
# load the model
import sys
sys.path.append('./BuboQA/entity_detection/nn')
sys.path.append('./BuboQA/entity_linking/')
sys.path.append('./BuboQA/relation_prediction/nn')
sys.path.append('./BuboQA/evidence_integration')

In [3]:
from BuboQA.entity_detection.nn.args import get_args

In [63]:
import torch
import torch.nn as nn
import time
import os
import numpy as np
from tqdm import tqdm
import pickle
import random
from torchtext import data
from BuboQA.entity_detection.nn.args import get_args
from BuboQA.entity_detection.nn.evaluation import evaluation
from BuboQA.entity_detection.nn.sq_entity_dataset import SQdataset
from BuboQA.relation_prediction.nn.sq_relation_dataset import SQdataset as relSQdataset
from BuboQA.entity_linking import entity_linking
from BuboQA.relation_prediction.nn import relation_prediction as rel_rp
from BuboQA.evidence_integration import evidence_integration as ei
from collections import defaultdict
from BuboQA.evidence_integration.util import clean_uri, processed_text, www2fb, rdf2fb


In [5]:
np.set_printoptions(threshold=np.nan)

In [67]:
# args = get_args()
class Args:
    seed = 42
    cuda = True
    data_dir = './BuboQA/data/processed_simplequestions_dataset'
    results_path = './BuboQA/entity_detection/nn/query_text'
    rel_results_path = './BuboQA/results'
    trained_model = './BuboQA/entity_detection/nn/saved_checkpoints/lstm/id1_best_model.pt'
    rel_trained_model = './BuboQA/relation_prediction/nn/saved_checkpoints/cnn/id1_best_model.pt'
    batch_size = 32
    gpu = 0
    entity_detection_mode = 'LSTM'
    dataset = 'EntityDetection'
    model_type = 'lstm'
    index_ent = 'BuboQA/indexes/entity_2M.pkl'
    query_dir = 'BuboQA/entity_detection/nn/query_text/lstm'
    hits = 100
    output_dir = 'BuboQA/entity_linking/results'
    relation_prediction_mode = 'CNN'
    rel_dataset = 'RelationPrediction'
    ent_type = 'lstm'
    rel_type = 'cnn'
    index_reachpath = 'BuboQA/indexes/reachability_2M.pkl'
    index_degreespath = 'BuboQA/indexes/degrees_2M.pkl'
    data_path = 'BuboQA/data/processed_simplequestions_dataset/amt_test.txt'
    ent_path = 'BuboQA/entity_linking/results/lstm/test-h100.txt'
    rel_path = 'BuboQA/relation_prediction/nn/results/cnn/test.txt'
    wiki_path = 'BuboQA/data/fb2w.nt'
    hits_ent = 50
    hits_rel = 5
    heuristics = True
    output_dir = 'BuboQA/evidence_integration/results'
    
args = Args()

In [7]:
# Set random seed for reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

In [8]:
if not args.cuda:
    args.gpu = -1
if torch.cuda.is_available() and args.cuda:
    print("Note: You are using GPU for training")
    # torch.cuda.set_device(args.gpu)
    torch.cuda.manual_seed(args.seed)
if torch.cuda.is_available() and not args.cuda:
    print("Warning: You have Cuda but not use it. You are using CPU for training.")

TEXT = data.Field(lower=True)
ED = data.Field()
RELATION = data.Field(sequential=False)

Note: You are using GPU for training


In [9]:
TEXT = pickle.load(open('BuboQA/entity_detection/nn/TEXT.pickle', 'rb'))
ED = pickle.load(open('BuboQA/entity_detection/nn/ED.pickle','rb'))
train, dev, test = SQdataset.splits(TEXT, ED, path=args.data_dir, test='amt_test.txt')
RELATION.build_vocab(train, dev) # This can also be loaded via pickle. But for now keeping it static.

In [10]:
# Load old torchtext things.

# TEXT.build_vocab(train, dev, test)
# ED.build_vocab(train, dev, test)

In [11]:
train_iter = data.Iterator(train, batch_size=args.batch_size, device="cuda", train=True, repeat=False,
                                   sort=False, shuffle=True)
dev_iter = data.Iterator(dev, batch_size=args.batch_size, device="cuda", train=False, repeat=False,
                                   sort=False, shuffle=False)
test_iter = data.Iterator(test, batch_size=args.batch_size, device="cuda", train=False, repeat=False,
                                   sort=False, shuffle=False)

In [12]:

model = torch.load(args.trained_model, map_location=lambda storage,location: storage.cuda(args.gpu))
rel_model = torch.load(args.rel_trained_model, map_location=lambda storage,location: storage.cuda(args.gpu))
print(model)
print(rel_model)

EntityDetection(
  (embed): Embedding(61332, 300)
  (lstm): LSTM(300, 300, num_layers=2, dropout=0.3, bidirectional=True)
  (dropout): Dropout(p=0.3)
  (relu): ReLU()
  (hidden2tag): Sequential(
    (0): Linear(in_features=600, out_features=600, bias=True)
    (1): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3)
    (4): Linear(in_features=600, out_features=7, bias=True)
  )
)
RelationPrediction(
  (embed): Embedding(61332, 300)
  (conv1): Conv2d(1, 300, kernel_size=(2, 300), stride=(1, 1), padding=(1, 0))
  (conv2): Conv2d(1, 300, kernel_size=(3, 300), stride=(1, 1), padding=(2, 0))
  (conv3): Conv2d(1, 300, kernel_size=(4, 300), stride=(1, 1), padding=(3, 0))
  (dropout): Dropout(p=0.5)
  (fc1): Linear(in_features=900, out_features=1698, bias=True)
)


In [13]:
if args.dataset == 'EntityDetection':
    index2tag = np.array(ED.vocab.itos)
else:
    print("Wrong Dataset")
    exit(1)

index2word = np.array(TEXT.vocab.itos)
index2tag = np.array(ED.vocab.itos)

results_path = os.path.join(args.results_path, args.entity_detection_mode.lower())
if not os.path.exists(results_path):
    os.makedirs(results_path, exist_ok=True)

rel_results_path = os.path.join(args.rel_results_path, args.relation_prediction_mode.lower())
if not os.path.exists(rel_results_path):
    os.makedirs(rel_results_path, exist_ok=True)

In [14]:
def convert(fileName, idFile, outputFile):
    fin = open(fileName)
    fid = open(idFile)
    fout = open(outputFile, "w")

    for line, line_id in tqdm(zip(fin.readlines(), fid.readlines())):
        query_list = []
        query_text = []
        line = line.strip().split('\t')
        sent = line[0].strip().split()
        pred = line[1].strip().split()
        for token, label in zip(sent, pred):
            if label == 'I':
                query_text.append(token)
            if label == 'O':
                query_text = list(filter(lambda x: x != '<pad>', query_text))
                if len(query_text) != 0:
                    query_list.append(" ".join(list(filter(lambda x:x!='<pad>', query_text))))
                    query_text = []
        query_text = list(filter(lambda x: x != '<pad>', query_text))
        if len(query_text) != 0:
            query_list.append(" ".join(list(filter(lambda x:x!='<pad>', query_text))))
            query_text = []
        if len(query_list) == 0:
            query_list.append(" ".join(list(filter(lambda x:x!='<pad>',sent))))
        fout.write(" %%%% ".join([line_id.strip()]+query_list)+"\n")


def predict(dataset_iter=test_iter, dataset=test, data_name="test"):
    print("Dataset: {}".format(data_name))
    model.eval()
    dataset_iter.init_epoch()

    n_correct = 0
    fname = "{}.txt".format(data_name)
    temp_file = 'tmp'+fname
    results_file = open(temp_file, 'w')

    gold_list = []
    pred_list = []

    for data_batch_idx, data_batch in enumerate(dataset_iter):
        scores = model(data_batch)
        if args.dataset == 'EntityDetection':
            n_correct += torch.sum(torch.sum(torch.max(scores, 1)[1].view(data_batch.ed.size()).data == data_batch.ed.data, dim=1) \
                              == data_batch.ed.size()[0]).item()
            index_tag = np.transpose(torch.max(scores, 1)[1].view(data_batch.ed.size()).cpu().data.numpy())
            tag_array = index2tag[index_tag]
            index_question = np.transpose(data_batch.text.cpu().data.numpy())
            question_array = index2word[index_question]
            gold_list.append(np.transpose(data_batch.ed.cpu().data.numpy()))
            gold_array = index2tag[np.transpose(data_batch.ed.cpu().data.numpy())]
            pred_list.append(index_tag)
            for  question, label, gold in zip(question_array, tag_array, gold_array):
                results_file.write("{}\t{}\t{}\n".format(" ".join(question), " ".join(label), " ".join(gold)))
        else:
            print("Wrong Dataset")
            exit()

    if args.dataset == 'EntityDetection':
        P, R, F = evaluation(gold_list, pred_list, index2tag, type=False)
        print("{} Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format("Dev", 100. * P, 100. * R,
                                                                                     100. * F))
    else:
        print("Wrong dataset")
        exit()
    results_file.flush()
    results_file.close()
    convert(temp_file, os.path.join(args.data_dir, "lineids_{}.txt".format(data_name)), os.path.join(results_path,"query.{}".format(data_name)))
    os.remove(temp_file)

##### Running the entity predicition module

In [15]:
# run the model on the test set and write the output to a file
predict(dataset_iter=test_iter, dataset=test, data_name="amt_test")

Dataset: amt_test


  outputs, (ht, ct) = self.lstm(x)
4487it [00:00, 39712.85it/s]

Dev Precision:  92.333456% Recall:  93.029168% F1 Score:  92.680006%


21687it [00:00, 68986.60it/s]


#### Running the entity Linking module.

In [16]:
model_type = args.model_type.lower()
output_dir = os.path.join(args.output_dir, model_type)
os.makedirs(output_dir, exist_ok=True)
entity_linking.get_stat_inverted_index(args.index_ent)

Total type of text: 4796519
Max Length of entry is 249631, text is ,


In [17]:
entity_linking.entity_linking("amt_test",
                    os.path.join(args.query_dir, "query.amt_test"),
                    os.path.join(args.data_dir, "amt_test.txt"),
                    args.hits,
                    os.path.join(output_dir, "test-h{}.txt".format(args.hits)))

0it [00:00, ?it/s]

Source : BuboQA/entity_detection/nn/query_text/lstm/query.amt_test


21687it [01:17, 279.75it/s]

amt_test
Top1 Entity Linking Accuracy: 0.6616406141928344
Top3 Entity Linking Accuracy: 0.7760409461889611
Top5 Entity Linking Accuracy: 0.8102088808963895
Top10 Entity Linking Accuracy: 0.8484806566145617
Top20 Entity Linking Accuracy: 0.8759625582145986
Top50 Entity Linking Accuracy: 0.9026144694978558
Top100 Entity Linking Accuracy: 0.9197214921381472





#### Running the relation module

In [40]:
def rel_predict(dataset_iter=test_iter, dataset=test, data_name="test", result_path = './BuboQA/result'):
    print("Dataset: {}".format(data_name))
    rel_model.eval()
    dataset_iter.init_epoch()

    n_correct = 0
    fname = "{}.txt".format(data_name)
    results_file = open(os.path.join(results_path, fname), 'w')
    n_retrieved = 0

    fid = open(os.path.join(args.data_dir,"lineids_{}.txt".format(data_name)))
    sent_id = [x.strip() for x in fid.readlines()]

    for data_batch_idx, data_batch in enumerate(dataset_iter):
        scores = rel_model(data_batch)
        if args.rel_dataset == 'RelationPrediction':
            n_correct += torch.sum(torch.max(scores, 1)[1].view(data_batch.relation.size()).data == data_batch.relation.data).item()
            # Get top k
            top_k_scores, top_k_indices = torch.topk(scores, k=args.hits, dim=1, sorted=True)  # shape: (batch_size, k)
            top_k_scores_array = top_k_scores.cpu().data.numpy()
            top_k_indices_array = top_k_indices.cpu().data.numpy()
            top_k_relatons_array = index2tag[top_k_indices_array]
            for i, (relations_row, scores_row) in enumerate(zip(top_k_relatons_array, top_k_scores_array)):
                index = (data_batch_idx * args.batch_size) + i
                example = data_batch.dataset.examples[index]
                for j, (rel, score) in enumerate(zip(relations_row, scores_row)):
                    if (rel == example.relation):
                        label = 1
                        n_retrieved += 1
                    else:
                        label = 0
                    results_file.write(
                        "{} %%%% {} %%%% {} %%%% {}\n".format( sent_id[index], rel, label, score))
        else:
            print("Wrong Dataset")
            exit()
    
    
    if args.rel_dataset == 'RelationPrediction':
        P = 1. * n_correct / len(dataset)
        print("{} Precision: {:10.6f}%".format(data_name, 100. * P))
        print("no. retrieved: {} out of {}".format(n_retrieved, len(dataset)))
        retrieval_rate = 100. * n_retrieved / len(dataset)
        print("{} Retrieval Rate {:10.6f}".format(data_name, retrieval_rate))
    else:
        print("Wrong dataset")
        print(args.rel_dataset)
        

In [41]:
args.rel_dataset

'RelationPrediction'

In [42]:
train, dev, test = relSQdataset.splits(TEXT, RELATION, path=args.data_dir, test='amt_test.txt' )
RELATION.build_vocab(train, dev)
index2tag = np.array(RELATION.vocab.itos)

train_iter = data.Iterator(train, batch_size=args.batch_size, device="cuda", train=True, repeat=False,
                                   sort=False, shuffle=True)
dev_iter = data.Iterator(dev, batch_size=args.batch_size, device="cuda", train=False, repeat=False,
                                   sort=False, shuffle=False)
test_iter = data.Iterator(test, batch_size=args.batch_size, device="cuda", train=False, repeat=False,
                                   sort=False, shuffle=False)
# 

In [43]:
rel_predict(dataset_iter=test_iter, dataset=test, data_name="amt_test")

Dataset: amt_test
amt_test Precision:  82.122931%
no. retrieved: 21390 out of 21687
amt_test Retrieval Rate  98.630516


##### Evidence integration

In [50]:
ent_type = args.ent_type.lower()
rel_type = args.rel_type.lower()
output_dir = os.path.join(args.output_dir, "{}-{}".format(ent_type, rel_type))
os.makedirs(output_dir, exist_ok=True)
index_reach = ei.load_index(args.index_reachpath)
index_degrees = ei.load_index(args.index_degreespath)
mid2wiki = ei.get_mid2wiki(args.wiki_path)

Loading index map from BuboQA/indexes/reachability_2M.pkl
Loading index map from BuboQA/indexes/degrees_2M.pkl
Loading Wiki


In [64]:
# Load predicted MIDs and relations for each question in valid/test set
def get_mids(filename, hits):
    print("Entity Source : {}".format(filename))
    id2mids = defaultdict(list)
    fin = open(filename)
    for line in fin.readlines():
        items = line.strip().split(' %%%% ')
        lineid = items[0]
        cand_mids = items[1:][:hits]
        for mid_entry in cand_mids:
            mid, mid_name, mid_type, score = mid_entry.split('\t')
            id2mids[lineid].append((mid, mid_name, mid_type, float(score)))
    return id2mids

def get_rels(filename, hits):
    print("Relation Source : {}".format(filename))
    id2rels = defaultdict(list)
    fin = open(filename)
    for line in fin.readlines():
        items = line.strip().split(' %%%% ')
        lineid = items[0].strip()
        rel = www2fb(items[1].strip())
        label = items[2].strip()
        score = items[3].strip()
        if len(id2rels[lineid]) < hits:
            id2rels[lineid].append((rel, label, float(score)))
    return id2rels


def get_questions(filename):
    print("getting questions ...")
    id2questions = {}
    id2goldmids = {}
    fin =open(filename)
    for line in fin.readlines():
        items = line.strip().split('\t')
        lineid = items[0].strip()
        mid = items[1].strip()
        question = items[5].strip()
        rel = items[3].strip()
        id2questions[lineid] = (question, rel)
        id2goldmids[lineid] = mid
    return id2questions, id2goldmids

def get_mid2wiki(filename):
    print("Loading Wiki")
    mid2wiki = defaultdict(bool)
    fin = open(filename)
    for line in fin.readlines():
        items = line.strip().split('\t')
        sub = rdf2fb(clean_uri(items[0]))
        mid2wiki[sub] = True
    return mid2wiki

def evidence_integration(data_path, ent_path, rel_path, output_dir, index_reach, index_degrees, mid2wiki, is_heuristics, HITS_ENT, HITS_REL):
    id2questions, id2goldmids = get_questions(data_path)
    id2mids = get_mids(ent_path, HITS_ENT)
    id2rels = get_rels(rel_path, HITS_REL)
    file_base_name = os.path.basename(data_path)
    fout = open(os.path.join(output_dir, file_base_name), 'w')

    id2answers = defaultdict(list)
    found, notfound_both, notfound_mid, notfound_rel = 0, 0, 0, 0
    retrieved, retrieved_top1, retrieved_top2, retrieved_top3 = 0, 0, 0, 0
    lineids_found1 = []
    lineids_found2 = []
    lineids_found3 = []

    # for every lineid
    for line_id in id2goldmids:
        if line_id not in id2mids and line_id not in id2rels:
            notfound_both += 1
            continue
        elif line_id not in id2mids:
            notfound_mid += 1
            continue
        elif line_id not in id2rels:
            notfound_rel += 1
            continue

        found += 1
        question, truth_rel = id2questions[line_id]
        truth_rel = www2fb(truth_rel)
        truth_mid = id2goldmids[line_id]
        mids = id2mids[line_id]
        rels = id2rels[line_id]
        if is_heuristics:
            for (mid, mid_name, mid_type, mid_score) in mids:
                for (rel, rel_label, rel_log_score) in rels:
                    # if this (mid, rel) exists in FB
                    if rel in index_reach[mid]:
                        rel_score = math.exp(float(rel_log_score))
                        comb_score = (float(mid_score)**0.6) * (rel_score**0.1)
                        id2answers[line_id].append((mid, rel, mid_name, mid_type, mid_score, rel_score, comb_score, int(mid2wiki[mid]), int(index_degrees[mid][0])))
                    # I cannot use retrieved here because I use contain different name_type
                    # if mid ==truth_mid and rel == truth_rel:
                    #     retrieved += 1
            id2answers[line_id].sort(key=lambda t: (t[6], t[3],  t[7], t[8]), reverse=True)
        else:
            id2answers[line_id] = [(mids[0][0], rels[0][0])]

        # write to file
        fout.write("{}".format(line_id))
        for answer in id2answers[line_id]:
            mid, rel, mid_name, mid_type, mid_score, rel_score, comb_score, _, _ = answer
            fout.write(" %%%% {}\t{}\t{}\t{}\t{}".format(mid, rel, mid_name, mid_score, rel_score, comb_score))
        fout.write('\n')

        if len(id2answers[line_id]) >= 1 and id2answers[line_id][0][0] == truth_mid \
                and id2answers[line_id][0][1] == truth_rel:
            retrieved_top1 += 1
            retrieved_top2 += 1
            retrieved_top3 += 1
            lineids_found1.append(line_id)
        elif len(id2answers[line_id]) >= 2 and id2answers[line_id][1][0] == truth_mid \
                and id2answers[line_id][1][1] == truth_rel:
            retrieved_top2 += 1
            retrieved_top3 += 1
            lineids_found2.append(line_id)
        elif len(id2answers[line_id]) >= 3 and id2answers[line_id][2][0] == truth_mid \
                and id2answers[line_id][2][1] == truth_rel:
            retrieved_top3 += 1
            lineids_found3.append(line_id)

    print()
    print("found:              {}".format(found / len(id2goldmids) * 100.0))
    print("retrieved at top 1: {}".format(retrieved_top1 / len(id2goldmids) * 100.0))
    print("retrieved at top 2: {}".format(retrieved_top2 / len(id2goldmids) * 100.0))
    print("retrieved at top 3: {}".format(retrieved_top3 / len(id2goldmids) * 100.0))
    #print("retrieved at inf:   {}".format(retrieved / len(id2goldmids) * 100.0))
    fout.close()
    return id2answers

In [68]:
test_answers = ei.evidence_integration(args.data_path, args.ent_path, args.rel_path, output_dir, index_reach, index_degrees, mid2wiki, args.heuristics, args.hits_ent, args.hits_rel)

getting questions ...
Entity Source : BuboQA/entity_linking/results/lstm/test-h100.txt
Relation Source : BuboQA/relation_prediction/nn/results/cnn/test.txt

found:              99.5112279245631
retrieved at top 1: 74.68529533822105
retrieved at top 2: 80.91483377138377
retrieved at top 3: 82.73620141098354
