In [1]:
import pandas as pd
import numpy as np
import os
from collections import Counter

from tools import train_test

from tqdm import tqdm_notebook as tqdm

%load_ext autoreload
%autoreload 2

In [2]:
model = 'TransE'
dataset = 'FB13'
timestamp = '1527033688'

# model = 'TransE'
# dataset = 'NELL186'
# timestamp = '1526711822'

# model = 'Analogy'
# dataset = 'NELL186'
# timestamp = '1526567410'

benchmark_dataset_path = os.path.join('~/proj', 'XKEc', 'benchmarks', dataset)
model_path = os.path.join('/home/andrey/hdd/proj/XKEc/results', dataset, model, timestamp)

model_path

'/home/andrey/hdd/proj/XKEc/results/FB13/TransE/1527033688'

In [3]:
entities = pd.read_csv(os.path.join(benchmark_dataset_path, 'entity2id.txt'), 
                       skiprows = 1, 
                       sep='\t',
                       names = ['ent', 'id'])

relations = pd.read_csv(os.path.join(benchmark_dataset_path, 'relation2id.txt'), 
                       skiprows = 1, 
                       sep='\t',
                       names = ['rel', 'id'])

valid_positive = pd.read_csv(os.path.join(benchmark_dataset_path, 'valid2id.txt'), 
                       skiprows = 1, 
                       sep=' ',
                       names = ['e1', 'e2', 'rel'])

valid_positive['label'] = 1.0

valid_negative = pd.read_csv(os.path.join(benchmark_dataset_path, 'valid2id_neg.txt'), 
                       skiprows = 1, 
                       sep=' ',
                       names = ['e1', 'e2', 'rel'])
valid_negative['label'] = 0.0


valid_set = valid_positive.append(valid_negative, ignore_index=True)

# valid_set = pd.read_csv(os.path.join(benchmark_dataset_path, 'train2id.txt'), 
#                        skiprows = 1, 
#                        sep=' ',
#                        names = ['e1', 'e2', 'rel'])

# valid_set['label'] = 1.0

valid_set.shape

(11816, 4)

In [4]:
model_info = train_test.read_model_info(model_path)
model_info

{'Unnamed: 0': 0,
 'acc': 0.8254329562187195,
 'bern': 1,
 'cuda_device': 0,
 'dataset_name': 'FB13',
 'ent_neg_rate': 1,
 'hits_10_filter': 0.3477857708930969,
 'hits_10_raw': 0.3447309732437134,
 'hits_1_filter': 0.19449710845947266,
 'hits_1_raw': 0.18996754288673398,
 'hits_3_filter': 0.2726793885231018,
 'hits_3_raw': 0.2689293622970581,
 'k': 100,
 'learning_rate': 0.001,
 'learning_time': 773.4418108463287,
 'log_on': 1,
 'log_print': True,
 'log_type': 'epoch',
 'margin': 1.0,
 'model_name': 'TransE',
 'mr_filter': 8762.4619140625,
 'mr_raw': 13547.830078125,
 'mrr_filter': 0.2488585412502289,
 'mrr_raw': 0.2449391633272171,
 'n_batches': 100,
 'n_epochs': 1000,
 'note': 'following NMM paper, no l2 constr in rel',
 'opt_method': 'RMSProp',
 'rel_neg_rate': 0,
 'score_norm': 'l2',
 'shuffle': 1,
 'test_link_prediction': True,
 'test_triple_class': True,
 'testing_time': 1087.2861058712006,
 'timestamp': '1527033688',
 'work_threads': 8}

In [5]:
con = train_test.restore_model(model_path)

INFO:tensorflow:Restoring parameters from /home/andrey/hdd/proj/XKEc/results/FB13/TransE/1527033688/tf_model/model.vec.tf


In [6]:
t_set = valid_set.copy()
h = np.array(t_set['e1'])
t = np.array(t_set['e2'])
r = np.array(t_set['rel'])
l = np.array(t_set['label'])

In [7]:
# con.init_triple_classification()
res = con.test_step(h, t, r)

In [8]:
t_set['res'] = res

In [9]:
thresholds = con.get_thresholds_dict(r)

In [10]:
t_set['thres'] = t_set['rel'].map(thresholds)

In [11]:
t_set['pred'] = np.where(t_set['res'] < t_set['thres'], 1, 0)

In [12]:
t_set['label'] = t_set['label'].astype(str).str[0]
t_set['check'] = t_set['label'].astype(str) + t_set['pred'].astype(str)

In [13]:
count = Counter(t_set['check'])
count

Counter({'00': 5024, '01': 884, '10': 1157, '11': 4751})

In [14]:
new_acc = 100.00 * (count['00'] + count['11']) / len(t_set)

In [15]:
new_acc

82.72681110358836

In [16]:
con.validation_acc()

0.827268123626709

In [18]:
# final_test = []
# for i in range(len(pred)):
#     final_test.append(str(int(l[i])) + '_' + str(int(pred[i])))
    

In [19]:
# Counter(final_test)

In [17]:
t_set.head(10)

Unnamed: 0,e1,e2,rel,label,res,thres,pred,check
0,23041,67451,5,1,0.03846,0.032075,0,10
1,14257,68833,3,1,0.522386,0.547903,1,11
2,29865,67534,6,1,0.570176,0.580372,1,11
3,38636,67574,3,1,0.553331,0.547903,0,10
4,12548,67408,6,1,0.560282,0.580372,1,11
5,43602,69400,3,1,0.56667,0.547903,0,10
6,20362,67532,6,1,0.576972,0.580372,1,11
7,12860,67426,6,1,0.568382,0.580372,1,11
8,16190,67497,0,1,1.010737,1.059594,1,11
9,56126,67402,5,1,0.028022,0.032075,1,11


In [18]:
n_ent = len(entities)
triple = 1
head = t_set.iloc[triple, 0]
tail = t_set.iloc[triple, 1]
rel = t_set.iloc[triple, 2]

heads = [head] * n_ent
rels = [rel] * n_ent

print('Evaluating triple ({}, {}, {}). The label is {} and the prediction is {}'.format(head, 
                        rel, tail, t_set.iloc[triple, 3], t_set.iloc[triple, 6]))

Evaluating triple (14257, 3, 68833). The label is 1 and the prediction is 1


In [19]:
con.get_true_tails(head, rel, tail, thresholds[rel], k = n_ent)

(46, 524)

In [20]:
pred1 = list(con.predict_tail_entity(head, rel, k = n_ent))
con.calculate_true_triples([head]*n_ent, pred1, [rel]*n_ent)


array([0.4942315 , 0.4955931 , 0.49575502, ..., 0.6625423 , 0.6644705 ,
       0.69270134], dtype=float32)

In [21]:
pred1.index(tail) + 1

46

In [22]:
n_ent = len(entities)

positive_tails = []
positive_heads = []

link_prediction_tail = []
link_prediction_head = []

for i in tqdm(range(len(t_set))):
    head = h[i]
    tail = t[i]
    rel = r[i]
    rank, positives = con.get_true_tails(head, rel, tail, thresholds[rel], k=n_ent)
    link_prediction_tail.append(rank)
    positive_tails.append(positives)
    rank, positives = con.get_true_heads(tail, rel, head, thresholds[rel], k=n_ent)
    link_prediction_head.append(rank)
    positive_heads.append(positives)


HBox(children=(IntProgress(value=0, max=11816), HTML(value=u'')))

KeyboardInterrupt: 

In [26]:
np.array(positive_tails).sum()

814978

In [27]:
np.array(positive_heads).sum()

358446635

In [28]:
pd.Series(link_prediction_head).describe()

count    11816.000000
mean     30664.295785
std      21030.874436
min          1.000000
25%      11948.750000
50%      28203.500000
75%      47353.250000
max      75037.000000
dtype: float64

In [29]:
t_set['lp_tail'] = link_prediction_tail
t_set['pos_tail'] = positive_tails
t_set['lp_head'] = link_prediction_head
t_set['pos_head'] = positive_heads

In [30]:
t_set.head(50)

Unnamed: 0,e1,e2,rel,label,res,thres,pred,check,lp_tail,pos_tail,lp_head,pos_head
0,23041,67451,5,1,0.03846,0.032075,0,10,2,1,43310,1844
1,14257,68833,3,1,0.522386,0.547903,1,11,46,524,42,27644
2,29865,67534,6,1,0.570176,0.580372,1,11,4,40,315,20052
3,38636,67574,3,1,0.553331,0.547903,0,10,32,21,65801,54870
4,12548,67408,6,1,0.560282,0.580372,1,11,1,30,2124,59423
5,43602,69400,3,1,0.56667,0.547903,0,10,140,33,45905,7175
6,20362,67532,6,1,0.576972,0.580372,1,11,5,8,15908,30069
7,12860,67426,6,1,0.568382,0.580372,1,11,1,30,19152,65556
8,16190,67497,0,1,1.010737,1.059594,1,11,1,6,54719,75013
9,56126,67402,5,1,0.028022,0.032075,1,11,1,1,51480,69119


In [31]:
t_set.describe()

Unnamed: 0,e1,e2,rel,res,thres,pred,lp_tail,pos_tail,lp_head,pos_head
count,11816.0,11816.0,11816.0,11816.0,11816.0,11816.0,11816.0,11816.0,11816.0,11816.0
mean,31522.586154,68701.658768,4.61696,0.509281,0.510258,0.476896,175.993653,68.97241,30664.295785,30335.700322
std,19294.487694,1964.394809,2.462351,0.338664,0.34032,0.499487,1723.769054,665.682339,21030.874436,29671.629028
min,3.0,67393.0,0.0,0.019329,0.032075,0.0,1.0,0.0,1.0,0.0
25%,14828.25,67431.0,3.0,0.449587,0.464917,0.0,2.0,1.0,11948.75,1844.0
50%,30256.0,67583.0,5.0,0.553366,0.547903,0.0,14.0,17.0,28203.5,13800.0
75%,48230.25,69324.0,6.0,0.584838,0.580372,1.0,107.0,46.0,47353.25,64908.0
max,67389.0,75032.0,12.0,1.411805,1.319657,1.0,68741.0,41681.0,75037.0,75013.0


In [32]:
t_set[(t_set['label'] == '0') & (t_set['rel'] == 0)].describe()

Unnamed: 0,e1,e2,rel,res,thres,pred,lp_tail,pos_tail,lp_head,pos_head
count,303.0,303.0,303.0,303.0,303.0,303.0,303.0,303.0,303.0,303.0
mean,29937.950495,70570.006601,0.0,1.078833,1.059594,0.158416,58.039604,42.653465,27263.333333,8428.722772
std,19221.300969,2301.088923,0.0,0.023646,0.0,0.365734,36.339255,302.05561,21095.254,18997.335463
min,370.0,67393.0,0.0,0.971118,1.059594,0.0,1.0,2.0,22.0,266.0
25%,12337.0,68209.5,0.0,1.07019,1.059594,0.0,27.0,9.0,9350.0,863.0
50%,28425.0,70244.0,0.0,1.08219,1.059594,0.0,57.0,11.0,20940.0,1554.0
75%,46198.5,72370.0,0.0,1.094376,1.059594,0.0,85.0,15.0,42804.0,4131.0
max,67383.0,74790.0,0.0,1.13899,1.059594,1.0,172.0,3727.0,74920.0,75013.0


In [33]:
t_set.to_csv('false_positive_evaluation_FB13_TransE.csv', sep='\t', index=False)

In [23]:
valid_positive[']

Unnamed: 0,e1,e2,rel,label
0,23041,67451,5,1.0
1,14257,68833,3,1.0
2,29865,67534,6,1.0
3,38636,67574,3,1.0
4,12548,67408,6,1.0
5,43602,69400,3,1.0
6,20362,67532,6,1.0
7,12860,67426,6,1.0
8,16190,67497,0,1.0
9,56126,67402,5,1.0
