### Train the inductive link prediction model

In [1]:
data_name = 'WN18RR_v4'
model_id = 'SiaLP_6_new'
lower_bound = 1
upper_bound_path = 10
upper_bound_subg = 3

In [2]:
#difine the names for saving
model_name = 'Model_' + model_id + '_' + data_name
one_hop_model_name = 'One_hop_model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [3]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
import pickle

from collections import defaultdict
from copy import deepcopy
from sklearn.utils import shuffle
from sys import getsizeof

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [4]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = set()
            
            one_hop[s].add((r,t))
            
            if t not in one_hop:
                
                one_hop[t] = set()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            one_hop[t].add((r_inv,s))
            
        #change each set in one_hop to list
        for e in one_hop:
            
            one_hop[e] = list(one_hop[e])

In [5]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, amount_bd=50, size_bd=50, threshold=20000):
        
        self.amount_bd = amount_bd #how many Tuples we choose in one_hop[node] for next recursion
                        
        self.size_bd = size_bd #size bound limit the number of paths to a target entity t
        
        #number of times paths with specific length been performed for recursion
        self.threshold = threshold
        
    '''
    Given an entity s, the function will find the paths from s to other entities, using recursion.
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on existing entities')

        #here is the result dict. Its key is each entity t sharing paths from s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #qualified_t contains the types of t we want to consider,
        #that is, what t will be added to the result set.
        qualified_t = set()

        #under this mode, we will only consider the direct neighbour of s
        if mode == 'direct_neighbour':
        
            for Tuple in one_hop[s]:
            
                t = Tuple[1]
                
                qualified_t.add(t)
        
        #under this mode, we will only consider one specified entity t
        elif mode == 'target_specified':
            
            qualified_t.add(t_input)
        
        #under this mode, we will consider any entity
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                qualified_t.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {s, e1, ..., ek-1, node}
        from s to this node, we will further find the direct neighbor t' of this node. 
        If t' is not an on-path entity (not among s, e1,...ek-1, node), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict):

            #when the current path is within lower_bd and upper_bd, 
            #and the node is among the qualified t, and it has not been fill of paths w.r.t size_limit,
            #we will add this path to the node
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in qualified_t) and (len(res[node]) < self.size_bd):
                
                res[node].add(tuple(path))
                    
            #won't start new recursions if the current path length already reaches upper limit
            #or the number of recursions performed on this length has reached the limit
            if (len(path) < upper_bd) and (count_dict[len(path)] <= self.threshold):
                                
                #temp list is the id list for us to go-over one_hop[node]
                temp_list = [i for i in range(len(one_hop[node]))]
                random.shuffle(temp_list) #so we random-shuffle the list
                
                #only take 20 recursions if there are too many (r,t)
                for i in temp_list[:self.amount_bd]:
                    
                    #obtain tuple of (r,t)
                    Tuple = one_hop[node][i]
                    r, t = Tuple[0], Tuple[1]
                    
                    #add to count_dict even if eventually this step not proceed
                    count_dict[len(path)] += 1
                    
                    #if t not on the path and we not exceed the computation threshold, 
                    #then finally proceed to next recursion
                    if (t not in on_path_en) and (count_dict[len(path)] <= self.threshold):

                        helper(t, path + [r], on_path_en.union({t}), res, qualified_t, 
                               lower_bd, upper_bd, one_hop, count_dict)

        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict)
        
        return(res, count_dict)

In [6]:
train_path = '../data/' + data_name + '/train.txt'
valid_path = '../data/' + data_name + '/valid.txt'
test_path = '../data/' + data_name + '/test.txt'

In [7]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [8]:
#define the dictionaries and sets for load KG
one_hop = dict() 
data = set()
s_t_r = defaultdict(set)

#define the dictionaries, which is shared by initail and inductive train/valid/test
entity2id = dict()
id2entity = dict()
relation2id = dict()
id2relation = dict()

#fill in the sets and dicts
Class_1.load_train_data(train_path, one_hop, data, s_t_r,
                        entity2id, id2entity, relation2id, id2relation)

In [9]:
#define the dictionaries and sets for load KG
one_hop_valid = dict() 
data_valid = set()
s_t_r_valid = defaultdict(set)

#fill in the sets and dicts
Class_1.load_train_data(valid_path, one_hop_valid, data_valid, s_t_r_valid,
                        entity2id, id2entity, relation2id, id2relation)

In [10]:
#define the dictionaries and sets for load KG
one_hop_test = dict() 
data_test = set()
s_t_r_test = defaultdict(set)

#fill in the sets and dicts
Class_1.load_train_data(test_path, one_hop_test, data_test, s_t_r_test,
                        entity2id, id2entity, relation2id, id2relation)

#### Build the path-based siamese neural network structure

We use biLSTM to train on the input path embedding sequence to predict the output embedding or the relation.

In [11]:
# Input layer, using integer to represent each relation type
#note that inputs_path is the path inputs, while inputs_out_re is the output relation inputs
fst_path = keras.Input(shape=(None,), dtype="int32")
scd_path = keras.Input(shape=(None,), dtype="int32")
thd_path = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of paths are not the same
in_embd_var = layers.Embedding(len(relation2id)+1, 300)

# Obtain the embedding
fst_p_embd = in_embd_var(fst_path)
scd_p_embd = in_embd_var(scd_path)
thd_p_embd = in_embd_var(thd_path)

# Embed each integer in a 300-dimensional vector as output
rela_embd = layers.Embedding(len(relation2id)+1, 300)(id_rela)

#add 2 layer bi-directional LSTM
lstm_layer_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_layer_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

#first LSTM layer
fst_lstm_mid = lstm_layer_1(fst_p_embd)
scd_lstm_mid = lstm_layer_1(scd_p_embd)
thd_lstm_mid = lstm_layer_1(thd_p_embd)

#second LSTM layer
fst_lstm_out = lstm_layer_2(fst_lstm_mid)
scd_lstm_out = lstm_layer_2(scd_lstm_mid)
thd_lstm_out = lstm_layer_2(thd_lstm_mid)

#reduce max
fst_reduce_max = tf.reduce_max(fst_lstm_out, axis=1)
scd_reduce_max = tf.reduce_max(scd_lstm_out, axis=1)
thd_reduce_max = tf.reduce_max(thd_lstm_out, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
path_concat = layers.concatenate([fst_reduce_max, scd_reduce_max, thd_reduce_max], axis=-1)

#add dropout on top of the concatenation from all channels
dropout = layers.Dropout(0.25)(path_concat)

#multiply into output embd size by dense layer: (Batch, 300)
path_out_vect = layers.Dense(300, activation='tanh')(dropout)

#remove the time dimension from the output embd since there is only one step
rela_out_embd = tf.reduce_sum(rela_embd, axis=1)

# Normalize the vectors to have unit length
path_out_vect_norm = tf.math.l2_normalize(path_out_vect, axis=-1)
rela_out_embd_norm = tf.math.l2_normalize(rela_out_embd, axis=-1)

# Calculate the dot product
dot_product = layers.Dot(axes=-1)([path_out_vect_norm, rela_out_embd_norm])

#put together the model
model = keras.Model([fst_path, scd_path, thd_path, id_rela], dot_product)

2023-05-15 13:37:33.944270: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [12]:
#config the Adam optimizer 
opt = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['binary_accuracy'])

#### Build the subgraph-based siamese neural network

In [13]:
#each input is an vector with number of relations to be dim:
#each dim represent the existence (1) or not (0) of an out-going relation from the entity
source_path_1 = keras.Input(shape=(None,), dtype="int32")
source_path_2 = keras.Input(shape=(None,), dtype="int32")
source_path_3 = keras.Input(shape=(None,), dtype="int32")
source_path_4 = keras.Input(shape=(None,), dtype="int32")
source_path_5 = keras.Input(shape=(None,), dtype="int32")
source_path_6 = keras.Input(shape=(None,), dtype="int32")

target_path_1 = keras.Input(shape=(None,), dtype="int32")
target_path_2 = keras.Input(shape=(None,), dtype="int32")
target_path_3 = keras.Input(shape=(None,), dtype="int32")
target_path_4 = keras.Input(shape=(None,), dtype="int32")
target_path_5 = keras.Input(shape=(None,), dtype="int32")
target_path_6 = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela_ = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of paths are not the same
#the source and target embedding and separate
in_embd_var_ = layers.Embedding(len(relation2id)+1, 300)

# Obtain the source embeddings
source_embd_1 = in_embd_var_(source_path_1)
source_embd_2 = in_embd_var_(source_path_2)
source_embd_3 = in_embd_var_(source_path_3)
source_embd_4 = in_embd_var_(source_path_4)
source_embd_5 = in_embd_var_(source_path_5)
source_embd_6 = in_embd_var_(source_path_6)

#Obtain the target embeddings
target_embd_1 = in_embd_var_(target_path_1)
target_embd_2 = in_embd_var_(target_path_2)
target_embd_3 = in_embd_var_(target_path_3)
target_embd_4 = in_embd_var_(target_path_4)
target_embd_5 = in_embd_var_(target_path_5)
target_embd_6 = in_embd_var_(target_path_6)

# Embed each integer in a 300-dimensional vector as output
rela_embd_ = layers.Embedding(len(relation2id)+1, 300)(id_rela_)

#add 2 layer bi-directional LSTM network
lstm_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

###source lstm implimentation########
#first LSTM layer
source_mid_1 = lstm_1(source_embd_1)
source_mid_2 = lstm_1(source_embd_2)
source_mid_3 = lstm_1(source_embd_3)
source_mid_4 = lstm_1(source_embd_4)
source_mid_5 = lstm_1(source_embd_5)
source_mid_6 = lstm_1(source_embd_6)

#second LSTM layer
source_out_1 = lstm_2(source_mid_1)
source_out_2 = lstm_2(source_mid_2)
source_out_3 = lstm_2(source_mid_3)
source_out_4 = lstm_2(source_mid_4)
source_out_5 = lstm_2(source_mid_5)
source_out_6 = lstm_2(source_mid_6)

#reduce max
source_max_1 = tf.reduce_max(source_out_1, axis=1)
source_max_2 = tf.reduce_max(source_out_2, axis=1)
source_max_3 = tf.reduce_max(source_out_3, axis=1)
source_max_4 = tf.reduce_max(source_out_4, axis=1)
source_max_5 = tf.reduce_max(source_out_5, axis=1)
source_max_6 = tf.reduce_max(source_out_6, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
source_concat = layers.concatenate([source_max_1, 
                                    source_max_2, 
                                    source_max_3,
                                    source_max_4,
                                    source_max_5,
                                    source_max_6], axis=-1)

#add dropout on top of the concatenation from all channels
source_dropout = layers.Dropout(0.25)(source_concat)

###target lstm implimentation########
#first LSTM layer
target_mid_1 = lstm_1(target_embd_1)
target_mid_2 = lstm_1(target_embd_2)
target_mid_3 = lstm_1(target_embd_3)
target_mid_4 = lstm_1(target_embd_4)
target_mid_5 = lstm_1(target_embd_5)
target_mid_6 = lstm_1(target_embd_6)

#second LSTM layer
target_out_1 = lstm_2(target_mid_1)
target_out_2 = lstm_2(target_mid_2)
target_out_3 = lstm_2(target_mid_3)
target_out_4 = lstm_2(target_mid_4)
target_out_5 = lstm_2(target_mid_5)
target_out_6 = lstm_2(target_mid_6)

#reduce max
target_max_1 = tf.reduce_max(target_out_1, axis=1)
target_max_2 = tf.reduce_max(target_out_2, axis=1)
target_max_3 = tf.reduce_max(target_out_3, axis=1)
target_max_4 = tf.reduce_max(target_out_4, axis=1)
target_max_5 = tf.reduce_max(target_out_5, axis=1)
target_max_6 = tf.reduce_max(target_out_6, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
target_concat = layers.concatenate([target_max_1, 
                                    target_max_2, 
                                    target_max_3,
                                    target_max_4,
                                    target_max_5,
                                    target_max_6], axis=-1)

#add dropout on top of the concatenation from all channels
target_dropout = layers.Dropout(0.25)(target_concat)

#further concatenate source and target output embeddings: (Batch, 1800)
final_concat = layers.concatenate([source_dropout, target_dropout], axis=-1)

#multiply into output embd size by dense layer: (Batch, 300)
out_vect = layers.Dense(300, activation='tanh')(final_concat)

#remove the time dimension from the output embd since there is only one step
rela_out_embd_ = tf.reduce_sum(rela_embd_, axis=1)

# Normalize the vectors to have unit length
out_vect_norm = tf.math.l2_normalize(out_vect, axis=-1)
rela_out_embd_norm_ = tf.math.l2_normalize(rela_out_embd_, axis=-1)

# Calculate the dot product
dot_product_ = layers.Dot(axes=-1)([out_vect_norm, rela_out_embd_norm_])

#put together the model
model_2 = keras.Model([source_path_1, source_path_2, source_path_3, source_path_4,
                       source_path_5, source_path_6,
                       target_path_1, target_path_2, target_path_3, target_path_4, 
                       target_path_5, target_path_6,
                       id_rela_], dot_product_)

In [14]:
#config the Adam optimizer 
opt_ = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model_2.compile(loss='binary_crossentropy', optimizer=opt_, metrics=['binary_accuracy'])

### Build the big-batch for path-based model
We will build the big-batch for the path-based model training. That is, we will build three list to store three paths, respectively.

In order to reduce computational complexity, we will run the path-finding algorithm for each entity e in the dataset before the training. That is, for each entity e, we will have two dictionaries. Dict 1 stores the paths between e and any other entities in the dataset. Will Dict 2 stores the paths between e and its direct neighbors. The two dicts will be used and invariant throughout the training.

* At each step, three different paths between two entities s and t are selected. Each path is append to one of the list. 
* If this step is for positive samples, the existing relation r will be selected between s and t. If there are more than one relation from s to t, we randomly choose one. Also, the label list will be appended 1.
* If this step is for negative samples, one relation that does not exist between s and t will be selected randomly and append to the relation list. Also, the label list will be appended 0.
* In practice, the positive step is always fallowed by a negative step. The same paths in the positive step will be used in the next negative step, while the relation is a negative one chosen in the above way.
* We do this until the length limit is reached.

**For relation prediciton, we will only need to train using (s,r,t) triple. (t,r-1,s) is not necessary and hence not included in training.**

In [15]:
#function to build the big batche for path-based training
def build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    count = 0
    for s in existing_ids:
        
        #impliment the path finding algorithm to find paths between s and t
        result, length_dict = Class_2.obtain_paths('direct_neighbour', s, 'nb', lower_bd, upper_bd, one_hop)
        
        for iteration in range(10):

            #proceed only if at least three paths are between s and t
            for t in result:

                if len(s_t_r[(s,t)]) == 0:

                    raise ValueError(s,t,id2entity[s], id2entity[t])

                #we are only interested in forward link in relation prediciton
                ini_r_list = list()

                #obtain initial relations between s and t
                for r in s_t_r[(s,t)]:
                    if r % 2 == 0:#initial relation id is always an even number
                        ini_r_list.append(r)

                #if there exist more than three paths between s and t, 
                #and inital connection between s and t exists,
                #and not every r in the relation dictionary exists between s and t (although this is rare)
                #we then proceed
                if len(result[t]) >= 3 and len(ini_r_list) > 0 and len(ini_r_list) < int(num_ini_r):

                    #obtain the list form of all the paths from s to t
                    temp_path_list = list(result[t])

                    temp_pair = random.sample(temp_path_list, 3)

                    path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

                    #####positive#####################
                    #append the paths: note that we add the space holder id at the end of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    r = random.choice(ini_r_list)
                    x_r_list.append([r])
                    y_list.append(1.)

                    #####negative#####################
                    #append the paths: note that we add the space holder id at the end
                    #of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    neg_r_list = list(ini_r_id_set.difference(set(ini_r_list)))
                    r_ran = random.choice(neg_r_list)
                    x_r_list.append([r_ran])
                    y_list.append(0.)
        
        count += 1
        if count % 100 == 0:
            print('generating big-batches for path-based model', count, len(existing_ids))

### Build the big-batch for the subgraph-based network training

Again, to reduce computational complexity, we store the subgraph of each entity e at the biginning.

* At each step, we will select one triple (s,r,t) from the dataset. Then, reaching out paths of s and t is generated respectively according to their out-going relations.
* We will select three paths for each of source and target entity. Add them to the corresponding list.
* If this is a positive sample step, the id of relation r is appended to the relation list.
* If this is a negative sample step, the id of a random relation is appended to the relation lsit.
* Similarly, one negative sample step always follows one positive step. The one-hop vectors from the previous positve step is used again for the negative step.

In [16]:
#Again, it is too slow to run the path-finding algorithm again and again on the complete FB15K-237
#Instead, we will find the subgraph for each entity once.
#then in the subgraph based training, the subgraphs are stored and used for multiple times
def store_subgraph_dicts(lower_bd, upper_bd, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
    
    #the ids to start path finding
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    #Dict stores the subgraph for each entity
    Dict_1 = dict()
    
    count = 0
    for s in existing_ids:
        
        path_set = set()
            
        result, length_dict = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)

        for t_ in result:
            for path in result[t_]:
                path_set.add(path)

        del(result, length_dict)
        
        path_list = list(path_set)
        
        path_select = random.sample(path_list, min(len(path_list), 100))
            
        Dict_1[s] = deepcopy(path_select)
        
        count += 1
        if count % 100 == 0:
            print('generating and storing paths for the path-based model', count, len(existing_ids))
        
    return(Dict_1)

In [17]:
#function to build the big-batch for one-hope neighbor training
def build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_s_list, x_t_list, x_r_list, y_list, Dict,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
        
    #if an entity has at least three out-stretching paths, it is a qualified one
    qualified = set()
    for e in Dict:
        if len(Dict[e]) >= 6:
            qualified.add(e)
    qualified = list(qualified)
    
    data = list(data)
    
    for iteration in range(10):

        data = shuffle(data)

        for i_0 in range(len(data)):

            triple = data[i_0]

            s, r, t = triple[0], triple[1], triple[2] #obtain entities and relation IDs

            if s in qualified and t in qualified:

                #obtain the path list for true entities
                path_s, path_t = list(Dict[s]), list(Dict[t])

                #####positive step###########
                #randomly obtain three paths for true entities
                temp_s = random.sample(path_s, 6)
                temp_t = random.sample(path_t, 6)
                s_p_1, s_p_2, s_p_3, s_p_4, s_p_5, s_p_6 = temp_s[0], temp_s[1], temp_s[2], temp_s[3], temp_s[4], temp_s[5]
                t_p_1, t_p_2, t_p_3, t_p_4, t_p_5, t_p_6 = temp_t[0], temp_t[1], temp_t[2], temp_t[3], temp_t[4], temp_t[5]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                x_s_list['4'].append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                x_s_list['5'].append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                x_s_list['6'].append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                x_t_list['4'].append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))
                x_t_list['5'].append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))
                x_t_list['6'].append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative step for relation###########
                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                x_s_list['4'].append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                x_s_list['5'].append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                x_s_list['6'].append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                x_t_list['4'].append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))
                x_t_list['5'].append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))
                x_t_list['6'].append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))

                #append relation
                neg_r_list = list(ini_r_id_set.difference({r}))
                r_ran = random.choice(neg_r_list)
                x_r_list.append([r_ran])
                y_list.append(0.)
                
                ##############################################
                ##############################################
                #randomly choose two negative sampled entities
                s_ran = random.choice(qualified)
                t_ran = random.choice(qualified)

                #obtain the path list for random entities
                path_s_ran, path_t_ran = list(Dict[s_ran]), list(Dict[t_ran])
                
                #####positive step#################
                #Again: randomly obtain three paths
                temp_s = random.sample(path_s, 6)
                temp_t = random.sample(path_t, 6)
                s_p_1, s_p_2, s_p_3, s_p_4, s_p_5, s_p_6 = temp_s[0], temp_s[1], temp_s[2], temp_s[3], temp_s[4], temp_s[5]
                t_p_1, t_p_2, t_p_3, t_p_4, t_p_5, t_p_6 = temp_t[0], temp_t[1], temp_t[2], temp_t[3], temp_t[4], temp_t[5]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                x_s_list['4'].append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                x_s_list['5'].append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                x_s_list['6'].append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                x_t_list['4'].append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))
                x_t_list['5'].append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))
                x_t_list['6'].append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative for source entity###########
                #randomly obtain three paths
                temp_s = random.sample(path_s_ran, 6)
                s_p_1, s_p_2, s_p_3, s_p_4, s_p_5, s_p_6 = temp_s[0], temp_s[1], temp_s[2], temp_s[3], temp_s[4], temp_s[5]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                x_s_list['4'].append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                x_s_list['5'].append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                x_s_list['6'].append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                x_t_list['4'].append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))
                x_t_list['5'].append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))
                x_t_list['6'].append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(0.)

                #####positive step###########
                #Again: randomly obtain three paths
                temp_s = random.sample(path_s, 6)
                temp_t = random.sample(path_t, 6)
                s_p_1, s_p_2, s_p_3, s_p_4, s_p_5, s_p_6 = temp_s[0], temp_s[1], temp_s[2], temp_s[3], temp_s[4], temp_s[5]
                t_p_1, t_p_2, t_p_3, t_p_4, t_p_5, t_p_6 = temp_t[0], temp_t[1], temp_t[2], temp_t[3], temp_t[4], temp_t[5]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                x_s_list['4'].append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                x_s_list['5'].append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                x_s_list['6'].append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                x_t_list['4'].append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))
                x_t_list['5'].append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))
                x_t_list['6'].append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative for target entity###########
                #randomly obtain three paths
                temp_t = random.sample(path_t_ran, 6)
                t_p_1, t_p_2, t_p_3, t_p_4, t_p_5, t_p_6 = temp_t[0], temp_t[1], temp_t[2], temp_t[3], temp_t[4], temp_t[5]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                x_s_list['4'].append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                x_s_list['5'].append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                x_s_list['6'].append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                x_t_list['4'].append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))
                x_t_list['5'].append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))
                x_t_list['6'].append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(0.)

            if i_0 % 2000 == 0:
                print('generating big-batches for subgraph-based model', i_0, len(data), iteration)

### Start Training: load the KG and call classes

Here, we use the validation set to see the training efficiency. That is, we use the validation to check whether the true relation between entities can be predicted by paths.

The trick is: in validation, we have to use the same relation ID and entity ID as in the training. But we don't want to use the links in training anymore. That is, in validation, we want to use (and update if necessary) entity2id, id2entity, relation2id and id2relation. But we want to use new one_hop, data, data_ and s_t_r for validation set. Then, path-finding will also be based on new one_hop.


In [18]:
model_name

'Model_SiaLP_6_new_WN18RR_v4'

In [19]:
one_hop_model_name

'One_hop_model_SiaLP_6_new_WN18RR_v4'

In [20]:
ids_name

'IDs_SiaLP_6_new_WN18RR_v4'

In [21]:
#first, we save the relation and ids
Dict = dict()

#save training data
Dict['one_hop'] = one_hop
Dict['data'] = data
Dict['s_t_r'] = s_t_r

#save valid data
Dict['one_hop_valid'] = one_hop_valid
Dict['data_valid'] = data_valid
Dict['s_t_r_valid'] = s_t_r_valid

#save test data
Dict['one_hop_test'] = one_hop_test
Dict['data_test'] = data_test
Dict['s_t_r_test'] = s_t_r_test

#save shared dictionaries
Dict['entity2id'] = entity2id
Dict['id2entity'] = id2entity
Dict['relation2id'] = relation2id
Dict['id2relation'] = id2relation

with open('../weight_bin/' + ids_name + '.pickle', 'wb') as handle:
    pickle.dump(Dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [22]:
###train the path-based model
lower_bd = lower_bound
upper_bd = upper_bound_path
num_epoch = 10
batch_size = 32
        
#define the training lists
train_p_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, list(), list()

#define the validation lists
valid_p_list, valid_r_list, valid_y_list = {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      train_p_list, train_r_list, train_y_list,
                      relation2id, entity2id, id2relation, id2entity)

#fill in the validation array list
build_big_batches_path(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                      valid_p_list, valid_r_list, valid_y_list,
                      relation2id, entity2id, id2relation, id2entity)    

#######################################
###do the training#####################
#sometimes the validation dataset is so small so sparse, 
#which cannot find three paths between any pair of s and t.
#in such a case, we will divide the training big-batch into train and valid
if len(valid_y_list) >= 100:
    #generate the input arrays
    x_train_1 = np.asarray(train_p_list['1'], dtype='int')
    x_train_2 = np.asarray(train_p_list['2'], dtype='int')
    x_train_3 = np.asarray(train_p_list['3'], dtype='int')
    x_train_r = np.asarray(train_r_list, dtype='int')
    y_train = np.asarray(train_y_list, dtype='int')

    #generate the validation arrays
    x_valid_1 = np.asarray(valid_p_list['1'], dtype='int')
    x_valid_2 = np.asarray(valid_p_list['2'], dtype='int')
    x_valid_3 = np.asarray(valid_p_list['3'], dtype='int')
    x_valid_r = np.asarray(valid_r_list, dtype='int')
    y_valid = np.asarray(valid_y_list, dtype='int')

else:
    split = int(len(train_y_list)*0.8)
    #generate the input arrays
    x_train_1 = np.asarray(train_p_list['1'][:split], dtype='int')
    x_train_2 = np.asarray(train_p_list['2'][:split], dtype='int')
    x_train_3 = np.asarray(train_p_list['3'][:split], dtype='int')
    x_train_r = np.asarray(train_r_list[:split], dtype='int')
    y_train = np.asarray(train_y_list[:split], dtype='int')

    #generate the validation arrays
    x_valid_1 = np.asarray(train_p_list['1'][split:], dtype='int')
    x_valid_2 = np.asarray(train_p_list['2'][split:], dtype='int')
    x_valid_3 = np.asarray(train_p_list['3'][split:], dtype='int')
    x_valid_r = np.asarray(train_r_list[split:], dtype='int')
    y_valid = np.asarray(train_y_list[split:], dtype='int')

#do the training
model.fit([x_train_1, x_train_2, x_train_3, x_train_r], y_train, 
          validation_data=([x_valid_1, x_valid_2, x_valid_3, x_valid_r], y_valid),
          batch_size=batch_size, epochs=num_epoch)   

# Save model and weights
add_h5 = model_name + '.h5'
save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path = os.path.join(save_dir, add_h5)
model.save(model_path)
print('Save model')
del(model)

generating big-batches for path-based model 100 3861
generating big-batches for path-based model 200 3861
generating big-batches for path-based model 300 3861
generating big-batches for path-based model 400 3861
generating big-batches for path-based model 500 3861
generating big-batches for path-based model 600 3861
generating big-batches for path-based model 700 3861
generating big-batches for path-based model 800 3861
generating big-batches for path-based model 900 3861
generating big-batches for path-based model 1000 3861
generating big-batches for path-based model 1100 3861
generating big-batches for path-based model 1200 3861
generating big-batches for path-based model 1300 3861
generating big-batches for path-based model 1400 3861
generating big-batches for path-based model 1500 3861
generating big-batches for path-based model 1600 3861
generating big-batches for path-based model 1700 3861
generating big-batches for path-based model 1800 3861
generating big-batches for path-based

In [23]:
###train the subgraph-based model
lower_bd = lower_bound
upper_bd = upper_bound_subg
num_epoch = 10
batch_size = 32

Dict_train = store_subgraph_dicts(lower_bd, upper_bd, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity)

Dict_valid = store_subgraph_dicts(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                         relation2id, entity2id, id2relation, id2entity)
        
#define the training lists
train_s_list, train_t_list, train_r_list, train_y_list = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': []}, {'1': [], '2': [], '3': [], '4': [], '5': [], '6': []}, list(), list()

#define the validation lists
valid_s_list, valid_t_list, valid_r_list, valid_y_list = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': []}, {'1': [], '2': [], '3': [], '4': [], '5': [], '6': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      train_s_list, train_t_list, train_r_list, train_y_list, Dict_train,
                      relation2id, entity2id, id2relation, id2entity)

#fill in the validation array list
build_big_batches_subgraph(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                      valid_s_list, valid_t_list, valid_r_list, valid_y_list, Dict_valid,
                      relation2id, entity2id, id2relation, id2entity)    

#######################################
###do the training#####################
#sometimes the validation dataset is so small so sparse, 
#which cannot find three paths between any pair of s and t.
#in such a case, we will divide the training big-batch into train and valid
if len(valid_y_list) >= 100:
    #generate the input arrays
    x_train_s_1 = np.asarray(train_s_list['1'], dtype='int')
    x_train_s_2 = np.asarray(train_s_list['2'], dtype='int')
    x_train_s_3 = np.asarray(train_s_list['3'], dtype='int')
    x_train_s_4 = np.asarray(train_s_list['4'], dtype='int')
    x_train_s_5 = np.asarray(train_s_list['5'], dtype='int')
    x_train_s_6 = np.asarray(train_s_list['6'], dtype='int')

    x_train_t_1 = np.asarray(train_t_list['1'], dtype='int')
    x_train_t_2 = np.asarray(train_t_list['2'], dtype='int')
    x_train_t_3 = np.asarray(train_t_list['3'], dtype='int')
    x_train_t_4 = np.asarray(train_t_list['4'], dtype='int')
    x_train_t_5 = np.asarray(train_t_list['5'], dtype='int')
    x_train_t_6 = np.asarray(train_t_list['6'], dtype='int')

    x_train_r = np.asarray(train_r_list, dtype='int')
    y_train = np.asarray(train_y_list, dtype='int')

    #generate the validation arrays
    x_valid_s_1 = np.asarray(valid_s_list['1'], dtype='int')
    x_valid_s_2 = np.asarray(valid_s_list['2'], dtype='int')
    x_valid_s_3 = np.asarray(valid_s_list['3'], dtype='int')
    x_valid_s_4 = np.asarray(valid_s_list['4'], dtype='int')
    x_valid_s_5 = np.asarray(valid_s_list['5'], dtype='int')
    x_valid_s_6 = np.asarray(valid_s_list['6'], dtype='int')

    x_valid_t_1 = np.asarray(valid_t_list['1'], dtype='int')
    x_valid_t_2 = np.asarray(valid_t_list['2'], dtype='int')
    x_valid_t_3 = np.asarray(valid_t_list['3'], dtype='int')
    x_valid_t_4 = np.asarray(valid_t_list['4'], dtype='int')
    x_valid_t_5 = np.asarray(valid_t_list['5'], dtype='int')
    x_valid_t_6 = np.asarray(valid_t_list['6'], dtype='int')

    x_valid_r = np.asarray(valid_r_list, dtype='int')
    y_valid = np.asarray(valid_y_list, dtype='int')

else:
    split = int(len(train_y_list)*0.8)
    #generate the input arrays
    x_train_s_1 = np.asarray(train_s_list['1'][:split], dtype='int')
    x_train_s_2 = np.asarray(train_s_list['2'][:split], dtype='int')
    x_train_s_3 = np.asarray(train_s_list['3'][:split], dtype='int')
    x_train_s_4 = np.asarray(train_s_list['4'][:split], dtype='int')
    x_train_s_5 = np.asarray(train_s_list['5'][:split], dtype='int')
    x_train_s_6 = np.asarray(train_s_list['6'][:split], dtype='int')

    x_train_t_1 = np.asarray(train_t_list['1'][:split], dtype='int')
    x_train_t_2 = np.asarray(train_t_list['2'][:split], dtype='int')
    x_train_t_3 = np.asarray(train_t_list['3'][:split], dtype='int')
    x_train_t_4 = np.asarray(train_t_list['4'][:split], dtype='int')
    x_train_t_5 = np.asarray(train_t_list['5'][:split], dtype='int')
    x_train_t_6 = np.asarray(train_t_list['6'][:split], dtype='int')

    x_train_r = np.asarray(train_r_list[:split], dtype='int')
    y_train = np.asarray(train_y_list[:split], dtype='int')

    #generate the validation arrays
    x_valid_s_1 = np.asarray(train_s_list['1'][split:], dtype='int')
    x_valid_s_2 = np.asarray(train_s_list['2'][split:], dtype='int')
    x_valid_s_3 = np.asarray(train_s_list['3'][split:], dtype='int')
    x_valid_s_4 = np.asarray(train_s_list['4'][split:], dtype='int')
    x_valid_s_5 = np.asarray(train_s_list['5'][split:], dtype='int')
    x_valid_s_6 = np.asarray(train_s_list['6'][split:], dtype='int')

    x_valid_t_1 = np.asarray(train_t_list['1'][split:], dtype='int')
    x_valid_t_2 = np.asarray(train_t_list['2'][split:], dtype='int')
    x_valid_t_3 = np.asarray(train_t_list['3'][split:], dtype='int')
    x_valid_t_4 = np.asarray(train_t_list['4'][split:], dtype='int')
    x_valid_t_5 = np.asarray(train_t_list['5'][split:], dtype='int')
    x_valid_t_6 = np.asarray(train_t_list['6'][split:], dtype='int')

    x_valid_r = np.asarray(train_r_list[split:], dtype='int')
    y_valid = np.asarray(train_y_list[split:], dtype='int')

#do the training
model_2.fit([x_train_s_1, x_train_s_2, x_train_s_3, x_train_s_4, x_train_s_5, x_train_s_6,
             x_train_t_1, x_train_t_2, x_train_t_3, x_train_t_4, x_train_t_5, x_train_t_6,
             x_train_r], y_train, 
          validation_data=([x_valid_s_1, x_valid_s_2, x_valid_s_3, x_valid_s_4, x_valid_s_5, x_valid_s_6,
                            x_valid_t_1, x_valid_t_2, x_valid_t_3, x_valid_t_4, x_valid_t_5, x_valid_t_6,
                            x_valid_r], y_valid),
          batch_size=batch_size, epochs=num_epoch)

# Save model and weights
one_hop_add_h5 = one_hop_model_name + '.h5'
one_hop_save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(one_hop_save_dir):
    os.makedirs(one_hop_save_dir)
one_hop_model_path = os.path.join(one_hop_save_dir, one_hop_add_h5)
model_2.save(one_hop_model_path)
print('Save model')
del(model_2, Dict_train, Dict_valid)

generating and storing paths for the path-based model 100 3861
generating and storing paths for the path-based model 200 3861
generating and storing paths for the path-based model 300 3861
generating and storing paths for the path-based model 400 3861
generating and storing paths for the path-based model 500 3861
generating and storing paths for the path-based model 600 3861
generating and storing paths for the path-based model 700 3861
generating and storing paths for the path-based model 800 3861
generating and storing paths for the path-based model 900 3861
generating and storing paths for the path-based model 1000 3861
generating and storing paths for the path-based model 1100 3861
generating and storing paths for the path-based model 1200 3861
generating and storing paths for the path-based model 1300 3861
generating and storing paths for the path-based model 1400 3861
generating and storing paths for the path-based model 1500 3861
generating and storing paths for the path-based m

### Result on the testset for inductive link prediction

We use the testset for inductive link prediction.

In [1]:
data_name = 'WN18RR_v4'
model_id = 'SiaLP_6_new'
lower_bound = 1
upper_bound_path = 10
upper_bound_subg = 3

In [2]:
#difine the names for saving
model_name = 'Model_' + model_id + '_' + data_name
one_hop_model_name = 'One_hop_model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [3]:
ids_name

'IDs_SiaLP_6_new_WN18RR_v4'

In [4]:
one_hop_model_name

'One_hop_model_SiaLP_6_new_WN18RR_v4'

In [5]:
model_name

'Model_SiaLP_6_new_WN18RR_v4'

In [6]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
import pickle

from collections import defaultdict
from copy import deepcopy
from sklearn.utils import shuffle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [7]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = set()
            
            one_hop[s].add((r,t))
            
            if t not in one_hop:
                
                one_hop[t] = set()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            one_hop[t].add((r_inv,s))
            
        #change each set in one_hop to list
        for e in one_hop:
            
            one_hop[e] = list(one_hop[e])

In [8]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, amount_bd=50, size_bd=50, threshold=20000):
        
        self.amount_bd = amount_bd #how many Tuples we choose in one_hop[node] for next recursion
                        
        self.size_bd = size_bd #size bound limit the number of paths to a target entity t
        
        #number of times paths with specific length been performed for recursion
        self.threshold = threshold
        
    '''
    Given an entity s, the function will find the paths from s to other entities, using recursion.
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on existing entities')

        #here is the result dict. Its key is each entity t sharing paths from s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #qualified_t contains the types of t we want to consider,
        #that is, what t will be added to the result set.
        qualified_t = set()

        #under this mode, we will only consider the direct neighbour of s
        if mode == 'direct_neighbour':
        
            for Tuple in one_hop[s]:
            
                t = Tuple[1]
                
                qualified_t.add(t)
        
        #under this mode, we will only consider one specified entity t
        elif mode == 'target_specified':
            
            qualified_t.add(t_input)
        
        #under this mode, we will consider any entity
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                qualified_t.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {s, e1, ..., ek-1, node}
        from s to this node, we will further find the direct neighbor t' of this node. 
        If t' is not an on-path entity (not among s, e1,...ek-1, node), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict):

            #when the current path is within lower_bd and upper_bd, 
            #and the node is among the qualified t, and it has not been fill of paths w.r.t size_limit,
            #we will add this path to the node
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in qualified_t) and (len(res[node]) < self.size_bd):
                
                res[node].add(tuple(path))
                    
            #won't start new recursions if the current path length already reaches upper limit
            #or the number of recursions performed on this length has reached the limit
            if (len(path) < upper_bd) and (count_dict[len(path)] <= self.threshold):
                                
                #temp list is the id list for us to go-over one_hop[node]
                temp_list = [i for i in range(len(one_hop[node]))]
                random.shuffle(temp_list) #so we random-shuffle the list
                
                #only take 20 recursions if there are too many (r,t)
                for i in temp_list[:self.amount_bd]:
                    
                    #obtain tuple of (r,t)
                    Tuple = one_hop[node][i]
                    r, t = Tuple[0], Tuple[1]
                    
                    #add to count_dict even if eventually this step not proceed
                    count_dict[len(path)] += 1
                    
                    #if t not on the path and we not exceed the computation threshold, 
                    #then finally proceed to next recursion
                    if (t not in on_path_en) and (count_dict[len(path)] <= self.threshold):

                        helper(t, path + [r], on_path_en.union({t}), res, qualified_t, 
                               lower_bd, upper_bd, one_hop, count_dict)

        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict)
        
        return(res, count_dict)

In [9]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [10]:
#load ids and relation/entity dicts
with open('../weight_bin/' + ids_name + '.pickle', 'rb') as handle:
    Dict = pickle.load(handle)
    
#save training data
one_hop = Dict['one_hop']
data = Dict['data']
s_t_r = Dict['s_t_r']

#save valid data
one_hop_valid = Dict['one_hop_valid']
data_valid = Dict['data_valid']
s_t_r_valid = Dict['s_t_r_valid']

#save test data
one_hop_test = Dict['one_hop_test']
data_test = Dict['data_test']
s_t_r_test = Dict['s_t_r_test']

#save shared dictionaries
entity2id = Dict['entity2id']
id2entity = Dict['id2entity']
relation2id = Dict['relation2id']
id2relation = Dict['id2relation']

#we want to keep the initial entity/relation dicts before adding new entities
entity2id_ini = deepcopy(entity2id)
id2entity_ini = deepcopy(id2entity)
relation2id_ini = deepcopy(relation2id)
id2relation_ini = deepcopy(id2relation)

num_r = len(id2relation)
num_r

18

In [11]:
ids_name

'IDs_SiaLP_6_new_WN18RR_v4'

In [12]:
model_name

'Model_SiaLP_6_new_WN18RR_v4'

In [13]:
#load the model
model = keras.models.load_model('../weight_bin/' + model_name + '.h5')

2023-05-15 19:21:21.662337: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [14]:
#load the one-hop neighbor model
model_2 = keras.models.load_model('../weight_bin/' + one_hop_model_name + '.h5')

In [15]:
ind_train_path = '../data/' + data_name + '_ind/train.txt'
ind_valid_path = '../data/' + data_name + '_ind/valid.txt'
ind_test_path = '../data/' + data_name + '_ind/test.txt'

In [16]:
#load the test dataset
one_hop_ind = dict() 
data_ind = set()
s_t_r_ind = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_train_path, 
                        one_hop_ind, data_ind, s_t_r_ind,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [17]:
print(size_0, size_1, len(data_ind))

3861 10945 12334


In [18]:
#load the test dataset
one_hop_ind_test = dict() 
data_ind_test = set()
s_t_r_ind_test = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_test_path, 
                        one_hop_ind_test, data_ind_test, s_t_r_ind_test,
                        entity2id, id2entity, relation2id, id2relation)


len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [19]:
print(size_0, size_1, len(data_ind_test))

10945 10945 1429


In [20]:
#load the validation for existing triple removal when ranking
one_hop_ind_valid = dict() 
data_ind_valid = set()
s_t_r_ind_valid = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_valid_path, 
                        one_hop_ind_valid, data_ind_valid, s_t_r_ind_valid,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [21]:
print(size_0, size_1, len(data_ind_valid))

10945 10945 1394


In [22]:
print(len(entity2id), len(entity2id_ini))

10945 3861


In [23]:
#obtain all the inital entities and new entities
ini_ent_set, new_ent_set, all_ent_set = set(), set(), set()

for ID in id2entity:
    all_ent_set.add(ID)
    if ID in id2entity_ini:
        ini_ent_set.add(ID)
    else:
        new_ent_set.add(ID)
        
print(len(ini_ent_set), len(new_ent_set), len(all_ent_set))

3861 7084 10945


In [24]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_ind_test:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [25]:
overlapping = 0

for ele in data_ind_valid:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [26]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_ind:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [27]:
#the function to do path-based relation scoring
def path_based_relation_scoring(s, t, lower_bd, upper_bd, one_hop, id2relation, model):
    
    path_holder = set()
    
    for iteration in range(3):
    
        result, length_dict = Class_2.obtain_paths('target_specified', 
                                                   s, t, lower_bd, upper_bd, one_hop)
        if t in result:
            
            for path in result[t]:
                
                path_holder.add(path)
                
        del(result, length_dict)
    
    path_holder = list(path_holder)
    random.shuffle(path_holder)
    
    score_dict = defaultdict(float)
    count_dict = defaultdict(int)
    
    count = 0
    
    if len(path_holder) >= 3:
    
        #iterate over path_1
        while count < 10:

            temp_pair = random.sample(path_holder, 3)

            path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

            list_1 = list()
            list_2 = list()
            list_3 = list()
            list_r = list()

            for i in range(len(id2relation)):

                if i not in id2relation:

                    raise ValueError ('error when generating id2relation')
                
                #only care about initial relations
                if i % 2 == 0:

                    list_1.append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    list_2.append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    list_3.append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))
                    list_r.append([i])
            
            #change to arrays
            input_1 = np.array(list_1)
            input_2 = np.array(list_2)
            input_3 = np.array(list_3)
            input_r = np.array(list_r)

            pred = model.predict([input_1, input_2, input_3, input_r], verbose = 0)

            for i in range(pred.shape[0]):
                #need to times 2 to go back to relation id from pred position
                score_dict[2*i] += float(pred[i])
                count_dict[2*i] += 1

            count += 1
            
    #average the score
    for r in score_dict:
        score_dict[r] = deepcopy(score_dict[r]/float(count_dict[r]))
    
    print(len(score_dict), len(path_holder))

    return(score_dict)

In [28]:
#the function to do path-based triple scoring: input one triple
def path_based_triple_scoring(s, r, t, lower_bd, upper_bd, one_hop, id2relation, model):
    
    path_holder = set()
    
    for iteration in range(3):
    
        result, length_dict = Class_2.obtain_paths('target_specified', 
                                                   s, t, lower_bd, upper_bd, one_hop)
        if t in result:
            
            for path in result[t]:
                
                path_holder.add(path)
                
        del(result, length_dict)
    
    path_holder = list(path_holder)
    random.shuffle(path_holder)
    
    score = 0.
    count = 0
    
    if len(path_holder) >= 3:
        
        list_1 = list()
        list_2 = list()
        list_3 = list()
        list_r = list()
    
        #iterate over path_1
        while count < 10:

            temp_pair = random.sample(path_holder, 3)
            path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

            list_1.append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
            list_2.append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
            list_3.append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))
            list_r.append([r])
            
            count += 1
            
        #change to arrays
        input_1 = np.array(list_1)
        input_2 = np.array(list_2)
        input_3 = np.array(list_3)
        input_r = np.array(list_r)

        pred = model.predict([input_1, input_2, input_3, input_r], verbose = 0)

        for i in range(pred.shape[0]):
            score += float(pred[i])
            
        #average the score
        score = score/float(count)

    return(score)

In [29]:
#subgraph based relation scoring
def subgraph_relation_scoring(s, t, lower_bd, upper_bd, one_hop, id2relation, model_2):
    
    path_s, path_t = set(), set() #sets holding all the paths from s or t
    
    for iteration in range(3):
    
        #obtain the paths out from s or t by "any target" mode. That is, 
        result_s, length_dict_s = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)
        result_t, length_dict_t = Class_2.obtain_paths('any_target', t, 'any', lower_bd, upper_bd, one_hop)

        #add paths to the source/target path_set
        for e in result_s:
            for path in result_s[e]:
                path_s.add(path)
        for e in result_t:
            for path in result_t[e]:
                path_t.add(path)
                
        del(result_s, length_dict_s, result_t, length_dict_t)
    
    #final output: the score dict
    score_dict = defaultdict(float)
    count_dict = defaultdict(int)
    
    #see if both path_s and path_t have at least three paths
    if len(path_s) >= 6 and len(path_t) >= 6:

        #change to lists
        path_s, path_t = list(path_s), list(path_t)
        
        count = 0
        while count < 10:
            
            #lists holding the input to the network
            list_s_1 = list()
            list_s_2 = list()
            list_s_3 = list()
            list_s_4 = list()
            list_s_5 = list()
            list_s_6 = list()
            
            list_t_1 = list()
            list_t_2 = list()
            list_t_3 = list()
            list_t_4 = list()
            list_t_5 = list()
            list_t_6 = list()

            list_r = list()

            #randomly obtain three paths
            temp_s = random.sample(path_s, 6)
            temp_t = random.sample(path_t, 6)
            s_p_1, s_p_2, s_p_3, s_p_4, s_p_5, s_p_6 = temp_s[0], temp_s[1], temp_s[2], temp_s[3], temp_s[4], temp_s[5]
            t_p_1, t_p_2, t_p_3, t_p_4, t_p_5, t_p_6 = temp_t[0], temp_t[1], temp_t[2], temp_t[3], temp_t[4], temp_t[5]

            #add all forward (initial relation)
            for i in range(len(id2relation)):

                if i not in id2relation:

                    raise ValueError ('error when generating id2relation')
                    
                if i % 2 == 0:

                    #append the paths: note that we add the space holder id at the end of the shorter path
                    list_s_1.append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                    list_s_2.append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                    list_s_3.append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                    list_s_4.append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                    list_s_5.append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                    list_s_6.append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))
                    
                    list_t_1.append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                    list_t_2.append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                    list_t_3.append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))                    
                    list_t_4.append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))                    
                    list_t_5.append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))                    
                    list_t_6.append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))                                    
                    
                    list_r.append([i])
                
            #change to arrays
            input_s_1 = np.array(list_s_1)
            input_s_2 = np.array(list_s_2)
            input_s_3 = np.array(list_s_3)
            input_s_4 = np.array(list_s_4)
            input_s_5 = np.array(list_s_5)
            input_s_6 = np.array(list_s_6)
            
            input_t_1 = np.array(list_t_1)
            input_t_2 = np.array(list_t_2)
            input_t_3 = np.array(list_t_3)
            input_t_4 = np.array(list_t_4)
            input_t_5 = np.array(list_t_5)
            input_t_6 = np.array(list_t_6)
            
            input_r = np.array(list_r)
            
            pred = model_2.predict([input_s_1, input_s_2, input_s_3, input_s_4,
                                    input_s_5, input_s_6,
                                    input_t_1, input_t_2, input_t_3, input_t_4,
                                    input_t_5, input_t_6,
                                    input_r], verbose = 0)

            for i in range(pred.shape[0]):
                #need to times 2 to go back to relation id from pred position
                score_dict[2*i] += float(pred[i])
                count_dict[2*i] += 1

            count += 1
            
    #average the score
    for r in score_dict:
        score_dict[r] = deepcopy(score_dict[r]/float(count_dict[r]))
            
    print(len(score_dict), len(path_s), len(path_t))
        
    return(score_dict)

In [30]:
#subgraph based triple scoring
def subgraph_triple_scoring(s, r, t, lower_bd, upper_bd, one_hop, id2relation, model_2):
    
    path_s, path_t = set(), set() #sets holding all the paths from s or t
    
    for iteration in range(3):
    
        #obtain the paths out from s or t by "any target" mode. That is, 
        result_s, length_dict_s = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)
        result_t, length_dict_t = Class_2.obtain_paths('any_target', t, 'any', lower_bd, upper_bd, one_hop)

        #add paths to the source/target path_set
        for e in result_s:
            for path in result_s[e]:
                path_s.add(path)
        for e in result_t:
            for path in result_t[e]:
                path_t.add(path)
                
        del(result_s, length_dict_s, result_t, length_dict_t)
    
    #final output: the score dict
    score = 0.
    
    #see if both path_s and path_t have at least three paths
    if len(path_s) >= 6 and len(path_t) >= 6:

        #change to lists
        path_s, path_t = list(path_s), list(path_t)
        
        #lists holding the input to the network
        list_s_1 = list()
        list_s_2 = list()
        list_s_3 = list()
        list_s_4 = list()
        list_s_5 = list()
        list_s_6 = list()

        list_t_1 = list()
        list_t_2 = list()
        list_t_3 = list()
        list_t_4 = list()
        list_t_5 = list()
        list_t_6 = list()
        
        list_r = list()
        
        count = 0
        while count < 10:

            #randomly obtain three paths
            temp_s = random.sample(path_s, 6)
            temp_t = random.sample(path_t, 6)
            s_p_1, s_p_2, s_p_3, s_p_4, s_p_5, s_p_6 = temp_s[0], temp_s[1], temp_s[2], temp_s[3], temp_s[4], temp_s[5]
            t_p_1, t_p_2, t_p_3, t_p_4, t_p_5, t_p_6 = temp_t[0], temp_t[1], temp_t[2], temp_t[3], temp_t[4], temp_t[5]

            #append the paths: note that we add the space holder id at the end of the shorter path
            list_s_1.append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
            list_s_2.append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
            list_s_3.append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
            list_s_4.append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
            list_s_5.append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
            list_s_6.append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

            list_t_1.append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
            list_t_2.append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
            list_t_3.append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))                    
            list_t_4.append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))                    
            list_t_5.append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))                    
            list_t_6.append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))                    

            list_r.append([r])
            count += 1
                
        #change to arrays
        input_s_1 = np.array(list_s_1)
        input_s_2 = np.array(list_s_2)
        input_s_3 = np.array(list_s_3)
        input_s_4 = np.array(list_s_4)
        input_s_5 = np.array(list_s_5)
        input_s_6 = np.array(list_s_6)

        input_t_1 = np.array(list_t_1)
        input_t_2 = np.array(list_t_2)
        input_t_3 = np.array(list_t_3)
        input_t_4 = np.array(list_t_4)
        input_t_5 = np.array(list_t_5)
        input_t_6 = np.array(list_t_6)
        
        input_r = np.array(list_r)

        pred = model_2.predict([input_s_1, input_s_2, input_s_3, input_s_4,
                                input_s_5, input_s_6, 
                                input_t_1, input_t_2, input_t_3, input_t_4,
                                input_t_5, input_t_6, 
                                input_r], verbose = 0)

        for i in range(pred.shape[0]):
            score += float(pred[i])

        #average the score
        score = score/float(count)
        
    return(score)

#### Not fine tuned 

In [31]:
########################################################
#obtain the Hits@N for relation prediction##############

#we select all the triples in the inductive test set
selected = list(data_ind_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    #run the path-based scoring
    score_dict_path = path_based_relation_scoring(s_true, t_true, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)
    
    #run the one-hop neighbour based scoring
    score_dict_subg = subgraph_relation_scoring(s_true, t_true, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    #final score dict
    score_dict = defaultdict(float)
    
    for r in score_dict_path:
        score_dict[r] += score_dict_path[r]
    for r in score_dict_subg:
        score_dict[r] += score_dict_subg[r]
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        #again, we only care about initial relation prediciton
        if r % 2 == 0:
        
            if r in score_dict:

                temp_list.append([score_dict[r], r])

            else:

                temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
            (s_true, sorted_list[p][1], t_true) in data_valid) or (
            (s_true, sorted_list[p][1], t_true) in data) or (
            (s_true, sorted_list[p][1], t_true) in data_ind) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_valid) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_test):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', r_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

9 27
9 54 65
checkcorrect 4 4 real score 1.8632494926452636 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 0 1429
9 23
9 45 47
checkcorrect 4 4 real score 1.631666561961174 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 1 1429
0 0
0 5 11
checkcorrect 0 0 real score 0.0 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 2 1429
9 150
9 42 58
checkcorrect 4 4 real score 1.4165811516344546 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 3 1429
9 16
9 6 10
checkcorrect 0 0 real score 1.1101392440497875 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 4 1429
9 57
9 41 20
checkcorrect 4 4 real score 1.3651697531342506 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 5 1429
0 1
9 6 20
checkcorrect 4 4 real score 0.7655016362667084 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 6 1429
9 5
9

0 0
0 68 2
checkcorrect 6 6 real score 0.0 Hits@1 0.86 Hits@3 0.96 Hits@10 1.0 MRR 0.9141666666666666 cur_rank 3 abs_cur_rank 3 total_num 49 1429
9 139
9 35 46
checkcorrect 4 4 real score 1.725605207681656 Hits@1 0.8627450980392157 Hits@3 0.9607843137254902 Hits@10 1.0 MRR 0.9158496732026142 cur_rank 0 abs_cur_rank 0 total_num 50 1429
9 9
9 29 50
checkcorrect 0 0 real score 0.7738050640793517 Hits@1 0.8653846153846154 Hits@3 0.9615384615384616 Hits@10 1.0 MRR 0.9174679487179487 cur_rank 0 abs_cur_rank 0 total_num 51 1429
0 1
9 10 25
checkcorrect 4 4 real score 0.7295501708984375 Hits@1 0.8679245283018868 Hits@3 0.9622641509433962 Hits@10 1.0 MRR 0.9190251572327043 cur_rank 0 abs_cur_rank 0 total_num 52 1429
9 25
9 15 12
checkcorrect 0 0 real score 1.9064445436000823 Hits@1 0.8703703703703703 Hits@3 0.9629629629629629 Hits@10 1.0 MRR 0.9205246913580246 cur_rank 0 abs_cur_rank 0 total_num 53 1429
0 1
9 45 69
checkcorrect 4 4 real score 0.6537943989038467 Hits@1 0.8727272727272727 Hits@3 

9 10 8
checkcorrect 4 4 real score 0.3013900235295296 Hits@1 0.8315789473684211 Hits@3 0.9789473684210527 Hits@10 1.0 MRR 0.9004385964912279 cur_rank 1 abs_cur_rank 1 total_num 94 1429
0 1
9 11 32
checkcorrect 4 4 real score 0.7040649354457855 Hits@1 0.8333333333333334 Hits@3 0.9791666666666666 Hits@10 1.0 MRR 0.9014756944444443 cur_rank 0 abs_cur_rank 0 total_num 95 1429
9 5
9 33 12
checkcorrect 4 4 real score 1.681504249572754 Hits@1 0.8350515463917526 Hits@3 0.979381443298969 Hits@10 1.0 MRR 0.9024914089347078 cur_rank 0 abs_cur_rank 0 total_num 96 1429
0 0
9 25 9
checkcorrect 0 0 real score 0.674074387550354 Hits@1 0.8367346938775511 Hits@3 0.9795918367346939 Hits@10 1.0 MRR 0.903486394557823 cur_rank 0 abs_cur_rank 0 total_num 97 1429
9 5
9 24 20
checkcorrect 4 4 real score 1.7405902445316315 Hits@1 0.8383838383838383 Hits@3 0.9797979797979798 Hits@10 1.0 MRR 0.9044612794612794 cur_rank 0 abs_cur_rank 0 total_num 98 1429
0 0
9 21 35
checkcorrect 4 4 real score 0.2227696470916271 H

9 5
9 9 29
checkcorrect 0 0 real score 1.1221760164946317 Hits@1 0.8345323741007195 Hits@3 0.9784172661870504 Hits@10 1.0 MRR 0.9013788968824938 cur_rank 0 abs_cur_rank 0 total_num 138 1429
9 150
9 14 27
checkcorrect 0 0 real score 0.9088845839723945 Hits@1 0.8357142857142857 Hits@3 0.9785714285714285 Hits@10 1.0 MRR 0.9020833333333331 cur_rank 0 abs_cur_rank 0 total_num 139 1429
0 0
9 14 33
checkcorrect 0 0 real score 0.323210134729743 Hits@1 0.8297872340425532 Hits@3 0.9787234042553191 Hits@10 1.0 MRR 0.8992316784869975 cur_rank 1 abs_cur_rank 1 total_num 140 1429
0 0
0 11 5
checkcorrect 0 0 real score 0.0 Hits@1 0.8309859154929577 Hits@3 0.9788732394366197 Hits@10 1.0 MRR 0.8999413145539904 cur_rank 0 abs_cur_rank 0 total_num 141 1429
9 150
9 37 33
checkcorrect 4 4 real score 1.6051250144839286 Hits@1 0.8321678321678322 Hits@3 0.9790209790209791 Hits@10 1.0 MRR 0.9006410256410253 cur_rank 0 abs_cur_rank 0 total_num 142 1429
0 1
9 44 18
checkcorrect 4 4 real score 0.672482642531395 H

0 1
9 29 34
checkcorrect 4 4 real score 0.7152851179242135 Hits@1 0.8469945355191257 Hits@3 0.9836065573770492 Hits@10 1.0 MRR 0.9096083788706738 cur_rank 0 abs_cur_rank 0 total_num 182 1429
0 1
9 14 7
checkcorrect 4 4 real score 0.6547450482845306 Hits@1 0.8478260869565217 Hits@3 0.9836956521739131 Hits@10 1.0 MRR 0.9100996376811593 cur_rank 0 abs_cur_rank 0 total_num 183 1429
0 1
9 6 7
checkcorrect 4 4 real score 0.6316552221775055 Hits@1 0.8486486486486486 Hits@3 0.9837837837837838 Hits@10 1.0 MRR 0.9105855855855854 cur_rank 0 abs_cur_rank 0 total_num 184 1429
9 21
9 29 17
checkcorrect 4 4 real score 1.742071169614792 Hits@1 0.8494623655913979 Hits@3 0.9838709677419355 Hits@10 1.0 MRR 0.9110663082437275 cur_rank 0 abs_cur_rank 0 total_num 185 1429
9 51
9 81 34
checkcorrect 4 4 real score 1.829626852273941 Hits@1 0.8502673796791443 Hits@3 0.983957219251337 Hits@10 1.0 MRR 0.9115418894830658 cur_rank 0 abs_cur_rank 0 total_num 186 1429
9 10
9 21 10
checkcorrect 4 4 real score 1.126156

9 148
9 46 37
checkcorrect 0 0 real score 1.4157205879688264 Hits@1 0.8634361233480177 Hits@3 0.986784140969163 Hits@10 1.0 MRR 0.9197870778267253 cur_rank 0 abs_cur_rank 0 total_num 226 1429
9 5
9 40 45
checkcorrect 4 4 real score 1.7149086594581604 Hits@1 0.8640350877192983 Hits@3 0.9868421052631579 Hits@10 1.0 MRR 0.9201388888888888 cur_rank 0 abs_cur_rank 0 total_num 227 1429
0 2
9 46 19
checkcorrect 4 4 real score 0.7239090502262115 Hits@1 0.8646288209606987 Hits@3 0.9868995633187773 Hits@10 1.0 MRR 0.9204876273653566 cur_rank 0 abs_cur_rank 0 total_num 228 1429
9 150
9 44 46
checkcorrect 0 0 real score 1.0801763579249382 Hits@1 0.8652173913043478 Hits@3 0.9869565217391304 Hits@10 1.0 MRR 0.9208333333333333 cur_rank 0 abs_cur_rank 0 total_num 229 1429
0 0
9 18 10
checkcorrect 0 0 real score 0.0812322661280632 Hits@1 0.8614718614718615 Hits@3 0.9826839826839827 Hits@10 1.0 MRR 0.9179292929292929 cur_rank 3 abs_cur_rank 3 total_num 230 1429
0 1
9 18 14
checkcorrect 4 4 real score 0.

9 102
9 54 36
checkcorrect 4 4 real score 1.189938947558403 Hits@1 0.8597785977859779 Hits@3 0.981549815498155 Hits@10 1.0 MRR 0.9165129151291513 cur_rank 0 abs_cur_rank 0 total_num 270 1429
0 2
9 6 13
checkcorrect 4 4 real score 0.6912766218185424 Hits@1 0.8602941176470589 Hits@3 0.9816176470588235 Hits@10 1.0 MRR 0.9168198529411765 cur_rank 0 abs_cur_rank 0 total_num 271 1429
0 1
9 34 12
checkcorrect 4 4 real score 0.7208852291107177 Hits@1 0.8608058608058609 Hits@3 0.9816849816849816 Hits@10 1.0 MRR 0.9171245421245421 cur_rank 0 abs_cur_rank 0 total_num 272 1429
9 21
9 29 53
checkcorrect 4 4 real score 1.2607582449913024 Hits@1 0.8613138686131386 Hits@3 0.9817518248175182 Hits@10 1.0 MRR 0.9174270072992701 cur_rank 0 abs_cur_rank 0 total_num 273 1429
0 1
9 52 15
checkcorrect 4 4 real score 0.6686168670654297 Hits@1 0.8618181818181818 Hits@3 0.9818181818181818 Hits@10 1.0 MRR 0.9177272727272727 cur_rank 0 abs_cur_rank 0 total_num 274 1429
0 1
9 31 11
checkcorrect 4 4 real score 0.908

9 23 16
checkcorrect 8 8 real score 0.9172757565975189 Hits@1 0.8598726114649682 Hits@3 0.9840764331210191 Hits@10 1.0 MRR 0.9167993630573247 cur_rank 0 abs_cur_rank 0 total_num 313 1429
0 1
9 16 39
checkcorrect 4 4 real score 0.6382976397871971 Hits@1 0.8571428571428571 Hits@3 0.9841269841269841 Hits@10 1.0 MRR 0.9154761904761903 cur_rank 1 abs_cur_rank 1 total_num 314 1429
9 4
9 24 15
checkcorrect 0 0 real score 1.6789037764072416 Hits@1 0.8575949367088608 Hits@3 0.9841772151898734 Hits@10 1.0 MRR 0.9157436708860758 cur_rank 0 abs_cur_rank 0 total_num 315 1429
0 0
0 33 3
checkcorrect 0 0 real score 0.0 Hits@1 0.8580441640378549 Hits@3 0.9842271293375394 Hits@10 1.0 MRR 0.9160094637223973 cur_rank 0 abs_cur_rank 0 total_num 316 1429
0 0
0 12 4
checkcorrect 0 0 real score 0.0 Hits@1 0.8584905660377359 Hits@3 0.9842767295597484 Hits@10 1.0 MRR 0.9162735849056602 cur_rank 0 abs_cur_rank 0 total_num 317 1429
9 136
9 57 68
checkcorrect 4 4 real score 1.6265767350792886 Hits@1 0.85893416927

9 91
9 29 31
checkcorrect 4 4 real score 1.6491469502449037 Hits@1 0.8603351955307262 Hits@3 0.9860335195530726 Hits@10 1.0 MRR 0.9181797020484169 cur_rank 0 abs_cur_rank 0 total_num 357 1429
0 0
9 14 57
checkcorrect 2 2 real score 0.015129832131788135 Hits@1 0.8579387186629527 Hits@3 0.9832869080779945 Hits@10 1.0 MRR 0.9163184772516246 cur_rank 3 abs_cur_rank 3 total_num 358 1429
9 12
9 23 7
checkcorrect 0 0 real score 0.869003812968731 Hits@1 0.8583333333333333 Hits@3 0.9833333333333333 Hits@10 1.0 MRR 0.9165509259259257 cur_rank 0 abs_cur_rank 0 total_num 359 1429
9 6
9 13 53
checkcorrect 2 2 real score 0.10590810067951681 Hits@1 0.8559556786703602 Hits@3 0.9833795013850416 Hits@10 1.0 MRR 0.9149353647276083 cur_rank 2 abs_cur_rank 2 total_num 360 1429
9 133
9 49 37
checkcorrect 4 4 real score 1.710992556810379 Hits@1 0.856353591160221 Hits@3 0.9834254143646409 Hits@10 1.0 MRR 0.9151703499079187 cur_rank 0 abs_cur_rank 0 total_num 361 1429
0 1
9 19 82
checkcorrect 4 4 real score 0.

9 28 20
checkcorrect 4 4 real score 0.2638947606086731 Hits@1 0.8482587064676617 Hits@3 0.9800995024875622 Hits@10 1.0 MRR 0.9099986180210057 cur_rank 0 abs_cur_rank 0 total_num 401 1429
0 1
9 6 12
checkcorrect 4 4 real score 0.8127825677394866 Hits@1 0.8486352357320099 Hits@3 0.9801488833746899 Hits@10 1.0 MRR 0.9102219465122687 cur_rank 0 abs_cur_rank 0 total_num 402 1429
0 0
9 31 6
checkcorrect 4 4 real score 0.029520758893340826 Hits@1 0.8465346534653465 Hits@3 0.9777227722772277 Hits@10 1.0 MRR 0.9084639713971393 cur_rank 4 abs_cur_rank 4 total_num 403 1429
9 48
9 62 38
checkcorrect 4 4 real score 1.4003364026546476 Hits@1 0.8469135802469135 Hits@3 0.9777777777777777 Hits@10 1.0 MRR 0.9086899862825785 cur_rank 0 abs_cur_rank 0 total_num 404 1429
9 4
0 26 3
checkcorrect 0 0 real score 0.7285836160182952 Hits@1 0.8472906403940886 Hits@3 0.9778325123152709 Hits@10 1.0 MRR 0.9089148877941977 cur_rank 0 abs_cur_rank 0 total_num 405 1429
9 118
9 45 45
checkcorrect 6 6 real score 0.46502

9 17 12
checkcorrect 4 4 real score 0.44903892278671265 Hits@1 0.8408071748878924 Hits@3 0.9798206278026906 Hits@10 1.0 MRR 0.9062468858993518 cur_rank 0 abs_cur_rank 0 total_num 445 1429
9 3
0 30 4
checkcorrect 4 4 real score 0.9998532116413117 Hits@1 0.8411633109619687 Hits@3 0.9798657718120806 Hits@10 1.0 MRR 0.906456624409644 cur_rank 0 abs_cur_rank 0 total_num 446 1429
0 1
9 47 21
checkcorrect 2 2 real score 0.3596714034676552 Hits@1 0.8392857142857143 Hits@3 0.9799107142857143 Hits@10 1.0 MRR 0.9055493551587297 cur_rank 1 abs_cur_rank 1 total_num 447 1429
0 0
0 21 5
checkcorrect 0 0 real score 0.0 Hits@1 0.8396436525612472 Hits@3 0.9799554565701559 Hits@10 1.0 MRR 0.9057597129423406 cur_rank 0 abs_cur_rank 0 total_num 448 1429
0 0
0 1 13
checkcorrect 4 4 real score 0.0 Hits@1 0.8377777777777777 Hits@3 0.98 Hits@10 1.0 MRR 0.9044876543209871 cur_rank 2 abs_cur_rank 2 total_num 449 1429
0 1
9 9 21
checkcorrect 4 4 real score 0.7459640622138977 Hits@1 0.8381374722838137 Hits@3 0.980

9 124
9 25 56
checkcorrect 0 0 real score 1.425166016817093 Hits@1 0.8367346938775511 Hits@3 0.9816326530612245 Hits@10 1.0 MRR 0.904461451247165 cur_rank 0 abs_cur_rank 0 total_num 489 1429
0 1
9 14 9
checkcorrect 4 4 real score 0.584816861152649 Hits@1 0.8370672097759674 Hits@3 0.9816700610997964 Hits@10 1.0 MRR 0.9046560307761932 cur_rank 0 abs_cur_rank 0 total_num 490 1429
9 7
9 38 8
checkcorrect 0 0 real score 0.9505188897252084 Hits@1 0.8353658536585366 Hits@3 0.9817073170731707 Hits@10 1.0 MRR 0.9038335591689245 cur_rank 1 abs_cur_rank 1 total_num 491 1429
0 1
9 26 27
checkcorrect 4 4 real score 0.5105748057365418 Hits@1 0.8336713995943205 Hits@3 0.9817444219066938 Hits@10 1.0 MRR 0.9030144241604683 cur_rank 1 abs_cur_rank 1 total_num 492 1429
0 1
9 7 6
checkcorrect 4 4 real score 0.6970747828483581 Hits@1 0.8340080971659919 Hits@3 0.9817813765182186 Hits@10 1.0 MRR 0.9032107512370665 cur_rank 0 abs_cur_rank 0 total_num 493 1429
0 0
9 19 20
checkcorrect 0 0 real score 0.10883176

9 84
9 36 58
checkcorrect 0 0 real score 1.168383901566267 Hits@1 0.8314606741573034 Hits@3 0.9775280898876404 Hits@10 1.0 MRR 0.9004317519766952 cur_rank 0 abs_cur_rank 0 total_num 533 1429
9 53
9 20 47
checkcorrect 4 4 real score 1.507726323604584 Hits@1 0.8317757009345794 Hits@3 0.9775700934579439 Hits@10 1.0 MRR 0.9006178608515051 cur_rank 0 abs_cur_rank 0 total_num 534 1429
9 150
9 38 61
checkcorrect 4 4 real score 1.739845710992813 Hits@1 0.832089552238806 Hits@3 0.9776119402985075 Hits@10 1.0 MRR 0.900803275290215 cur_rank 0 abs_cur_rank 0 total_num 535 1429
0 1
9 15 50
checkcorrect 4 4 real score 0.7588182151317596 Hits@1 0.8324022346368715 Hits@3 0.9776536312849162 Hits@10 1.0 MRR 0.9009879991723562 cur_rank 0 abs_cur_rank 0 total_num 536 1429
9 65
9 44 53
checkcorrect 4 4 real score 1.6874440252780913 Hits@1 0.8327137546468402 Hits@3 0.9776951672862454 Hits@10 1.0 MRR 0.9011720363486156 cur_rank 0 abs_cur_rank 0 total_num 537 1429
9 13
9 45 40
checkcorrect 4 4 real score 1.74

9 78
9 33 28
checkcorrect 4 4 real score 1.808294379711151 Hits@1 0.8356401384083045 Hits@3 0.9775086505190311 Hits@10 1.0 MRR 0.902821030372933 cur_rank 0 abs_cur_rank 0 total_num 577 1429
9 8
9 38 14
checkcorrect 0 0 real score 1.7148607790470123 Hits@1 0.8359240069084629 Hits@3 0.9775474956822107 Hits@10 1.0 MRR 0.9029888696987138 cur_rank 0 abs_cur_rank 0 total_num 578 1429
9 3
9 28 29
checkcorrect 4 4 real score 1.682643210887909 Hits@1 0.8362068965517241 Hits@3 0.9775862068965517 Hits@10 1.0 MRR 0.9031561302681987 cur_rank 0 abs_cur_rank 0 total_num 579 1429
0 1
9 32 11
checkcorrect 4 4 real score 0.7658750385046005 Hits@1 0.8364888123924269 Hits@3 0.9776247848537005 Hits@10 1.0 MRR 0.9033228150698026 cur_rank 0 abs_cur_rank 0 total_num 580 1429
9 11
9 25 58
checkcorrect 0 0 real score 0.8738673649728299 Hits@1 0.8367697594501718 Hits@3 0.9776632302405498 Hits@10 1.0 MRR 0.9034889270714008 cur_rank 0 abs_cur_rank 0 total_num 581 1429
9 28
9 19 27
checkcorrect 2 2 real score 1.902

9 81
9 41 35
checkcorrect 4 4 real score 1.7141697973012926 Hits@1 0.8279742765273312 Hits@3 0.977491961414791 Hits@10 1.0 MRR 0.8984414076455874 cur_rank 0 abs_cur_rank 0 total_num 621 1429
9 8
9 30 11
checkcorrect 0 0 real score 1.6010964900255202 Hits@1 0.8282504012841091 Hits@3 0.9775280898876404 Hits@10 1.0 MRR 0.898604423042625 cur_rank 0 abs_cur_rank 0 total_num 622 1429
0 1
9 8 13
checkcorrect 4 4 real score 0.7640736043453217 Hits@1 0.8285256410256411 Hits@3 0.9775641025641025 Hits@10 1.0 MRR 0.8987669159544157 cur_rank 0 abs_cur_rank 0 total_num 623 1429
0 2
9 89 18
checkcorrect 2 2 real score 0.3338980682194233 Hits@1 0.8288 Hits@3 0.9776 Hits@10 1.0 MRR 0.8989288888888887 cur_rank 0 abs_cur_rank 0 total_num 624 1429
9 19
9 39 21
checkcorrect 4 4 real score 1.657530742883682 Hits@1 0.829073482428115 Hits@3 0.9776357827476039 Hits@10 1.0 MRR 0.899090344337948 cur_rank 0 abs_cur_rank 0 total_num 625 1429
9 18
9 40 33
checkcorrect 4 4 real score 1.6568560004234314 Hits@1 0.8293

0 2
9 45 17
checkcorrect 2 2 real score 0.2499485120177269 Hits@1 0.8290854572713643 Hits@3 0.9775112443778111 Hits@10 1.0 MRR 0.8992604888032172 cur_rank 1 abs_cur_rank 1 total_num 666 1429
0 0
0 1 7
checkcorrect 0 0 real score 0.0 Hits@1 0.8293413173652695 Hits@3 0.9775449101796407 Hits@10 1.0 MRR 0.8994112964547094 cur_rank 0 abs_cur_rank 0 total_num 667 1429
9 150
9 27 48
checkcorrect 0 0 real score 1.0442537892609836 Hits@1 0.8295964125560538 Hits@3 0.9775784753363229 Hits@10 1.0 MRR 0.8995616532612046 cur_rank 0 abs_cur_rank 0 total_num 668 1429
9 125
9 55 31
checkcorrect 8 8 real score 1.1000709906220436 Hits@1 0.8298507462686567 Hits@3 0.9776119402985075 Hits@10 1.0 MRR 0.8997115612414118 cur_rank 0 abs_cur_rank 0 total_num 669 1429
9 150
9 39 56
checkcorrect 4 4 real score 1.6837796062231063 Hits@1 0.8301043219076006 Hits@3 0.977645305514158 Hits@10 1.0 MRR 0.8998610224020058 cur_rank 0 abs_cur_rank 0 total_num 670 1429
9 3
9 16 10
checkcorrect 4 4 real score 1.572136026620865

9 19 42
checkcorrect 4 4 real score 0.5967566281557083 Hits@1 0.829817158931083 Hits@3 0.9774964838255977 Hits@10 1.0 MRR 0.8997907039046278 cur_rank 0 abs_cur_rank 0 total_num 710 1429
0 1
9 12 30
checkcorrect 4 4 real score 0.6966431319713593 Hits@1 0.8300561797752809 Hits@3 0.9775280898876404 Hits@10 1.0 MRR 0.8999314472980201 cur_rank 0 abs_cur_rank 0 total_num 711 1429
9 134
9 22 41
checkcorrect 0 0 real score 1.4254086762666702 Hits@1 0.8302945301542777 Hits@3 0.9775596072931276 Hits@10 1.0 MRR 0.9000717958992852 cur_rank 0 abs_cur_rank 0 total_num 712 1429
9 149
9 32 44
checkcorrect 4 4 real score 1.7871483981609344 Hits@1 0.8305322128851541 Hits@3 0.9775910364145658 Hits@10 1.0 MRR 0.9002117513672133 cur_rank 0 abs_cur_rank 0 total_num 713 1429
9 127
9 36 49
checkcorrect 4 4 real score 1.5442125380039216 Hits@1 0.8307692307692308 Hits@3 0.9776223776223776 Hits@10 1.0 MRR 0.9003513153513152 cur_rank 0 abs_cur_rank 0 total_num 714 1429
0 1
9 15 13
checkcorrect 4 4 real score 0.40

9 28
9 27 27
checkcorrect 0 0 real score 1.1407517850399018 Hits@1 0.8306878306878307 Hits@3 0.9761904761904762 Hits@10 1.0 MRR 0.8999133912824387 cur_rank 0 abs_cur_rank 0 total_num 755 1429
9 105
9 30 28
checkcorrect 4 4 real score 1.739414131641388 Hits@1 0.8309114927344782 Hits@3 0.9762219286657859 Hits@10 1.0 MRR 0.9000456060891991 cur_rank 0 abs_cur_rank 0 total_num 756 1429
9 15
9 42 33
checkcorrect 4 4 real score 1.8954342544078826 Hits@1 0.8311345646437994 Hits@3 0.9762532981530343 Hits@10 1.0 MRR 0.9001774720442265 cur_rank 0 abs_cur_rank 0 total_num 757 1429
0 1
9 26 10
checkcorrect 4 4 real score 0.7404068663716317 Hits@1 0.8313570487483531 Hits@3 0.9762845849802372 Hits@10 1.0 MRR 0.9003089905263817 cur_rank 0 abs_cur_rank 0 total_num 758 1429
9 13
9 16 16
checkcorrect 4 4 real score 1.520648044347763 Hits@1 0.8315789473684211 Hits@3 0.9763157894736842 Hits@10 1.0 MRR 0.900440162907268 cur_rank 0 abs_cur_rank 0 total_num 759 1429
0 1
9 16 12
checkcorrect 4 4 real score 0.6

9 27 41
checkcorrect 4 4 real score 0.6776199519634247 Hits@1 0.8285356695869838 Hits@3 0.9712140175219024 Hits@10 1.0 MRR 0.8976235174921032 cur_rank 0 abs_cur_rank 0 total_num 798 1429
0 2
9 14 13
checkcorrect 4 4 real score 0.68983413875103 Hits@1 0.82875 Hits@3 0.97125 Hits@10 1.0 MRR 0.8977514880952381 cur_rank 0 abs_cur_rank 0 total_num 799 1429
0 1
0 5 6
checkcorrect 4 4 real score 0.0 Hits@1 0.8277153558052435 Hits@3 0.9712858926342073 Hits@10 1.0 MRR 0.8970468462041497 cur_rank 2 abs_cur_rank 2 total_num 800 1429
9 150
9 26 46
checkcorrect 4 4 real score 1.57409108877182 Hits@1 0.827930174563591 Hits@3 0.9713216957605985 Hits@10 1.0 MRR 0.8971752167201046 cur_rank 0 abs_cur_rank 0 total_num 801 1429
9 7
9 14 24
checkcorrect 4 4 real score 1.296482503414154 Hits@1 0.8281444582814446 Hits@3 0.9713574097135741 Hits@10 1.0 MRR 0.897303267508747 cur_rank 0 abs_cur_rank 0 total_num 802 1429
9 38
9 20 22
checkcorrect 0 0 real score 1.4472857847809792 Hits@1 0.8283582089552238 Hits@3 

9 5
9 15 12
checkcorrect 4 4 real score 1.6786578595638275 Hits@1 0.8268090154211151 Hits@3 0.970344009489917 Hits@10 1.0 MRR 0.8966799412528951 cur_rank 0 abs_cur_rank 0 total_num 842 1429
0 1
0 4 5
checkcorrect 4 4 real score 0.0 Hits@1 0.8258293838862559 Hits@3 0.9703791469194313 Hits@10 1.0 MRR 0.8960124689686303 cur_rank 2 abs_cur_rank 2 total_num 843 1429
9 41
9 64 35
checkcorrect 4 4 real score 1.3647105962038042 Hits@1 0.8260355029585799 Hits@3 0.9704142011834319 Hits@10 1.0 MRR 0.8961355311355312 cur_rank 0 abs_cur_rank 0 total_num 844 1429
9 10
9 26 13
checkcorrect 0 0 real score 1.0863813223317265 Hits@1 0.8262411347517731 Hits@3 0.9704491725768322 Hits@10 1.0 MRR 0.8962583023753238 cur_rank 0 abs_cur_rank 0 total_num 845 1429
0 2
9 32 23
checkcorrect 0 0 real score 0.10537425372749568 Hits@1 0.8252656434474617 Hits@3 0.9693034238488784 Hits@10 1.0 MRR 0.8954953055602407 cur_rank 3 abs_cur_rank 3 total_num 846 1429
0 1
9 12 23
checkcorrect 4 4 real score 0.5994098097085953 H

9 33 22
checkcorrect 0 0 real score 0.9510286958888172 Hits@1 0.8261851015801355 Hits@3 0.9706546275395034 Hits@10 1.0 MRR 0.8965212834569496 cur_rank 1 abs_cur_rank 1 total_num 885 1429
9 80
9 71 50
checkcorrect 10 10 real score 1.2066470134072005 Hits@1 0.826381059751973 Hits@3 0.9706877113866967 Hits@10 1.0 MRR 0.8966379449186667 cur_rank 0 abs_cur_rank 0 total_num 886 1429
0 1
9 16 29
checkcorrect 4 4 real score 0.5764733493328095 Hits@1 0.8265765765765766 Hits@3 0.9707207207207207 Hits@10 1.0 MRR 0.8967543436293438 cur_rank 0 abs_cur_rank 0 total_num 887 1429
9 108
9 64 28
checkcorrect 2 2 real score 0.01713173426687717 Hits@1 0.8256467941507312 Hits@3 0.9707536557930259 Hits@10 1.0 MRR 0.8961205742139378 cur_rank 2 abs_cur_rank 2 total_num 888 1429
0 1
9 10 34
checkcorrect 4 4 real score 0.7124002903699875 Hits@1 0.8258426966292135 Hits@3 0.9707865168539326 Hits@10 1.0 MRR 0.8962372926698772 cur_rank 0 abs_cur_rank 0 total_num 889 1429
0 0
0 10 5
checkcorrect 4 4 real score 0.0 H

9 48
9 36 45
checkcorrect 6 6 real score 0.6561888117343186 Hits@1 0.821505376344086 Hits@3 0.9698924731182795 Hits@10 1.0 MRR 0.8926117938214718 cur_rank 1 abs_cur_rank 1 total_num 929 1429
9 68
9 15 53
checkcorrect 0 0 real score 1.0175977969542145 Hits@1 0.8216970998925887 Hits@3 0.9699248120300752 Hits@10 1.0 MRR 0.8927271409817065 cur_rank 0 abs_cur_rank 0 total_num 930 1429
0 1
9 13 38
checkcorrect 4 4 real score 0.6746913671493531 Hits@1 0.8218884120171673 Hits@3 0.9699570815450643 Hits@10 1.0 MRR 0.8928422406158463 cur_rank 0 abs_cur_rank 0 total_num 931 1429
9 146
9 66 71
checkcorrect 6 6 real score 1.1481050536036492 Hits@1 0.8220793140407289 Hits@3 0.969989281886388 Hits@10 1.0 MRR 0.892957093519795 cur_rank 0 abs_cur_rank 0 total_num 932 1429
9 49
9 27 24
checkcorrect 4 4 real score 1.8153861463069916 Hits@1 0.8222698072805139 Hits@3 0.9700214132762313 Hits@10 1.0 MRR 0.8930717004860479 cur_rank 0 abs_cur_rank 0 total_num 933 1429
9 93
9 28 47
checkcorrect 0 0 real score 0.

9 34 66
checkcorrect 4 4 real score 1.8368425369262695 Hits@1 0.8242548818088387 Hits@3 0.9701952723535457 Hits@10 1.0 MRR 0.8941030848790361 cur_rank 0 abs_cur_rank 0 total_num 972 1429
0 1
9 28 9
checkcorrect 4 4 real score 0.6353953957557679 Hits@1 0.824435318275154 Hits@3 0.9702258726899384 Hits@10 1.0 MRR 0.8942118086111932 cur_rank 0 abs_cur_rank 0 total_num 973 1429
9 22
9 24 26
checkcorrect 0 0 real score 1.4722416326403618 Hits@1 0.8246153846153846 Hits@3 0.9702564102564103 Hits@10 1.0 MRR 0.8943203093203099 cur_rank 0 abs_cur_rank 0 total_num 974 1429
9 130
9 58 52
checkcorrect 4 4 real score 1.6232303112745285 Hits@1 0.8247950819672131 Hits@3 0.9702868852459017 Hits@10 1.0 MRR 0.8944285876919079 cur_rank 0 abs_cur_rank 0 total_num 975 1429
9 5
9 55 37
checkcorrect 4 4 real score 1.699029278755188 Hits@1 0.8249744114636642 Hits@3 0.970317297850563 Hits@10 1.0 MRR 0.8945366444087023 cur_rank 0 abs_cur_rank 0 total_num 976 1429
9 150
9 31 41
checkcorrect 4 4 real score 1.754437

9 3
9 41 23
checkcorrect 4 4 real score 1.7360943794250487 Hits@1 0.8190757128810227 Hits@3 0.967551622418879 Hits@10 1.0 MRR 0.8906599709697061 cur_rank 0 abs_cur_rank 0 total_num 1016 1429
9 7
9 47 9
checkcorrect 2 2 real score 1.7945489108562471 Hits@1 0.8192534381139489 Hits@3 0.9675834970530451 Hits@10 1.0 MRR 0.8907673776779873 cur_rank 0 abs_cur_rank 0 total_num 1017 1429
9 58
9 33 10
checkcorrect 6 6 real score 0.45977366007864473 Hits@1 0.8184494602551521 Hits@3 0.9676153091265947 Hits@10 1.0 MRR 0.8902203373989446 cur_rank 2 abs_cur_rank 2 total_num 1018 1429
9 70
9 38 53
checkcorrect 4 4 real score 1.4885104969143867 Hits@1 0.8186274509803921 Hits@3 0.9676470588235294 Hits@10 1.0 MRR 0.8903279645191416 cur_rank 0 abs_cur_rank 0 total_num 1019 1429
9 118
9 30 39
checkcorrect 0 0 real score 1.1256459638476373 Hits@1 0.8188050930460333 Hits@3 0.9676787463271302 Hits@10 1.0 MRR 0.8904353808124628 cur_rank 0 abs_cur_rank 0 total_num 1020 1429
9 97
9 30 24
checkcorrect 4 4 real sc

9 5
9 20 14
checkcorrect 4 4 real score 1.7767064690589904 Hits@1 0.8188679245283019 Hits@3 0.9679245283018868 Hits@10 1.0 MRR 0.8906405360886499 cur_rank 0 abs_cur_rank 0 total_num 1059 1429
9 17
9 67 39
checkcorrect 4 4 real score 1.6800382524728774 Hits@1 0.819038642789821 Hits@3 0.9679547596606974 Hits@10 1.0 MRR 0.8907436081564268 cur_rank 0 abs_cur_rank 0 total_num 1060 1429
0 1
9 11 24
checkcorrect 4 4 real score 0.5468323320150376 Hits@1 0.8192090395480226 Hits@3 0.967984934086629 Hits@10 1.0 MRR 0.8908464861148483 cur_rank 0 abs_cur_rank 0 total_num 1061 1429
0 0
0 36 5
checkcorrect 4 4 real score 0.0 Hits@1 0.8184383819379115 Hits@3 0.9680150517403575 Hits@10 1.0 MRR 0.8903220146635017 cur_rank 2 abs_cur_rank 2 total_num 1062 1429
9 3
9 12 25
checkcorrect 4 4 real score 1.4437153726816176 Hits@1 0.818609022556391 Hits@3 0.9680451127819549 Hits@10 1.0 MRR 0.8904250954767878 cur_rank 0 abs_cur_rank 0 total_num 1063 1429
0 1
9 27 24
checkcorrect 4 4 real score 0.767524915933609 

9 11 22
checkcorrect 0 0 real score 0.646690982580185 Hits@1 0.8179347826086957 Hits@3 0.9692028985507246 Hits@10 1.0 MRR 0.8904700799401893 cur_rank 0 abs_cur_rank 0 total_num 1103 1429
0 0
9 24 9
checkcorrect 0 0 real score 0.009046766348183155 Hits@1 0.8171945701357466 Hits@3 0.9692307692307692 Hits@10 1.0 MRR 0.8901167133520081 cur_rank 1 abs_cur_rank 1 total_num 1104 1429
9 25
9 17 40
checkcorrect 8 8 real score 1.957986867427826 Hits@1 0.8173598553345389 Hits@3 0.969258589511754 Hits@10 1.0 MRR 0.8902160653290859 cur_rank 0 abs_cur_rank 0 total_num 1105 1429
9 3
9 17 11
checkcorrect 4 4 real score 1.2223079867661 Hits@1 0.8175248419150858 Hits@3 0.969286359530262 Hits@10 1.0 MRR 0.8903152378084634 cur_rank 0 abs_cur_rank 0 total_num 1106 1429
0 1
9 12 34
checkcorrect 4 4 real score 0.7312681376934052 Hits@1 0.8176895306859205 Hits@3 0.9693140794223827 Hits@10 1.0 MRR 0.8904142312761454 cur_rank 0 abs_cur_rank 0 total_num 1107 1429
9 4
9 8 13
checkcorrect 2 2 real score 0.70157738

9 11
9 28 38
checkcorrect 4 4 real score 1.6629471600055696 Hits@1 0.8160418482999128 Hits@3 0.969485614646905 Hits@10 1.0 MRR 0.8895995073414436 cur_rank 0 abs_cur_rank 0 total_num 1146 1429
9 102
9 27 35
checkcorrect 4 4 real score 1.8047187387943269 Hits@1 0.8162020905923345 Hits@3 0.9695121951219512 Hits@10 1.0 MRR 0.8896956750179754 cur_rank 0 abs_cur_rank 0 total_num 1147 1429
9 8
9 20 8
checkcorrect 0 0 real score 0.9669776022899896 Hits@1 0.8163620539599652 Hits@3 0.9695387293298521 Hits@10 1.0 MRR 0.8897916753008144 cur_rank 0 abs_cur_rank 0 total_num 1148 1429
9 60
9 14 26
checkcorrect 4 4 real score 1.6149019420146944 Hits@1 0.8165217391304348 Hits@3 0.9695652173913043 Hits@10 1.0 MRR 0.8898875086266398 cur_rank 0 abs_cur_rank 0 total_num 1149 1429
9 149
9 90 41
checkcorrect 4 4 real score 0.9206566199660301 Hits@1 0.8166811468288445 Hits@3 0.9695916594265855 Hits@10 1.0 MRR 0.8899831754306133 cur_rank 0 abs_cur_rank 0 total_num 1150 1429
0 1
9 11 9
checkcorrect 4 4 real sco

9 32
9 36 33
checkcorrect 4 4 real score 1.2063804229721429 Hits@1 0.819327731092437 Hits@3 0.9705882352941176 Hits@10 1.0 MRR 0.8917680405495539 cur_rank 0 abs_cur_rank 0 total_num 1189 1429
9 76
9 36 26
checkcorrect 0 0 real score 1.3405774191021917 Hits@1 0.8194794290512175 Hits@3 0.9706129303106633 Hits@10 1.0 MRR 0.8918589154105534 cur_rank 0 abs_cur_rank 0 total_num 1190 1429
9 73
9 28 34
checkcorrect 4 4 real score 1.612517276406288 Hits@1 0.8196308724832215 Hits@3 0.9706375838926175 Hits@10 1.0 MRR 0.891949637796954 cur_rank 0 abs_cur_rank 0 total_num 1191 1429
0 1
9 11 43
checkcorrect 6 6 real score 0.5591028153896331 Hits@1 0.8197820620284996 Hits@3 0.9706621961441744 Hits@10 1.0 MRR 0.8920402080921787 cur_rank 0 abs_cur_rank 0 total_num 1192 1429
0 0
9 40 12
checkcorrect 0 0 real score 0.7278498470783233 Hits@1 0.8199329983249581 Hits@3 0.9706867671691792 Hits@10 1.0 MRR 0.8921306266783662 cur_rank 0 abs_cur_rank 0 total_num 1193 1429
0 1
9 19 29
checkcorrect 0 0 real score 

9 147
9 63 54
checkcorrect 4 4 real score 1.3737499237060546 Hits@1 0.819140308191403 Hits@3 0.9708029197080292 Hits@10 1.0 MRR 0.8917126894012545 cur_rank 0 abs_cur_rank 0 total_num 1232 1429
9 32
9 45 14
checkcorrect 2 2 real score 0.0578967057634145 Hits@1 0.8184764991896273 Hits@3 0.9700162074554295 Hits@10 1.0 MRR 0.8911926629106538 cur_rank 3 abs_cur_rank 3 total_num 1233 1429
9 20
9 15 39
checkcorrect 4 4 real score 1.5246949642896652 Hits@1 0.8186234817813766 Hits@3 0.9700404858299595 Hits@10 1.0 MRR 0.8912807660176087 cur_rank 0 abs_cur_rank 0 total_num 1234 1429
9 3
0 4 30
checkcorrect 4 4 real score 0.9998216271400452 Hits@1 0.8187702265372169 Hits@3 0.9700647249190939 Hits@10 1.0 MRR 0.8913687265629019 cur_rank 0 abs_cur_rank 0 total_num 1235 1429
0 0
9 21 14
checkcorrect 4 4 real score 0.08064824407920242 Hits@1 0.8181083265966047 Hits@3 0.9700889248181084 Hits@10 1.0 MRR 0.8910523411736029 cur_rank 1 abs_cur_rank 1 total_num 1236 1429
0 1
9 15 8
checkcorrect 4 4 real scor

9 5
9 12 25
checkcorrect 4 4 real score 1.7149268865585328 Hits@1 0.8167580266249022 Hits@3 0.9694596711041503 Hits@10 1.0 MRR 0.8900274701371024 cur_rank 0 abs_cur_rank 0 total_num 1276 1429
0 1
9 33 11
checkcorrect 6 6 real score 0.7385632663965225 Hits@1 0.8169014084507042 Hits@3 0.9694835680751174 Hits@10 1.0 MRR 0.8901135206299529 cur_rank 0 abs_cur_rank 0 total_num 1277 1429
0 0
9 17 22
checkcorrect 4 4 real score 0.7900406360626221 Hits@1 0.8170445660672401 Hits@3 0.9695074276778733 Hits@10 1.0 MRR 0.8901994365637841 cur_rank 0 abs_cur_rank 0 total_num 1278 1429
0 1
9 16 9
checkcorrect 4 4 real score 0.7747986495494843 Hits@1 0.8171875 Hits@3 0.96953125 Hits@10 1.0 MRR 0.8902852182539686 cur_rank 0 abs_cur_rank 0 total_num 1279 1429
0 0
0 5 3
checkcorrect 0 0 real score 0.0 Hits@1 0.8173302107728337 Hits@3 0.9695550351288056 Hits@10 1.0 MRR 0.8903708660148945 cur_rank 0 abs_cur_rank 0 total_num 1280 1429
0 0
9 7 29
checkcorrect 4 4 real score 0.8056516170501709 Hits@1 0.81747269

9 9
9 14 8
checkcorrect 4 4 real score 1.7926499605178834 Hits@1 0.8205904617713853 Hits@3 0.9704769114307343 Hits@10 1.0 MRR 0.8924287756990258 cur_rank 0 abs_cur_rank 0 total_num 1320 1429
0 2
9 14 15
checkcorrect 4 4 real score 0.6891565561294556 Hits@1 0.8207261724659607 Hits@3 0.970499243570348 Hits@10 1.0 MRR 0.8925101457627935 cur_rank 0 abs_cur_rank 0 total_num 1321 1429
9 21
9 33 6
checkcorrect 0 0 real score 0.669335699826479 Hits@1 0.8201058201058201 Hits@3 0.9705215419501134 Hits@10 1.0 MRR 0.892213463868793 cur_rank 1 abs_cur_rank 1 total_num 1322 1429
9 77
9 38 45
checkcorrect 4 4 real score 1.6844490081071855 Hits@1 0.8202416918429003 Hits@3 0.9705438066465257 Hits@10 1.0 MRR 0.8922948736392847 cur_rank 0 abs_cur_rank 0 total_num 1323 1429
9 8
9 13 39
checkcorrect 0 0 real score 1.4739970192313194 Hits@1 0.820377358490566 Hits@3 0.970566037735849 Hits@10 1.0 MRR 0.8923761605271042 cur_rank 0 abs_cur_rank 0 total_num 1324 1429
9 69
9 69 32
checkcorrect 4 4 real score 1.85

9 37
9 22 18
checkcorrect 4 4 real score 1.7252974569797517 Hits@1 0.8233137829912024 Hits@3 0.9714076246334311 Hits@10 1.0 MRR 0.8942314853605178 cur_rank 0 abs_cur_rank 0 total_num 1363 1429
9 136
9 26 33
checkcorrect 4 4 real score 1.4715208172798158 Hits@1 0.8234432234432234 Hits@3 0.9714285714285714 Hits@10 1.0 MRR 0.8943089714518287 cur_rank 0 abs_cur_rank 0 total_num 1364 1429
9 13
9 39 18
checkcorrect 4 4 real score 1.6591706305742262 Hits@1 0.8235724743777453 Hits@3 0.9714494875549048 Hits@10 1.0 MRR 0.8943863440935185 cur_rank 0 abs_cur_rank 0 total_num 1365 1429
9 150
9 43 40
checkcorrect 4 4 real score 1.7276365935802458 Hits@1 0.8237015362106803 Hits@3 0.9714703730797366 Hits@10 1.0 MRR 0.894463603534562 cur_rank 0 abs_cur_rank 0 total_num 1366 1429
0 1
9 6 9
checkcorrect 4 4 real score 0.6787554234266281 Hits@1 0.8238304093567251 Hits@3 0.9714912280701754 Hits@10 1.0 MRR 0.8945407500232063 cur_rank 0 abs_cur_rank 0 total_num 1367 1429
9 150
9 34 50
checkcorrect 0 0 real s

9 23
9 27 38
checkcorrect 4 4 real score 1.285265950858593 Hits@1 0.8244491826581379 Hits@3 0.9722814498933902 Hits@10 1.0 MRR 0.8952132760235108 cur_rank 0 abs_cur_rank 0 total_num 1406 1429
9 5
9 18 24
checkcorrect 4 4 real score 1.6979073345661164 Hits@1 0.8245738636363636 Hits@3 0.9723011363636364 Hits@10 1.0 MRR 0.8952876984126985 cur_rank 0 abs_cur_rank 0 total_num 1407 1429
0 0
9 28 7
checkcorrect 0 0 real score 0.7660936117172241 Hits@1 0.8246983676366217 Hits@3 0.9723207948899929 Hits@10 1.0 MRR 0.8953620151632928 cur_rank 0 abs_cur_rank 0 total_num 1408 1429
0 2
9 35 22
checkcorrect 4 4 real score 0.6234164535999298 Hits@1 0.824822695035461 Hits@3 0.9723404255319149 Hits@10 1.0 MRR 0.8954362265000564 cur_rank 0 abs_cur_rank 0 total_num 1409 1429
9 3
9 21 17
checkcorrect 4 4 real score 1.7661591559648513 Hits@1 0.8249468462083629 Hits@3 0.9723600283486888 Hits@10 1.0 MRR 0.8955103326471152 cur_rank 0 abs_cur_rank 0 total_num 1410 1429
9 3
9 15 23
checkcorrect 6 6 real score 1.

In [32]:
###########################################
##obtain the AUC-PR for the test triples###
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import auc, plot_precision_recall_curve
import matplotlib.pyplot as plt

#we select all the triples in the inductive test set
pos_triples = list(data_ind_test)

#we build the negative samples by randomly replace head or tail entity in the triple.
neg_triples = list()

for i in range(len(pos_triples)):
    
    s_pos, r_pos, t_pos = pos_triples[i][0], pos_triples[i][1], pos_triples[i][2]
    
    #decide to replace the head or tail entity
    number_0 = random.uniform(0, 1)
    
    if number_0 < 0.5: #replace head entity
        s_neg = random.choice(list(new_ent_set))
        
        #filter out the existing triples
        while ((s_neg, r_pos, t_pos) in data_test) or (
               (s_neg, r_pos, t_pos) in data_valid) or (
               (s_neg, r_pos, t_pos) in data) or (
               (s_neg, r_pos, t_pos) in data_ind) or (
               (s_neg, r_pos, t_pos) in data_ind_valid) or (
               (s_neg, r_pos, t_pos) in data_ind_test):
            
            s_neg = random.choice(list(new_ent_set))
        
        neg_triples.append((s_neg, r_pos, t_pos))
    
    else: #replace tail entity
        t_neg = random.choice(list(new_ent_set))
        
        #filter out the existing triples
        while ((s_pos, r_pos, t_neg) in data_test) or (
               (s_pos, r_pos, t_neg) in data_valid) or (
               (s_pos, r_pos, t_neg) in data) or (
               (s_pos, r_pos, t_neg) in data_ind) or (
               (s_pos, r_pos, t_neg) in data_ind_valid) or (
               (s_pos, r_pos, t_neg) in data_ind_test):
            
            t_neg = random.choice(list(new_ent_set))
        
        neg_triples.append((s_pos, r_pos, t_neg))

if len(pos_triples) != len(neg_triples):
    raise ValueError('error when generating negative triples')
        
#combine all triples
all_triples = pos_triples + neg_triples

#obtain the label array
arr1 = np.ones((len(pos_triples),))
arr2 = np.zeros((len(neg_triples),))
y_test = np.concatenate((arr1, arr2))

#shuffle positive and negative triples (optional)
all_triples, y_test = shuffle(all_triples, y_test)

#obtain the score aray
y_score = np.zeros((len(y_test),))

#implement the scoring
for i in range(len(all_triples)):
    
    s, r, t = all_triples[i][0], all_triples[i][1], all_triples[i][2]
    
    #path_score = path_based_triple_scoring(s, r, t, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)
    
    subg_score = subgraph_triple_scoring(s, r, t, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    #ave_score = (path_score + subg_score)/float(2)
    
    #y_score[i] = ave_score
    y_score[i] = subg_score
    
    if i % 20 == 0 and i > 0:
        print('evaluating scores', i, len(all_triples))
        
        # Data to plot precision - recall curve
        precision, recall, thresholds = precision_recall_curve(y_test[:i], y_score[:i])
        # Use AUC function to calculate the area under the curve of precision recall curve
        auc_precision_recall = auc(recall, precision)
        print('AUC-PR is:', auc_precision_recall)
        
        
# Data to plot precision - recall curve
precision, recall, thresholds = precision_recall_curve(y_test, y_score)
# Use AUC function to calculate the area under the curve of precision recall curve
auc_precision_recall = auc(recall, precision)
print('AUC-PR is:', auc_precision_recall)

evaluating scores 20 2858
AUC-PR is: 0.6311310556611308
evaluating scores 40 2858
AUC-PR is: 0.6833320470769427
evaluating scores 60 2858
AUC-PR is: 0.7186420618957116
evaluating scores 80 2858
AUC-PR is: 0.7052031339509326
evaluating scores 100 2858
AUC-PR is: 0.732042150545161
evaluating scores 120 2858
AUC-PR is: 0.7538639985917822
evaluating scores 140 2858
AUC-PR is: 0.7660928763574997
evaluating scores 160 2858
AUC-PR is: 0.7780692475347163
evaluating scores 180 2858
AUC-PR is: 0.7507508969394479
evaluating scores 200 2858
AUC-PR is: 0.7456688285116878
evaluating scores 220 2858
AUC-PR is: 0.731848281973119
evaluating scores 240 2858
AUC-PR is: 0.7247874156686389
evaluating scores 260 2858
AUC-PR is: 0.7255516994704185
evaluating scores 280 2858
AUC-PR is: 0.7249921676995906
evaluating scores 300 2858
AUC-PR is: 0.7350333510682039
evaluating scores 320 2858
AUC-PR is: 0.7309674530422767
evaluating scores 340 2858
AUC-PR is: 0.7308877272362144
evaluating scores 360 2858
AUC-PR is:

In [33]:
##########################################################
##obtain the AUC-PR for the test triples, using sklearn###
from sklearn import datasets, metrics
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import auc, plot_precision_recall_curve
import matplotlib.pyplot as plt

#we select all the triples in the inductive test set
pos_triples = list(data_ind_test)

#we build the negative samples by randomly replace head or tail entity in the triple.
neg_triples = list()

for i in range(len(pos_triples)):
    
    s_pos, r_pos, t_pos = pos_triples[i][0], pos_triples[i][1], pos_triples[i][2]
    
    #decide to replace the head or tail entity
    number_0 = random.uniform(0, 1)
    
    if number_0 < 0.5: #replace head entity
        s_neg = random.choice(list(new_ent_set))
        
        #filter out the existing triples
        while ((s_neg, r_pos, t_pos) in data_test) or (
               (s_neg, r_pos, t_pos) in data_valid) or (
               (s_neg, r_pos, t_pos) in data) or (
               (s_neg, r_pos, t_pos) in data_ind) or (
               (s_neg, r_pos, t_pos) in data_ind_valid) or (
               (s_neg, r_pos, t_pos) in data_ind_test):
            
            s_neg = random.choice(list(new_ent_set))
        
        neg_triples.append((s_neg, r_pos, t_pos))
    
    else: #replace tail entity
        t_neg = random.choice(list(new_ent_set))
        
        #filter out the existing triples
        while ((s_pos, r_pos, t_neg) in data_test) or (
               (s_pos, r_pos, t_neg) in data_valid) or (
               (s_pos, r_pos, t_neg) in data) or (
               (s_pos, r_pos, t_neg) in data_ind) or (
               (s_pos, r_pos, t_neg) in data_ind_valid) or (
               (s_pos, r_pos, t_neg) in data_ind_test):
            
            t_neg = random.choice(list(new_ent_set))
        
        neg_triples.append((s_pos, r_pos, t_neg))

if len(pos_triples) != len(neg_triples):
    raise ValueError('error when generating negative triples')
        
#combine all triples
all_triples = pos_triples + neg_triples

#obtain the label array
arr1 = np.ones((len(pos_triples),))
arr2 = np.zeros((len(neg_triples),))
y_test = np.concatenate((arr1, arr2))

#shuffle positive and negative triples (optional)
all_triples, y_test = shuffle(all_triples, y_test)

#obtain the score aray
y_score = np.zeros((len(y_test),))

#implement the scoring
for i in range(len(all_triples)):
    
    s, r, t = all_triples[i][0], all_triples[i][1], all_triples[i][2]
    
    #path_score = path_based_triple_scoring(s, r, t, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)
    
    subg_score = subgraph_triple_scoring(s, r, t, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    #ave_score = (path_score + subg_score)/float(2)
    
    #y_score[i] = ave_score
    y_score[i] = subg_score
    
    if i % 20 == 0 and i > 0:
        print('evaluating scores', i, len(all_triples))
        auc = metrics.roc_auc_score(y_test[:i], y_score[:i])
        auc_pr = metrics.average_precision_score(y_test[:i], y_score[:i])
        print('auc, auc-pr', auc, auc_pr)
        
print('evaluating scores', i, len(all_triples))
auc = metrics.roc_auc_score(y_test, y_score)
auc_pr = metrics.average_precision_score(y_test, y_score)
print('(final) auc, auc-pr', auc, auc_pr)

evaluating scores 20 2858
auc, auc-pr 0.6428571428571428 0.8378076031050379
evaluating scores 40 2858
auc, auc-pr 0.6483516483516484 0.7452322776120757
evaluating scores 60 2858
auc, auc-pr 0.6434285714285715 0.7084267409875705
evaluating scores 80 2858
auc, auc-pr 0.7476190476190476 0.7681339496392743
evaluating scores 100 2858
auc, auc-pr 0.746688077077479 0.7214242544758477
evaluating scores 120 2858
auc, auc-pr 0.7300785634118967 0.7311812081572158
evaluating scores 140 2858
auc, auc-pr 0.7608205128205128 0.7485674516633705
evaluating scores 160 2858
auc, auc-pr 0.7099308610936518 0.7093189602927049
evaluating scores 180 2858
auc, auc-pr 0.7168989547038327 0.7225214476030981
evaluating scores 200 2858
auc, auc-pr 0.6951121794871794 0.6905389597224258
evaluating scores 220 2858
auc, auc-pr 0.6758523667659716 0.6792864878173438
evaluating scores 240 2858
auc, auc-pr 0.6910916545062886 0.686455130007279
evaluating scores 260 2858
auc, auc-pr 0.6856017997750281 0.6886143394347848
evalu

evaluating scores 2160 2858
auc, auc-pr 0.7239963957352391 0.7291719441750266
evaluating scores 2180 2858
auc, auc-pr 0.7256329691917538 0.7306854698638788
evaluating scores 2200 2858
auc, auc-pr 0.7277678098432138 0.7327424542823548
evaluating scores 2220 2858
auc, auc-pr 0.7282224761706879 0.7332511634171501
evaluating scores 2240 2858
auc, auc-pr 0.728494623655914 0.7345634752303121
evaluating scores 2260 2858
auc, auc-pr 0.7285242825503335 0.7351961252502717
evaluating scores 2280 2858
auc, auc-pr 0.7279376176805306 0.733750031153221
evaluating scores 2300 2858
auc, auc-pr 0.7285268268633852 0.7346872951710492
evaluating scores 2320 2858
auc, auc-pr 0.728083923962488 0.7345190792033187
evaluating scores 2340 2858
auc, auc-pr 0.7266436845592527 0.7342376642866558
evaluating scores 2360 2858
auc, auc-pr 0.727305308625394 0.7367689257618482
evaluating scores 2380 2858
auc, auc-pr 0.7284609995791068 0.7371791826625107
evaluating scores 2400 2858
auc, auc-pr 0.7271782489006671 0.7345423

In [34]:
######################################################
#obtain the Hits@N for entity prediction##############

#we select all the triples in the inductive test set
selected = list(data_ind_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    triple_list = list()
    
    #score the true triple
    s_pos, r_pos, t_pos = selected[i][0], selected[i][1], selected[i][2]

    #path_score = path_based_triple_scoring(s_pos, r_pos, t_pos, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)

    subg_score = subgraph_triple_scoring(s_pos, r_pos, t_pos, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    #ave_score = (path_score + subg_score)/float(2)
    
    triple_list.append([(s_pos, r_pos, t_pos), subg_score])
    
    #generate the 50 random samples
    for sub_i in range(50):
        
        #decide to replace the head or tail entity
        number_0 = random.uniform(0, 1)

        if number_0 < 0.5: #replace head entity
            
            s_neg = random.choice(list(new_ent_set))
            
            while ((s_neg, r_pos, t_pos) in data_test) or (
                   (s_neg, r_pos, t_pos) in data_valid) or (
                   (s_neg, r_pos, t_pos) in data) or (
                   (s_neg, r_pos, t_pos) in data_ind) or (
                   (s_neg, r_pos, t_pos) in data_ind_valid) or (
                   (s_neg, r_pos, t_pos) in data_ind_test):

                s_neg = random.choice(list(new_ent_set))
            
            #path_score = path_based_triple_scoring(s_neg, r_pos, t_pos, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)

            subg_score = subgraph_triple_scoring(s_neg, r_pos, t_pos, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)

            #ave_score = (path_score + subg_score)/float(2)

            triple_list.append([(s_neg, r_pos, t_pos), subg_score])
            
        else: #replace tail entity

            t_neg = random.choice(list(new_ent_set))
            
            #filter out the existing triples
            while ((s_pos, r_pos, t_neg) in data_test) or (
                   (s_pos, r_pos, t_neg) in data_valid) or (
                   (s_pos, r_pos, t_neg) in data) or (
                   (s_pos, r_pos, t_neg) in data_ind) or (
                   (s_pos, r_pos, t_neg) in data_ind_valid) or (
                   (s_pos, r_pos, t_neg) in data_ind_test):

                t_neg = random.choice(list(new_ent_set))
            
            #path_score = path_based_triple_scoring(s_pos, r_pos, t_neg, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)

            subg_score = subgraph_triple_scoring(s_pos, r_pos, t_neg, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)

            #ave_score = (path_score + subg_score)/float(2)

            triple_list.append([(s_pos, r_pos, t_neg), subg_score])
            
    #random shuffle!
    random.shuffle(triple_list)
    
    #sort
    sorted_list = sorted(triple_list, key = lambda x: x[-1], reverse=True)
    
    p = 0
    
    while p < len(sorted_list) and sorted_list[p][0] != (s_pos, r_pos, t_pos):
            
        p += 1
    
    if p == 0:
        
        Hits_at_1 += 1
        
    if p < 3:
        
        Hits_at_3 += 1
        
    if p < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p + 1.) 
        
    print('checkcorrect', (s_pos, r_pos, t_pos), sorted_list[p][0],
          'real score', sorted_list[p][-1],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'rank', p,
          'total_num', i, len(selected))

checkcorrect (6804, 4, 5424) (6804, 4, 5424) real score 0.8000518888235092 Hits@1 0.0 Hits@3 1.0 Hits@10 1.0 MRR 0.3333333333333333 rank 2 total_num 0 1429
checkcorrect (6823, 4, 6068) (6823, 4, 6068) real score 0.651672613620758 Hits@1 0.0 Hits@3 0.5 Hits@10 0.5 MRR 0.2121212121212121 rank 10 total_num 1 1429
checkcorrect (6374, 0, 9656) (6374, 0, 9656) real score 0.0 Hits@1 0.0 Hits@3 0.3333333333333333 Hits@10 0.3333333333333333 MRR 0.15121806298276885 rank 33 total_num 2 1429
checkcorrect (6801, 4, 7818) (6801, 4, 7818) real score 0.4773150682449341 Hits@1 0.0 Hits@3 0.25 Hits@10 0.25 MRR 0.13614081996434937 rank 10 total_num 3 1429
checkcorrect (10426, 0, 4263) (10426, 0, 4263) real score 0.10916919112205506 Hits@1 0.0 Hits@3 0.2 Hits@10 0.2 MRR 0.11843646549528901 rank 20 total_num 4 1429
checkcorrect (10269, 4, 7231) (10269, 4, 7231) real score 0.2962911680340767 Hits@1 0.0 Hits@3 0.16666666666666666 Hits@10 0.16666666666666666 MRR 0.10564149902385195 rank 23 total_num 5 1429
ch

checkcorrect (4952, 4, 9721) (4952, 4, 9721) real score 0.6498563110828399 Hits@1 0.13333333333333333 Hits@3 0.3111111111111111 Hits@10 0.5777777777777777 MRR 0.26429669441580345 rank 3 total_num 44 1429
checkcorrect (9284, 4, 8543) (9284, 4, 8543) real score 0.7207768529653549 Hits@1 0.15217391304347827 Hits@3 0.32608695652173914 Hits@10 0.5869565217391305 MRR 0.28029024453719903 rank 0 total_num 45 1429
checkcorrect (8500, 4, 8321) (8500, 4, 8321) real score 0.5618980035185814 Hits@1 0.14893617021276595 Hits@3 0.3191489361702128 Hits@10 0.574468085106383 MRR 0.27521314713569833 rank 23 total_num 46 1429
checkcorrect (7914, 4, 9675) (7914, 4, 9675) real score 0.46443679109215735 Hits@1 0.14583333333333334 Hits@3 0.3125 Hits@10 0.5625 MRR 0.27096763514179983 rank 13 total_num 47 1429
checkcorrect (9995, 4, 5316) (9995, 4, 5316) real score 0.7696879625320434 Hits@1 0.14285714285714285 Hits@3 0.32653061224489793 Hits@10 0.5714285714285714 MRR 0.27564176503686516 rank 1 total_num 48 1429


checkcorrect (4677, 6, 6717) (4677, 6, 6717) real score 0.8018379747867584 Hits@1 0.13793103448275862 Hits@3 0.3563218390804598 Hits@10 0.5747126436781609 MRR 0.27785249999259076 rank 0 total_num 86 1429
checkcorrect (7345, 4, 6253) (7345, 4, 6253) real score 0.6297228872776032 Hits@1 0.13636363636363635 Hits@3 0.3522727272727273 Hits@10 0.5795454545454546 MRR 0.27583144885631133 rank 9 total_num 87 1429
checkcorrect (5334, 10, 4649) (5334, 10, 4649) real score 0.4221215099096298 Hits@1 0.1348314606741573 Hits@3 0.34831460674157305 Hits@10 0.5730337078651685 MRR 0.27311966581995606 rank 28 total_num 88 1429
checkcorrect (5712, 4, 9106) (5712, 4, 9106) real score 0.0 Hits@1 0.13333333333333333 Hits@3 0.34444444444444444 Hits@10 0.5666666666666667 MRR 0.2703214094857863 rank 46 total_num 89 1429
checkcorrect (4741, 2, 10130) (4741, 2, 10130) real score 0.38227771520614623 Hits@1 0.13186813186813187 Hits@3 0.34065934065934067 Hits@10 0.5714285714285714 MRR 0.2700980972936348 rank 3 total_

checkcorrect (5122, 4, 7282) (5122, 4, 7282) real score 0.82185577750206 Hits@1 0.1328125 Hits@3 0.3125 Hits@10 0.546875 MRR 0.26270331945316766 rank 0 total_num 127 1429
checkcorrect (9967, 4, 9339) (9967, 4, 9339) real score 0.25972591042518617 Hits@1 0.13178294573643412 Hits@3 0.31007751937984496 Hits@10 0.5426356589147286 MRR 0.26093416781880735 rank 28 total_num 128 1429
checkcorrect (8003, 4, 6453) (8003, 4, 6453) real score 0.6602312564849854 Hits@1 0.13076923076923078 Hits@3 0.3076923076923077 Hits@10 0.5384615384615384 MRR 0.2595680075535345 rank 11 total_num 129 1429
checkcorrect (4204, 4, 9586) (4204, 4, 9586) real score 0.7790071785449981 Hits@1 0.1297709923664122 Hits@3 0.31297709923664124 Hits@10 0.5419847328244275 MRR 0.2614033662744999 rank 1 total_num 130 1429
checkcorrect (4190, 2, 7201) (4190, 2, 7201) real score 0.2864329123403877 Hits@1 0.12878787878787878 Hits@3 0.3181818181818182 Hits@10 0.5454545454545454 MRR 0.2632109165299961 rank 1 total_num 131 1429
checkcor

checkcorrect (8100, 4, 5750) (8100, 4, 5750) real score 0.6097413152456284 Hits@1 0.1242603550295858 Hits@3 0.2958579881656805 Hits@10 0.5443786982248521 MRR 0.2569158968706063 rank 3 total_num 168 1429
checkcorrect (4398, 4, 4515) (4398, 4, 4515) real score 0.7094680070877075 Hits@1 0.12352941176470589 Hits@3 0.3 Hits@10 0.5470588235294118 MRR 0.2583458033596027 rank 1 total_num 169 1429
checkcorrect (8279, 4, 7478) (8279, 4, 7478) real score 0.41060999447945506 Hits@1 0.12280701754385964 Hits@3 0.2982456140350877 Hits@10 0.543859649122807 MRR 0.25708926802340365 rank 22 total_num 170 1429
checkcorrect (8484, 4, 4157) (8484, 4, 4157) real score 0.8446371257305145 Hits@1 0.12790697674418605 Hits@3 0.3023255813953488 Hits@10 0.5465116279069767 MRR 0.2614085164651281 rank 0 total_num 171 1429
checkcorrect (4961, 4, 5258) (4961, 4, 5258) real score 0.6818532556295395 Hits@1 0.12716763005780346 Hits@3 0.30057803468208094 Hits@10 0.5491329479768786 MRR 0.2613425712832487 rank 3 total_num 17

checkcorrect (5322, 4, 5321) (5322, 4, 5321) real score 0.5344616383314132 Hits@1 0.1523809523809524 Hits@3 0.319047619047619 Hits@10 0.5666666666666667 MRR 0.28154023651096816 rank 24 total_num 209 1429
checkcorrect (7373, 4, 9345) (7373, 4, 9345) real score 0.6706053048372269 Hits@1 0.15165876777251186 Hits@3 0.3175355450236967 Hits@10 0.5639810426540285 MRR 0.28057048693946157 rank 12 total_num 210 1429
checkcorrect (9060, 4, 7995) (9060, 4, 7995) real score 0.8039636701345444 Hits@1 0.15566037735849056 Hits@3 0.32075471698113206 Hits@10 0.5660377358490566 MRR 0.2839640223784264 rank 0 total_num 211 1429
checkcorrect (6259, 0, 9790) (6259, 0, 9790) real score 0.0 Hits@1 0.15492957746478872 Hits@3 0.3192488262910798 Hits@10 0.5633802816901409 MRR 0.2827544062615027 rank 37 total_num 212 1429
checkcorrect (5823, 0, 4637) (5823, 0, 4637) real score 0.660650098323822 Hits@1 0.1542056074766355 Hits@3 0.3177570093457944 Hits@10 0.5607476635514018 MRR 0.28176690236041424 rank 13 total_num 

checkcorrect (10805, 4, 9848) (10805, 4, 9848) real score 0.6974461108446122 Hits@1 0.14342629482071714 Hits@3 0.30677290836653387 Hits@10 0.5577689243027888 MRR 0.27421132675796356 rank 3 total_num 250 1429
checkcorrect (8763, 0, 10852) (8763, 0, 10852) real score 0.0 Hits@1 0.14285714285714285 Hits@3 0.3055555555555556 Hits@10 0.5555555555555556 MRR 0.2732094529630303 rank 45 total_num 251 1429
checkcorrect (4887, 4, 4886) (4887, 4, 4886) real score 0.6817069947719574 Hits@1 0.1422924901185771 Hits@3 0.30434782608695654 Hits@10 0.5533596837944664 MRR 0.27248889817230326 rank 10 total_num 252 1429
checkcorrect (8906, 0, 8024) (8906, 0, 8024) real score 0.4944456547498703 Hits@1 0.14173228346456693 Hits@3 0.3031496062992126 Hits@10 0.5551181102362205 MRR 0.2722035088094202 rank 4 total_num 253 1429
checkcorrect (10505, 4, 5903) (10505, 4, 5903) real score 0.6337791442871094 Hits@1 0.1411764705882353 Hits@3 0.30196078431372547 Hits@10 0.5529411764705883 MRR 0.27143770319417965 rank 12 t

checkcorrect (8345, 4, 10019) (8345, 4, 10019) real score 0.6964188277721405 Hits@1 0.13356164383561644 Hits@3 0.3047945205479452 Hits@10 0.5616438356164384 MRR 0.2710048934233015 rank 4 total_num 291 1429
checkcorrect (3935, 4, 9069) (3935, 4, 9069) real score 0.7056389302015305 Hits@1 0.13651877133105803 Hits@3 0.30716723549488056 Hits@10 0.5631399317406144 MRR 0.2734929313297066 rank 0 total_num 292 1429
checkcorrect (4748, 0, 7473) (4748, 0, 7473) real score 0.3215655263513327 Hits@1 0.1360544217687075 Hits@3 0.30612244897959184 Hits@10 0.5612244897959183 MRR 0.2727627632959721 rank 16 total_num 293 1429
checkcorrect (5762, 0, 8300) (5762, 0, 8300) real score 0.5305330604314804 Hits@1 0.13559322033898305 Hits@3 0.3050847457627119 Hits@10 0.559322033898305 MRR 0.27203754555399173 rank 16 total_num 294 1429
checkcorrect (8131, 4, 6385) (8131, 4, 6385) real score 0.6810719907283783 Hits@1 0.13851351351351351 Hits@3 0.30743243243243246 Hits@10 0.5608108108108109 MRR 0.2744968781703634 

checkcorrect (8543, 4, 8542) (8543, 4, 8542) real score 0.38855174034833906 Hits@1 0.13213213213213212 Hits@3 0.29429429429429427 Hits@10 0.5435435435435435 MRR 0.26590928788848744 rank 20 total_num 332 1429
checkcorrect (5191, 4, 6922) (5191, 4, 6922) real score 0.636806321144104 Hits@1 0.1317365269461078 Hits@3 0.2934131736526946 Hits@10 0.5449101796407185 MRR 0.2654458202933456 rank 8 total_num 333 1429
checkcorrect (7230, 0, 5222) (7230, 0, 5222) real score 0.6367726616561413 Hits@1 0.13134328358208955 Hits@3 0.29253731343283584 Hits@10 0.5462686567164179 MRR 0.26525045963575355 rank 4 total_num 334 1429
checkcorrect (4338, 4, 6918) (4338, 4, 6918) real score 0.6178750693798065 Hits@1 0.13095238095238096 Hits@3 0.2916666666666667 Hits@10 0.5446428571428571 MRR 0.2646263676593244 rank 17 total_num 335 1429
checkcorrect (8825, 4, 6041) (8825, 4, 6041) real score 0.8018056392669678 Hits@1 0.13056379821958458 Hits@3 0.29376854599406527 Hits@10 0.5459940652818991 MRR 0.2648302458957458 

checkcorrect (9667, 4, 9528) (9667, 4, 9528) real score 0.7350888729095459 Hits@1 0.12834224598930483 Hits@3 0.2807486631016043 Hits@10 0.5427807486631016 MRR 0.26030175031366076 rank 2 total_num 373 1429
checkcorrect (6652, 2, 10250) (6652, 2, 10250) real score 0.8977586627006531 Hits@1 0.13066666666666665 Hits@3 0.2826666666666667 Hits@10 0.544 MRR 0.262274278979491 rank 0 total_num 374 1429
checkcorrect (9682, 0, 7612) (9682, 0, 7612) real score 0.40711998343467715 Hits@1 0.13031914893617022 Hits@3 0.28191489361702127 Hits@10 0.5425531914893617 MRR 0.2618185205005804 rank 10 total_num 375 1429
checkcorrect (9025, 0, 6625) (9025, 0, 6625) real score 0.0 Hits@1 0.129973474801061 Hits@3 0.28116710875331563 Hits@10 0.5411140583554377 MRR 0.26119384482146396 rank 37 total_num 376 1429
checkcorrect (6414, 0, 7572) (6414, 0, 7572) real score 0.011618781043216586 Hits@1 0.12962962962962962 Hits@3 0.2804232804232804 Hits@10 0.5396825396825397 MRR 0.2605706892151665 rank 38 total_num 377 1429

checkcorrect (9154, 4, 9444) (9154, 4, 9444) real score 0.7306611955165863 Hits@1 0.13012048192771083 Hits@3 0.27228915662650605 Hits@10 0.5253012048192771 MRR 0.25706876167760795 rank 7 total_num 414 1429
checkcorrect (3862, 4, 5173) (3862, 4, 5173) real score 0.6698927491903305 Hits@1 0.12980769230769232 Hits@3 0.27163461538461536 Hits@10 0.5264423076923077 MRR 0.25671790194066924 rank 8 total_num 415 1429
checkcorrect (3862, 4, 3861) (3862, 4, 3861) real score 0.8230519354343414 Hits@1 0.13189448441247004 Hits@3 0.2733812949640288 Hits@10 0.5275779376498801 MRR 0.25850035301515206 rank 0 total_num 416 1429
checkcorrect (10694, 0, 5417) (10694, 0, 5417) real score 0.12761461585760117 Hits@1 0.13157894736842105 Hits@3 0.2727272727272727 Hits@10 0.5263157894736842 MRR 0.2579644257558351 rank 28 total_num 417 1429
checkcorrect (6869, 2, 9186) (6869, 2, 9186) real score 0.3860436499118805 Hits@1 0.13126491646778043 Hits@3 0.2744630071599045 Hits@10 0.5274463007159904 MRR 0.25814430381687

checkcorrect (10486, 0, 5936) (10486, 0, 5936) real score 0.14582612160593272 Hits@1 0.12280701754385964 Hits@3 0.25877192982456143 Hits@10 0.5219298245614035 MRR 0.24840491278902077 rank 20 total_num 455 1429
checkcorrect (5050, 4, 10738) (5050, 4, 10738) real score 0.5305334568023682 Hits@1 0.12253829321663019 Hits@3 0.25820568927789933 Hits@10 0.5207877461706784 MRR 0.2480176560245559 rank 13 total_num 456 1429
checkcorrect (10199, 2, 9282) (10199, 2, 9282) real score 0.3036777630448341 Hits@1 0.1222707423580786 Hits@3 0.2576419213973799 Hits@10 0.5218340611353712 MRR 0.2478400337770496 rank 5 total_num 457 1429
checkcorrect (10312, 4, 5175) (10312, 4, 5175) real score 0.7273694336414337 Hits@1 0.12418300653594772 Hits@3 0.25925925925925924 Hits@10 0.5228758169934641 MRR 0.24947872651391878 rank 0 total_num 458 1429
checkcorrect (6464, 4, 10916) (6464, 4, 10916) real score 0.7070821553468705 Hits@1 0.12391304347826088 Hits@3 0.2608695652173913 Hits@10 0.5239130434782608 MRR 0.249661

checkcorrect (6948, 0, 9693) (6948, 0, 9693) real score 0.4334400504827499 Hits@1 0.12096774193548387 Hits@3 0.25806451612903225 Hits@10 0.5161290322580645 MRR 0.24704002357086177 rank 6 total_num 495 1429
checkcorrect (5771, 4, 5563) (5771, 4, 5563) real score 0.6828381180763244 Hits@1 0.12072434607645875 Hits@3 0.2575452716297787 Hits@10 0.5150905432595574 MRR 0.24671063385207398 rank 11 total_num 496 1429
checkcorrect (5828, 0, 9476) (5828, 0, 9476) real score 0.46938281655311587 Hits@1 0.12048192771084337 Hits@3 0.2570281124497992 Hits@10 0.5140562248995983 MRR 0.24630650516051267 rank 21 total_num 497 1429
checkcorrect (4692, 4, 7624) (4692, 4, 7624) real score 0.6670344084501266 Hits@1 0.12024048096192384 Hits@3 0.2565130260521042 Hits@10 0.5150300601202404 MRR 0.2461469062857755 rank 5 total_num 498 1429
checkcorrect (7863, 0, 4946) (7863, 0, 4946) real score 0.0 Hits@1 0.12 Hits@3 0.256 Hits@10 0.514 MRR 0.2457415689949431 rank 22 total_num 499 1429
checkcorrect (10643, 0, 5306

checkcorrect (6030, 4, 6029) (6030, 4, 6029) real score 0.7639159083366394 Hits@1 0.1191806331471136 Hits@3 0.2569832402234637 Hits@10 0.5176908752327747 MRR 0.2456123466508295 rank 0 total_num 536 1429
checkcorrect (4808, 4, 7070) (4808, 4, 7070) real score 0.7990691423416137 Hits@1 0.120817843866171 Hits@3 0.258364312267658 Hits@10 0.5185873605947955 MRR 0.2470145541849358 rank 0 total_num 537 1429
checkcorrect (4556, 4, 4077) (4556, 4, 4077) real score 0.8067680925130845 Hits@1 0.12059369202226346 Hits@3 0.2597402597402597 Hits@10 0.5194805194805194 MRR 0.24748391493783942 rank 1 total_num 538 1429
checkcorrect (7076, 2, 8757) (7076, 2, 8757) real score 0.48893132209777834 Hits@1 0.12037037037037036 Hits@3 0.2611111111111111 Hits@10 0.5203703703703704 MRR 0.24764289534227554 rank 2 total_num 539 1429
checkcorrect (9620, 4, 5705) (9620, 4, 5705) real score 0.585905385017395 Hits@1 0.12014787430683918 Hits@3 0.26062846580406657 Hits@10 0.5194085027726433 MRR 0.24731717570472714 rank 1

checkcorrect (9528, 4, 6382) (9528, 4, 6382) real score 0.8305343508720398 Hits@1 0.1245674740484429 Hits@3 0.2664359861591695 Hits@10 0.5224913494809689 MRR 0.25131652248132097 rank 1 total_num 577 1429
checkcorrect (4441, 0, 4436) (4441, 0, 4436) real score 0.7360121071338653 Hits@1 0.12435233160621761 Hits@3 0.26770293609671847 Hits@10 0.5233160621761658 MRR 0.25174602762384024 rank 1 total_num 578 1429
checkcorrect (4260, 4, 5661) (4260, 4, 5661) real score 0.608814948797226 Hits@1 0.12413793103448276 Hits@3 0.2672413793103448 Hits@10 0.5224137931034483 MRR 0.2514134026269229 rank 16 total_num 579 1429
checkcorrect (5417, 4, 5416) (5417, 4, 5416) real score 0.7327927023172378 Hits@1 0.12392426850258176 Hits@3 0.2667814113597246 Hits@10 0.5232358003442341 MRR 0.251410969920164 rank 3 total_num 580 1429
checkcorrect (9977, 0, 4899) (9977, 0, 4899) real score 0.46042028442025185 Hits@1 0.12371134020618557 Hits@3 0.2663230240549828 Hits@10 0.5240549828178694 MRR 0.25132263492030116 ran

checkcorrect (5815, 4, 5738) (5815, 4, 5738) real score 0.6604584515094757 Hits@1 0.12277867528271405 Hits@3 0.2649434571890145 Hits@10 0.5282714054927302 MRR 0.25048515947420497 rank 11 total_num 618 1429
checkcorrect (7312, 4, 6173) (7312, 4, 6173) real score 0.4017525613307953 Hits@1 0.12258064516129032 Hits@3 0.2645161290322581 Hits@10 0.5274193548387097 MRR 0.2501349145933326 rank 29 total_num 619 1429
checkcorrect (9208, 4, 8168) (9208, 4, 8168) real score 0.6890942960977554 Hits@1 0.12238325281803543 Hits@3 0.2640901771336554 Hits@10 0.5281803542673108 MRR 0.24996216455833067 rank 6 total_num 620 1429
checkcorrect (9828, 4, 9479) (9828, 4, 9479) real score 0.7467141091823578 Hits@1 0.12218649517684887 Hits@3 0.26366559485530544 Hits@10 0.5289389067524116 MRR 0.24982824896686495 rank 5 total_num 621 1429
checkcorrect (10589, 0, 5505) (10589, 0, 5505) real score 0.6340525031089783 Hits@1 0.12199036918138041 Hits@3 0.26324237560192615 Hits@10 0.5280898876404494 MRR 0.24949412122641

checkcorrect (9897, 4, 4677) (9897, 4, 4677) real score 0.8583393812179565 Hits@1 0.12121212121212122 Hits@3 0.26515151515151514 Hits@10 0.5303030303030303 MRR 0.2502726103375471 rank 0 total_num 659 1429
checkcorrect (4795, 4, 4386) (4795, 4, 4386) real score 0.7803653419017792 Hits@1 0.12102874432677761 Hits@3 0.26626323751891073 Hits@10 0.5310136157337367 MRR 0.25065041274248273 rank 1 total_num 660 1429
checkcorrect (10561, 0, 4866) (10561, 0, 4866) real score 0.8736135542392731 Hits@1 0.12235649546827794 Hits@3 0.2673716012084592 Hits@10 0.5317220543806647 MRR 0.25178236075948807 rank 0 total_num 661 1429
checkcorrect (6783, 4, 9405) (6783, 4, 9405) real score 0.03940975535660982 Hits@1 0.12217194570135746 Hits@3 0.2669683257918552 Hits@10 0.530920060331825 MRR 0.25144696016212215 rank 33 total_num 662 1429
checkcorrect (5049, 4, 7434) (5049, 4, 7434) real score 0.6687075376510621 Hits@1 0.12198795180722892 Hits@3 0.26656626506024095 Hits@10 0.5316265060240963 MRR 0.25131927899721

checkcorrect (9686, 0, 7014) (9686, 0, 7014) real score 0.09965866280253977 Hits@1 0.11840228245363767 Hits@3 0.26390870185449355 Hits@10 0.5306704707560628 MRR 0.24827308478744947 rank 28 total_num 700 1429
checkcorrect (4615, 0, 7672) (4615, 0, 7672) real score 0.34352488815784454 Hits@1 0.11823361823361823 Hits@3 0.26353276353276356 Hits@10 0.5299145299145299 MRR 0.24800321362594566 rank 16 total_num 701 1429
checkcorrect (4520, 4, 4683) (4520, 4, 4683) real score 0.7739460527896881 Hits@1 0.11806543385490754 Hits@3 0.26458036984352773 Hits@10 0.5305832147937412 MRR 0.24836167278152751 rank 1 total_num 702 1429
checkcorrect (5870, 4, 10814) (5870, 4, 10814) real score 0.8036500811576843 Hits@1 0.11931818181818182 Hits@3 0.265625 Hits@10 0.53125 MRR 0.24942934085996285 rank 0 total_num 703 1429
checkcorrect (7163, 4, 7689) (7163, 4, 7689) real score 0.3272116851061583 Hits@1 0.11914893617021277 Hits@3 0.2652482269503546 Hits@10 0.5304964539007092 MRR 0.24914646236228916 rank 19 total

checkcorrect (7930, 4, 6747) (7930, 4, 6747) real score 0.8253880202770233 Hits@1 0.11725067385444744 Hits@3 0.2628032345013477 Hits@10 0.532345013477089 MRR 0.24786781166157834 rank 1 total_num 741 1429
checkcorrect (7197, 4, 4849) (7197, 4, 4849) real score 0.1760527916252613 Hits@1 0.11709286675639301 Hits@3 0.26244952893674295 Hits@10 0.531628532974428 MRR 0.24756479613138413 rank 43 total_num 742 1429
checkcorrect (7181, 4, 5743) (7181, 4, 5743) real score 0.17180822379887103 Hits@1 0.11693548387096774 Hits@3 0.2620967741935484 Hits@10 0.5309139784946236 MRR 0.24729925205056238 rank 19 total_num 743 1429
checkcorrect (8723, 4, 3931) (8723, 4, 3931) real score 0.7871675491333008 Hits@1 0.11677852348993288 Hits@3 0.26174496644295303 Hits@10 0.5315436241610738 MRR 0.2473028772155952 rank 3 total_num 744 1429
checkcorrect (6248, 4, 6286) (6248, 4, 6286) real score 0.5707873538136482 Hits@1 0.11662198391420911 Hits@3 0.2613941018766756 Hits@10 0.5308310991957105 MRR 0.24707448606238805

checkcorrect (7335, 4, 5727) (7335, 4, 5727) real score 0.6609603464603424 Hits@1 0.11749680715197956 Hits@3 0.25798212005108556 Hits@10 0.5300127713920817 MRR 0.2462603228665023 rank 5 total_num 782 1429
checkcorrect (4387, 4, 4150) (4387, 4, 4150) real score 0.6930803894996643 Hits@1 0.11734693877551021 Hits@3 0.2576530612244898 Hits@10 0.5306122448979592 MRR 0.24607376633223377 rank 9 total_num 783 1429
checkcorrect (6088, 4, 4438) (6088, 4, 4438) real score 0.9225254476070404 Hits@1 0.11847133757961784 Hits@3 0.2585987261146497 Hits@10 0.5312101910828025 MRR 0.24703418191652393 rank 0 total_num 784 1429
checkcorrect (9881, 0, 7321) (9881, 0, 7321) real score 0.6625404000282288 Hits@1 0.1183206106870229 Hits@3 0.2582697201017812 Hits@10 0.5305343511450382 MRR 0.24681076510928734 rank 13 total_num 785 1429
checkcorrect (6749, 4, 9067) (6749, 4, 9067) real score 0.7990912139415741 Hits@1 0.1181702668360864 Hits@3 0.25921219822109276 Hits@10 0.531130876747141 MRR 0.24692070484019468 ra

checkcorrect (10607, 4, 9613) (10607, 4, 9613) real score 0.6761652171611786 Hits@1 0.11771844660194175 Hits@3 0.25728155339805825 Hits@10 0.5303398058252428 MRR 0.24575080918536993 rank 6 total_num 823 1429
checkcorrect (4494, 4, 5952) (4494, 4, 5952) real score 0.4891082912683487 Hits@1 0.11757575757575757 Hits@3 0.25696969696969696 Hits@10 0.5296969696969697 MRR 0.24556312225412594 rank 10 total_num 824 1429
checkcorrect (4742, 6, 7305) (4742, 6, 7305) real score 0.4645273588597775 Hits@1 0.11743341404358354 Hits@3 0.25786924939467315 Hits@10 0.5302663438256658 MRR 0.24566938158957294 rank 2 total_num 825 1429
checkcorrect (3876, 0, 4792) (3876, 0, 4792) real score 0.04921239521354437 Hits@1 0.11729141475211609 Hits@3 0.25755743651753327 Hits@10 0.5296251511487303 MRR 0.24540500147522887 rank 36 total_num 826 1429
checkcorrect (5026, 4, 7417) (5026, 4, 7417) real score 0.6436939537525177 Hits@1 0.11714975845410629 Hits@3 0.2572463768115942 Hits@10 0.5301932367149759 MRR 0.2454105509

checkcorrect (3972, 4, 8678) (3972, 4, 8678) real score 0.8893337786197663 Hits@1 0.11805555555555555 Hits@3 0.2604166666666667 Hits@10 0.5335648148148148 MRR 0.24715422582894991 rank 0 total_num 863 1429
checkcorrect (5707, 4, 4510) (5707, 4, 4510) real score 0.42264038920402525 Hits@1 0.11791907514450867 Hits@3 0.26011560693641617 Hits@10 0.5329479768786127 MRR 0.24692354932234886 rank 20 total_num 864 1429
checkcorrect (9244, 4, 7125) (9244, 4, 7125) real score 0.7666876912117004 Hits@1 0.11893764434180139 Hits@3 0.26096997690531176 Hits@10 0.5334872979214781 MRR 0.24779315261412443 rank 0 total_num 865 1429
checkcorrect (8784, 0, 8999) (8784, 0, 8999) real score 0.2816000372171402 Hits@1 0.11880046136101499 Hits@3 0.2606689734717416 Hits@10 0.5340253748558247 MRR 0.2476226876168763 rank 9 total_num 866 1429
checkcorrect (7902, 4, 6818) (7902, 4, 6818) real score 0.6356777787208557 Hits@1 0.11866359447004608 Hits@3 0.26036866359447003 Hits@10 0.5345622119815668 MRR 0.247567822769391

checkcorrect (4337, 4, 5994) (4337, 4, 5994) real score 0.5870719298720359 Hits@1 0.11836283185840708 Hits@3 0.2588495575221239 Hits@10 0.536504424778761 MRR 0.24693914962623328 rank 8 total_num 903 1429
checkcorrect (8014, 0, 9301) (8014, 0, 9301) real score 0.0 Hits@1 0.11823204419889503 Hits@3 0.2585635359116022 Hits@10 0.5359116022099447 MRR 0.24670575198655156 rank 27 total_num 904 1429
checkcorrect (10648, 0, 7337) (10648, 0, 7337) real score 0.13033900200389326 Hits@1 0.11810154525386314 Hits@3 0.2582781456953642 Hits@10 0.5353200883002207 MRR 0.2464794395303486 rank 23 total_num 905 1429
checkcorrect (7389, 0, 4099) (7389, 0, 4099) real score 0.6997329890727997 Hits@1 0.11797133406835722 Hits@3 0.2579933847850055 Hits@10 0.5347298787210585 MRR 0.24630791764653243 rank 10 total_num 906 1429
checkcorrect (10074, 4, 4361) (10074, 4, 4361) real score 0.7353520721197129 Hits@1 0.11784140969162996 Hits@3 0.2577092511013216 Hits@10 0.5352422907488987 MRR 0.24631198381652525 rank 3 tot

checkcorrect (3894, 4, 4354) (3894, 4, 4354) real score 0.7012547433376313 Hits@1 0.11522198731501057 Hits@3 0.25792811839323465 Hits@10 0.5359408033826638 MRR 0.24461101858252754 rank 8 total_num 945 1429
checkcorrect (5162, 4, 5581) (5162, 4, 5581) real score 0.716680034995079 Hits@1 0.11615628299894404 Hits@3 0.2587117212249208 Hits@10 0.5364308342133052 MRR 0.24540868382161674 rank 0 total_num 946 1429
checkcorrect (5346, 0, 5344) (5346, 0, 5344) real score 0.18745028525590895 Hits@1 0.1160337552742616 Hits@3 0.25843881856540085 Hits@10 0.5358649789029536 MRR 0.24519567704635087 rank 22 total_num 947 1429
checkcorrect (7447, 4, 10119) (7447, 4, 10119) real score 0.5110747307538986 Hits@1 0.11591148577449947 Hits@3 0.2581664910432034 Hits@10 0.5363540569020021 MRR 0.24511292782571895 rank 5 total_num 948 1429
checkcorrect (3906, 4, 3905) (3906, 4, 3905) real score 0.8273166596889496 Hits@1 0.11578947368421053 Hits@3 0.25894736842105265 Hits@10 0.5368421052631579 MRR 0.24538123000695

checkcorrect (4653, 4, 4181) (4653, 4, 4181) real score 0.6397791832685471 Hits@1 0.11448834853090173 Hits@3 0.25835866261398177 Hits@10 0.541033434650456 MRR 0.24560004744984157 rank 3 total_num 986 1429
checkcorrect (5090, 4, 5659) (5090, 4, 5659) real score 0.7609439253807068 Hits@1 0.11437246963562753 Hits@3 0.2591093117408907 Hits@10 0.541497975708502 MRR 0.24585753728035792 rank 1 total_num 987 1429
checkcorrect (4932, 4, 6460) (4932, 4, 6460) real score 0.6219876736402512 Hits@1 0.11425682507583418 Hits@3 0.25884732052578363 Hits@10 0.5409504550050556 MRR 0.24570086544378433 rank 10 total_num 988 1429
checkcorrect (4178, 4, 8621) (4178, 4, 8621) real score 0.13511207289993762 Hits@1 0.11414141414141414 Hits@3 0.2585858585858586 Hits@10 0.5404040404040404 MRR 0.24550078279951693 rank 20 total_num 989 1429
checkcorrect (6244, 16, 6002) (6244, 16, 6002) real score 0.008214320428669453 Hits@1 0.11402623612512613 Hits@3 0.25832492431886983 Hits@10 0.5398587285570131 MRR 0.24528560346

checkcorrect (8359, 16, 5236) (8359, 16, 5236) real score 0.041398338647559284 Hits@1 0.11197663096397274 Hits@3 0.2590068159688413 Hits@10 0.5374878286270691 MRR 0.24426748144818355 rank 23 total_num 1026 1429
checkcorrect (6130, 8, 6129) (6130, 8, 6129) real score 0.4749195620417595 Hits@1 0.11186770428015565 Hits@3 0.2587548638132296 Hits@10 0.5379377431906615 MRR 0.24427305782809777 rank 3 total_num 1027 1429
checkcorrect (6448, 4, 4307) (6448, 4, 4307) real score 0.7178565859794617 Hits@1 0.11175898931000972 Hits@3 0.2585034013605442 Hits@10 0.5383867832847424 MRR 0.24427862336956707 rank 3 total_num 1028 1429
checkcorrect (9045, 4, 7871) (9045, 4, 7871) real score 0.5854726269841194 Hits@1 0.11165048543689321 Hits@3 0.258252427184466 Hits@10 0.537864077669903 MRR 0.24410618457665162 rank 14 total_num 1029 1429
checkcorrect (8827, 4, 3986) (8827, 4, 3986) real score 0.6125819861888886 Hits@1 0.11154219204655674 Hits@3 0.2580019398642095 Hits@10 0.5373423860329777 MRR 0.24393003890

checkcorrect (8294, 8, 8147) (8294, 8, 8147) real score 0.9161451280117034 Hits@1 0.1105904404873477 Hits@3 0.2614807872539831 Hits@10 0.5388940955951266 MRR 0.24432278344144967 rank 1 total_num 1066 1429
checkcorrect (6913, 4, 3900) (6913, 4, 3900) real score 0.4415543735027313 Hits@1 0.1104868913857678 Hits@3 0.2612359550561798 Hits@10 0.5383895131086143 MRR 0.24413860391352607 rank 20 total_num 1067 1429
checkcorrect (5850, 4, 4103) (5850, 4, 4103) real score 0.7348300635814666 Hits@1 0.11038353601496725 Hits@3 0.26192703461178674 Hits@10 0.538821328344247 MRR 0.2443779504019138 rank 1 total_num 1068 1429
checkcorrect (3928, 0, 8003) (3928, 0, 8003) real score -0.02518251962028444 Hits@1 0.1102803738317757 Hits@3 0.2616822429906542 Hits@10 0.5383177570093458 MRR 0.244173523369427 rank 38 total_num 1069 1429
checkcorrect (10566, 4, 7155) (10566, 4, 7155) real score 0.6084582388401032 Hits@1 0.11017740429505135 Hits@3 0.26143790849673204 Hits@10 0.5378151260504201 MRR 0.24399740948724

checkcorrect (5285, 4, 5284) (5285, 4, 5284) real score 0.40378288179636 Hits@1 0.1093044263775971 Hits@3 0.26196928635953026 Hits@10 0.5356820234869015 MRR 0.24338239052030683 rank 8 total_num 1106 1429
checkcorrect (10645, 4, 5111) (10645, 4, 5111) real score 0.6856196969747543 Hits@1 0.11010830324909747 Hits@3 0.26263537906137185 Hits@10 0.5361010830324909 MRR 0.24406525839889862 rank 0 total_num 1107 1429
checkcorrect (5615, 2, 7158) (5615, 2, 7158) real score 0.5790218263864517 Hits@1 0.11000901713255185 Hits@3 0.26330027051397653 Hits@10 0.5365193868349865 MRR 0.24429603814786263 rank 1 total_num 1108 1429
checkcorrect (5842, 10, 5058) (5842, 10, 5058) real score 0.442401596903801 Hits@1 0.10990990990990991 Hits@3 0.26306306306306304 Hits@10 0.536036036036036 MRR 0.24412600167705872 rank 17 total_num 1109 1429
checkcorrect (4941, 0, 5594) (4941, 0, 5594) real score 0.4153663866221905 Hits@1 0.10981098109810981 Hits@3 0.26282628262826285 Hits@10 0.5355535553555355 MRR 0.2439662723

checkcorrect (7965, 4, 8164) (7965, 4, 8164) real score 0.6812143832445144 Hits@1 0.10985178727114212 Hits@3 0.26068003487358327 Hits@10 0.5300784655623365 MRR 0.2425560489487127 rank 0 total_num 1146 1429
checkcorrect (5278, 4, 4527) (5278, 4, 4527) real score 0.8144953668117523 Hits@1 0.10975609756097561 Hits@3 0.2613240418118467 Hits@10 0.5304878048780488 MRR 0.24263512323824632 rank 2 total_num 1147 1429
checkcorrect (9653, 0, 7981) (9653, 0, 7981) real score 0.057361011113971475 Hits@1 0.10966057441253264 Hits@3 0.26109660574412535 Hits@10 0.5300261096605744 MRR 0.24245742647168697 rank 25 total_num 1148 1429
checkcorrect (8767, 4, 9155) (8767, 4, 9155) real score 0.5964163452386856 Hits@1 0.10956521739130434 Hits@3 0.2608695652173913 Hits@10 0.5295652173913044 MRR 0.2423190576950449 rank 11 total_num 1149 1429
checkcorrect (6222, 4, 7058) (6222, 4, 7058) real score 0.21065251380205155 Hits@1 0.10947002606429192 Hits@3 0.26064291920069504 Hits@10 0.529105125977411 MRR 0.2421407066

checkcorrect (5825, 0, 4728) (5825, 0, 4728) real score 0.6821044147014618 Hits@1 0.11204717775905644 Hits@3 0.2620050547598989 Hits@10 0.5290648694187026 MRR 0.24378655673766908 rank 13 total_num 1186 1429
checkcorrect (8291, 4, 5683) (8291, 4, 5683) real score 0.7136280626058579 Hits@1 0.11195286195286196 Hits@3 0.2617845117845118 Hits@10 0.5294612794612794 MRR 0.24374969936667776 rank 4 total_num 1187 1429
checkcorrect (8100, 4, 9306) (8100, 4, 9306) real score 0.29093727469444275 Hits@1 0.1118587047939445 Hits@3 0.26156433978132887 Hits@10 0.5290159798149706 MRR 0.24359726059513304 rank 15 total_num 1188 1429
checkcorrect (5081, 4, 5703) (5081, 4, 5703) real score 0.7014019787311554 Hits@1 0.11176470588235295 Hits@3 0.26134453781512607 Hits@10 0.5285714285714286 MRR 0.24346258502600546 rank 11 total_num 1189 1429
checkcorrect (8322, 0, 6337) (8322, 0, 6337) real score 0.4104331940412521 Hits@1 0.11167086481947942 Hits@3 0.2611251049538203 Hits@10 0.5281276238455079 MRR 0.2433048125

checkcorrect (5350, 0, 3888) (5350, 0, 3888) real score 0.5470288008451462 Hits@1 0.11165444172779136 Hits@3 0.2624286878565607 Hits@10 0.530562347188264 MRR 0.2438265186882005 rank 9 total_num 1226 1429
checkcorrect (6658, 4, 3893) (6658, 4, 3893) real score 0.68611781001091 Hits@1 0.11156351791530944 Hits@3 0.26221498371335505 Hits@10 0.5301302931596091 MRR 0.24368225170772692 rank 14 total_num 1227 1429
checkcorrect (8311, 4, 7778) (8311, 4, 7778) real score 0.6521175980567933 Hits@1 0.1114727420667209 Hits@3 0.26200162733930027 Hits@10 0.5305126118795769 MRR 0.24364670878526334 rank 4 total_num 1228 1429
checkcorrect (10265, 0, 4326) (10265, 0, 4326) real score 0.2948948871344328 Hits@1 0.11138211382113822 Hits@3 0.26178861788617885 Hits@10 0.5300813008130081 MRR 0.24349943503828347 rank 15 total_num 1229 1429
checkcorrect (8656, 4, 9184) (8656, 4, 9184) real score 0.6970500677824021 Hits@1 0.11129163281884646 Hits@3 0.26238830219333875 Hits@10 0.5304630381803412 MRR 0.243707802678

checkcorrect (10449, 0, 5520) (10449, 0, 5520) real score 0.4914641439914703 Hits@1 0.11041009463722397 Hits@3 0.2610410094637224 Hits@10 0.5291798107255521 MRR 0.24297646574337234 rank 12 total_num 1267 1429
checkcorrect (8200, 4, 8199) (8200, 4, 8199) real score 0.5294870406389236 Hits@1 0.11032308904649331 Hits@3 0.26083530338849487 Hits@10 0.52876280535855 MRR 0.24283752973149156 rank 14 total_num 1268 1429
checkcorrect (6423, 4, 10373) (6423, 4, 10373) real score 0.7132550358772278 Hits@1 0.11023622047244094 Hits@3 0.2606299212598425 Hits@10 0.5291338582677165 MRR 0.2427338081420267 rank 8 total_num 1269 1429
checkcorrect (3864, 0, 9500) (3864, 0, 9500) real score 0.0 Hits@1 0.11014948859166011 Hits@3 0.26042486231313927 Hits@10 0.5295043273013376 MRR 0.24270018594836654 rank 4 total_num 1270 1429
checkcorrect (4964, 4, 8808) (4964, 4, 8808) real score 0.755115520954132 Hits@1 0.11006289308176101 Hits@3 0.2610062893081761 Hits@10 0.529874213836478 MRR 0.24290246567639456 rank 1 to

checkcorrect (5401, 6, 7146) (5401, 6, 7146) real score 0.7510314106941223 Hits@1 0.1099236641221374 Hits@3 0.26183206106870227 Hits@10 0.5297709923664122 MRR 0.24327656829827632 rank 1 total_num 1309 1429
checkcorrect (7671, 4, 6637) (7671, 4, 6637) real score 0.4843511790037155 Hits@1 0.10983981693363844 Hits@3 0.2616323417238749 Hits@10 0.5293668954996186 MRR 0.24313114877932063 rank 18 total_num 1310 1429
checkcorrect (7070, 4, 4027) (7070, 4, 4027) real score 0.6816487848758698 Hits@1 0.10975609756097561 Hits@3 0.2614329268292683 Hits@10 0.5297256097560976 MRR 0.24303052375061007 rank 8 total_num 1311 1429
checkcorrect (8012, 4, 5104) (8012, 4, 5104) real score 0.7925066351890564 Hits@1 0.10967250571210967 Hits@3 0.26123381568926124 Hits@10 0.5300837776085301 MRR 0.24294062997776117 rank 7 total_num 1312 1429
checkcorrect (4904, 4, 4076) (4904, 4, 4076) real score 0.8075861394405365 Hits@1 0.11035007610350075 Hits@3 0.2617960426179604 Hits@10 0.530441400304414 MRR 0.24351677866118

checkcorrect (7124, 4, 10564) (7124, 4, 10564) real score 0.803963017463684 Hits@1 0.1111111111111111 Hits@3 0.2651851851851852 Hits@10 0.5355555555555556 MRR 0.24532044459047106 rank 2 total_num 1349 1429
checkcorrect (6538, 4, 8570) (6538, 4, 8570) real score 0.6032017976045608 Hits@1 0.11102886750555144 Hits@3 0.26572908956328645 Hits@10 0.535899333826795 MRR 0.24538559106622446 rank 2 total_num 1350 1429
checkcorrect (9699, 4, 8825) (9699, 4, 8825) real score 0.7817675292491912 Hits@1 0.11168639053254438 Hits@3 0.26627218934911245 Hits@10 0.5362426035502958 MRR 0.2459437378183944 rank 0 total_num 1351 1429
checkcorrect (7367, 0, 6692) (7367, 0, 6692) real score 0.17383537739515303 Hits@1 0.11160384331116038 Hits@3 0.2660753880266075 Hits@10 0.5358462675535847 MRR 0.24578744736813743 rank 28 total_num 1352 1429
checkcorrect (10335, 4, 6204) (10335, 4, 6204) real score 0.5521692227572202 Hits@1 0.11152141802067947 Hits@3 0.2658788774002954 Hits@10 0.5354505169867061 MRR 0.24564695114

checkcorrect (5044, 4, 5043) (5044, 4, 5043) real score 0.6136346518993377 Hits@1 0.11294964028776978 Hits@3 0.26618705035971224 Hits@10 0.5381294964028777 MRR 0.24656009250900657 rank 13 total_num 1389 1429
checkcorrect (5198, 4, 7040) (5198, 4, 7040) real score 0.6379575461149216 Hits@1 0.11286843997124371 Hits@3 0.26599568655643424 Hits@10 0.5384615384615384 MRR 0.2464547293943344 rank 9 total_num 1390 1429
checkcorrect (4610, 4, 4609) (4610, 4, 4609) real score 0.6840748131275177 Hits@1 0.11278735632183907 Hits@3 0.26580459770114945 Hits@10 0.5380747126436781 MRR 0.24633754448337103 rank 11 total_num 1391 1429
checkcorrect (6140, 4, 6139) (6140, 4, 6139) real score 0.5664209961891175 Hits@1 0.11270638908829864 Hits@3 0.2656137832017229 Hits@10 0.5376884422110553 MRR 0.2462205278206646 rank 11 total_num 1392 1429
checkcorrect (6510, 0, 5901) (6510, 0, 5901) real score 0.42181834280490876 Hits@1 0.11262553802008608 Hits@3 0.2654232424677188 Hits@10 0.5373027259684362 MRR 0.2461036790

#### Fine tuned

In [34]:
#function to build the big batche for path-based training
def build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    count = 0
    for s in existing_ids:
        
        #impliment the path finding algorithm to find paths between s and t
        result, length_dict = Class_2.obtain_paths('direct_neighbour', s, 'nb', lower_bd, upper_bd, one_hop)
        
        for iteration in range(10):

            #proceed only if at least three paths are between s and t
            for t in result:

                if len(s_t_r[(s,t)]) == 0:

                    raise ValueError(s,t,id2entity[s], id2entity[t])

                #we are only interested in forward link in relation prediciton
                ini_r_list = list()

                #obtain initial relations between s and t
                for r in s_t_r[(s,t)]:
                    if r % 2 == 0:#initial relation id is always an even number
                        ini_r_list.append(r)

                #if there exist more than three paths between s and t, 
                #and inital connection between s and t exists,
                #and not every r in the relation dictionary exists between s and t (although this is rare)
                #we then proceed
                if len(result[t]) >= 3 and len(ini_r_list) > 0 and len(ini_r_list) < int(num_ini_r):

                    #obtain the list form of all the paths from s to t
                    temp_path_list = list(result[t])

                    temp_pair = random.sample(temp_path_list, 3)

                    path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

                    #####positive#####################
                    #append the paths: note that we add the space holder id at the end of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    r = random.choice(ini_r_list)
                    x_r_list.append([r])
                    y_list.append(1.)

                    #####negative#####################
                    #append the paths: note that we add the space holder id at the end
                    #of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    neg_r_list = list(ini_r_id_set.difference(set(ini_r_list)))
                    r_ran = random.choice(neg_r_list)
                    x_r_list.append([r_ran])
                    y_list.append(0.)
        
        count += 1
        if count % 100 == 0:
            print('generating big-batches for path-based model', count, len(existing_ids))

In [35]:
#Again, it is too slow to run the path-finding algorithm again and again on the complete FB15K-237
#Instead, we will find the subgraph for each entity once.
#then in the subgraph based training, the subgraphs are stored and used for multiple times
def store_subgraph_dicts(lower_bd, upper_bd, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
    
    #the ids to start path finding
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    #Dict stores the subgraph for each entity
    Dict_1 = dict()
    
    count = 0
    for s in existing_ids:
        
        path_set = set()
            
        result, length_dict = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)

        for t_ in result:
            for path in result[t_]:
                path_set.add(path)

        del(result, length_dict)
        
        path_list = list(path_set)
        
        path_select = random.sample(path_list, min(len(path_list), 100))
            
        Dict_1[s] = deepcopy(path_select)
        
        count += 1
        if count % 100 == 0:
            print('generating and storing paths for the path-based model', count, len(existing_ids))
        
    return(Dict_1)

In [36]:
#function to build the big-batch for one-hope neighbor training
def build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_s_list, x_t_list, x_r_list, y_list, Dict,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
        
    #if an entity has at least three out-stretching paths, it is a qualified one
    qualified = set()
    for e in Dict:
        if len(Dict[e]) >= 6:
            qualified.add(e)
    qualified = list(qualified)
    
    data = list(data)
    
    for iteration in range(10):

        data = shuffle(data)

        for i_0 in range(len(data)):

            triple = data[i_0]

            s, r, t = triple[0], triple[1], triple[2] #obtain entities and relation IDs

            if s in qualified and t in qualified:

                #obtain the path list for true entities
                path_s, path_t = list(Dict[s]), list(Dict[t])

                #####positive step###########
                #randomly obtain three paths for true entities
                temp_s = random.sample(path_s, 6)
                temp_t = random.sample(path_t, 6)
                s_p_1, s_p_2, s_p_3, s_p_4, s_p_5, s_p_6 = temp_s[0], temp_s[1], temp_s[2], temp_s[3], temp_s[4], temp_s[5]
                t_p_1, t_p_2, t_p_3, t_p_4, t_p_5, t_p_6 = temp_t[0], temp_t[1], temp_t[2], temp_t[3], temp_t[4], temp_t[5]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                x_s_list['4'].append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                x_s_list['5'].append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                x_s_list['6'].append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                x_t_list['4'].append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))
                x_t_list['5'].append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))
                x_t_list['6'].append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative step for relation###########
                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                x_s_list['4'].append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                x_s_list['5'].append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                x_s_list['6'].append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                x_t_list['4'].append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))
                x_t_list['5'].append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))
                x_t_list['6'].append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))

                #append relation
                neg_r_list = list(ini_r_id_set.difference({r}))
                r_ran = random.choice(neg_r_list)
                x_r_list.append([r_ran])
                y_list.append(0.)
                
                ##############################################
                ##############################################
                #randomly choose two negative sampled entities
                s_ran = random.choice(qualified)
                t_ran = random.choice(qualified)

                #obtain the path list for random entities
                path_s_ran, path_t_ran = list(Dict[s_ran]), list(Dict[t_ran])
                
                #####positive step#################
                #Again: randomly obtain three paths
                temp_s = random.sample(path_s, 6)
                temp_t = random.sample(path_t, 6)
                s_p_1, s_p_2, s_p_3, s_p_4, s_p_5, s_p_6 = temp_s[0], temp_s[1], temp_s[2], temp_s[3], temp_s[4], temp_s[5]
                t_p_1, t_p_2, t_p_3, t_p_4, t_p_5, t_p_6 = temp_t[0], temp_t[1], temp_t[2], temp_t[3], temp_t[4], temp_t[5]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                x_s_list['4'].append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                x_s_list['5'].append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                x_s_list['6'].append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                x_t_list['4'].append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))
                x_t_list['5'].append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))
                x_t_list['6'].append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative for source entity###########
                #randomly obtain three paths
                temp_s = random.sample(path_s_ran, 6)
                s_p_1, s_p_2, s_p_3, s_p_4, s_p_5, s_p_6 = temp_s[0], temp_s[1], temp_s[2], temp_s[3], temp_s[4], temp_s[5]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                x_s_list['4'].append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                x_s_list['5'].append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                x_s_list['6'].append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                x_t_list['4'].append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))
                x_t_list['5'].append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))
                x_t_list['6'].append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(0.)

                #####positive step###########
                #Again: randomly obtain three paths
                temp_s = random.sample(path_s, 6)
                temp_t = random.sample(path_t, 6)
                s_p_1, s_p_2, s_p_3, s_p_4, s_p_5, s_p_6 = temp_s[0], temp_s[1], temp_s[2], temp_s[3], temp_s[4], temp_s[5]
                t_p_1, t_p_2, t_p_3, t_p_4, t_p_5, t_p_6 = temp_t[0], temp_t[1], temp_t[2], temp_t[3], temp_t[4], temp_t[5]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                x_s_list['4'].append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                x_s_list['5'].append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                x_s_list['6'].append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                x_t_list['4'].append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))
                x_t_list['5'].append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))
                x_t_list['6'].append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative for target entity###########
                #randomly obtain three paths
                temp_t = random.sample(path_t_ran, 6)
                t_p_1, t_p_2, t_p_3, t_p_4, t_p_5, t_p_6 = temp_t[0], temp_t[1], temp_t[2], temp_t[3], temp_t[4], temp_t[5]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                x_s_list['4'].append(list(s_p_4) + [num_r]*abs(len(s_p_4)-upper_bd))
                x_s_list['5'].append(list(s_p_5) + [num_r]*abs(len(s_p_5)-upper_bd))
                x_s_list['6'].append(list(s_p_6) + [num_r]*abs(len(s_p_6)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                x_t_list['4'].append(list(t_p_4) + [num_r]*abs(len(t_p_4)-upper_bd))
                x_t_list['5'].append(list(t_p_5) + [num_r]*abs(len(t_p_5)-upper_bd))
                x_t_list['6'].append(list(t_p_6) + [num_r]*abs(len(t_p_6)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(0.)

            if i_0 % 2000 == 0:
                print('generating big-batches for subgraph-based model', i_0, len(data), iteration)

In [37]:
###fine tune the path-based model
lower_bd = lower_bound
upper_bd = upper_bound_path
batch_size = 32

#define the training lists
train_p_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_path(lower_bd, upper_bd, data_ind, one_hop_ind, s_t_r_ind,
                      train_p_list, train_r_list, train_y_list,
                      relation2id, entity2id, id2relation, id2entity)   

#######################################
###do the training#####################

#generate the input arrays
x_train_1 = np.asarray(train_p_list['1'], dtype='int')
x_train_2 = np.asarray(train_p_list['2'], dtype='int')
x_train_3 = np.asarray(train_p_list['3'], dtype='int')
x_train_r = np.asarray(train_r_list, dtype='int')
y_train = np.asarray(train_y_list, dtype='int')

model.fit([x_train_1, x_train_2, x_train_3, x_train_r], y_train,
           batch_size=batch_size, epochs=5)

generating big-batches for path-based model 100 225
generating big-batches for path-based model 200 225
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f9acdf6db20>

In [39]:
###fine tune the subgraph model
lower_bd = lower_bound
upper_bd = upper_bound_subg
batch_size = 32

Dict_train_ind = store_subgraph_dicts(lower_bd, upper_bd, data_ind, one_hop_ind, s_t_r_ind,
                         relation2id, entity2id, id2relation, id2entity)

#define the training lists
train_s_list, train_t_list, train_r_list, train_y_list = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': []}, {'1': [], '2': [], '3': [], '4': [], '5': [], '6': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_subgraph(lower_bd, upper_bd, data_ind, one_hop_ind, s_t_r_ind,
                      train_s_list, train_t_list, train_r_list, train_y_list, Dict_train_ind,
                      relation2id, entity2id, id2relation, id2entity)

#######################################
###do the training#####################

#generate the input arrays
x_train_s_1 = np.asarray(train_s_list['1'], dtype='int')
x_train_s_2 = np.asarray(train_s_list['2'], dtype='int')
x_train_s_3 = np.asarray(train_s_list['3'], dtype='int')
x_train_s_4 = np.asarray(train_s_list['4'], dtype='int')
x_train_s_5 = np.asarray(train_s_list['5'], dtype='int')
x_train_s_6 = np.asarray(train_s_list['6'], dtype='int')

x_train_t_1 = np.asarray(train_t_list['1'], dtype='int')
x_train_t_2 = np.asarray(train_t_list['2'], dtype='int')
x_train_t_3 = np.asarray(train_t_list['3'], dtype='int')
x_train_t_4 = np.asarray(train_t_list['4'], dtype='int')
x_train_t_5 = np.asarray(train_t_list['5'], dtype='int')
x_train_t_6 = np.asarray(train_t_list['6'], dtype='int')

x_train_r = np.asarray(train_r_list, dtype='int')
y_train = np.asarray(train_y_list, dtype='int')

model_2.fit([x_train_s_1, x_train_s_2, x_train_s_3, x_train_s_4, x_train_s_5, x_train_s_6,
             x_train_t_1, x_train_t_2, x_train_t_3, x_train_t_4, x_train_t_5, x_train_t_6,
             x_train_r], y_train,
             batch_size=batch_size, epochs=5)

generating and storing paths for the path-based model 100 225
generating and storing paths for the path-based model 200 225
generating big-batches for subgraph-based model 0 833 0
generating big-batches for subgraph-based model 0 833 1
generating big-batches for subgraph-based model 0 833 2
generating big-batches for subgraph-based model 0 833 3
generating big-batches for subgraph-based model 0 833 4
generating big-batches for subgraph-based model 0 833 5
generating big-batches for subgraph-based model 0 833 6
generating big-batches for subgraph-based model 0 833 7
generating big-batches for subgraph-based model 0 833 8
generating big-batches for subgraph-based model 0 833 9
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f9a299a7550>

In [None]:
########################################################
#obtain the Hits@N for relation prediction##############

#we select all the triples in the inductive test set
selected = list(data_ind_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    #run the path-based scoring
    score_dict_path = path_based_relation_scoring(s_true, t_true, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)
    
    #run the one-hop neighbour based scoring
    score_dict_subg = subgraph_relation_scoring(s_true, t_true, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    #final score dict
    score_dict = defaultdict(float)
    
    for r in score_dict_path:
        score_dict[r] += score_dict_path[r]
    for r in score_dict_subg:
        score_dict[r] += score_dict_subg[r]
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        #again, we only care about initial relation prediciton
        if r % 2 == 0:
        
            if r in score_dict:

                temp_list.append([score_dict[r], r])

            else:

                temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
            (s_true, sorted_list[p][1], t_true) in data_valid) or (
            (s_true, sorted_list[p][1], t_true) in data) or (
            (s_true, sorted_list[p][1], t_true) in data_ind) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_valid) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_test):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', r_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

In [None]:
###########################################
##obtain the AUC-PR for the test triples###
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import auc, plot_precision_recall_curve
import matplotlib.pyplot as plt

#we select all the triples in the inductive test set
pos_triples = list(data_ind_test)

#we build the negative samples by randomly replace head or tail entity in the triple.
neg_triples = list()

for i in range(len(pos_triples)):
    
    s_pos, r_pos, t_pos = pos_triples[i][0], pos_triples[i][1], pos_triples[i][2]
    
    #decide to replace the head or tail entity
    number_0 = random.uniform(0, 1)
    
    if number_0 < 0.5: #replace head entity
        s_neg = random.choice(list(new_ent_set))
        
        #filter out the existing triples
        while ((s_neg, r_pos, t_pos) in data_test) or (
               (s_neg, r_pos, t_pos) in data_valid) or (
               (s_neg, r_pos, t_pos) in data) or (
               (s_neg, r_pos, t_pos) in data_ind) or (
               (s_neg, r_pos, t_pos) in data_ind_valid) or (
               (s_neg, r_pos, t_pos) in data_ind_test):
            
            s_neg = random.choice(list(new_ent_set))
        
        neg_triples.append((s_neg, r_pos, t_pos))
    
    else: #replace tail entity
        t_neg = random.choice(list(new_ent_set))
        
        #filter out the existing triples
        while ((s_pos, r_pos, t_neg) in data_test) or (
               (s_pos, r_pos, t_neg) in data_valid) or (
               (s_pos, r_pos, t_neg) in data) or (
               (s_pos, r_pos, t_neg) in data_ind) or (
               (s_pos, r_pos, t_neg) in data_ind_valid) or (
               (s_pos, r_pos, t_neg) in data_ind_test):
            
            t_neg = random.choice(list(new_ent_set))
        
        neg_triples.append((s_pos, r_pos, t_neg))

if len(pos_triples) != len(neg_triples):
    raise ValueError('error when generating negative triples')
        
#combine all triples
all_triples = pos_triples + neg_triples

#obtain the label array
arr1 = np.ones((len(pos_triples),))
arr2 = np.zeros((len(neg_triples),))
y_test = np.concatenate((arr1, arr2))

#shuffle positive and negative triples (optional)
all_triples, y_test = shuffle(all_triples, y_test)

#obtain the score aray
y_score = np.zeros((len(y_test),))

#implement the scoring
for i in range(len(all_triples)):
    
    s, r, t = all_triples[i][0], all_triples[i][1], all_triples[i][2]
    
    #path_score = path_based_triple_scoring(s, r, t, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)
    
    subg_score = subgraph_triple_scoring(s, r, t, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    #ave_score = (path_score + subg_score)/float(2)
    
    #y_score[i] = ave_score
    y_score[i] = subg_score
    
    if i % 20 == 0 and i > 0:
        print('evaluating scores', i, len(all_triples))
        
        # Data to plot precision - recall curve
        precision, recall, thresholds = precision_recall_curve(y_test[:i], y_score[:i])
        # Use AUC function to calculate the area under the curve of precision recall curve
        auc_precision_recall = auc(recall, precision)
        print('AUC-PR is:', auc_precision_recall)
        
        
# Data to plot precision - recall curve
precision, recall, thresholds = precision_recall_curve(y_test, y_score)
# Use AUC function to calculate the area under the curve of precision recall curve
auc_precision_recall = auc(recall, precision)
print('AUC-PR is:', auc_precision_recall)

In [None]:
######################################################
#obtain the Hits@N for entity prediction##############

#we select all the triples in the inductive test set
selected = list(data_ind_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    triple_list = list()
    
    #score the true triple
    s_pos, r_pos, t_pos = selected[i][0], selected[i][1], selected[i][2]

    #path_score = path_based_triple_scoring(s_pos, r_pos, t_pos, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)

    subg_score = subgraph_triple_scoring(s_pos, r_pos, t_pos, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    #ave_score = (path_score + subg_score)/float(2)
    
    triple_list.append([(s_pos, r_pos, t_pos), subg_score])
    
    #generate the 50 random samples
    for sub_i in range(50):
        
        #decide to replace the head or tail entity
        number_0 = random.uniform(0, 1)

        if number_0 < 0.5: #replace head entity
            
            s_neg = random.choice(list(new_ent_set))
            
            while ((s_neg, r_pos, t_pos) in data_test) or (
                   (s_neg, r_pos, t_pos) in data_valid) or (
                   (s_neg, r_pos, t_pos) in data) or (
                   (s_neg, r_pos, t_pos) in data_ind) or (
                   (s_neg, r_pos, t_pos) in data_ind_valid) or (
                   (s_neg, r_pos, t_pos) in data_ind_test):

                s_neg = random.choice(list(new_ent_set))
            
            #path_score = path_based_triple_scoring(s_neg, r_pos, t_pos, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)

            subg_score = subgraph_triple_scoring(s_neg, r_pos, t_pos, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)

            #ave_score = (path_score + subg_score)/float(2)

            triple_list.append([(s_neg, r_pos, t_pos), subg_score])
            
        else: #replace tail entity

            t_neg = random.choice(list(new_ent_set))
            
            #filter out the existing triples
            while ((s_pos, r_pos, t_neg) in data_test) or (
                   (s_pos, r_pos, t_neg) in data_valid) or (
                   (s_pos, r_pos, t_neg) in data) or (
                   (s_pos, r_pos, t_neg) in data_ind) or (
                   (s_pos, r_pos, t_neg) in data_ind_valid) or (
                   (s_pos, r_pos, t_neg) in data_ind_test):

                t_neg = random.choice(list(new_ent_set))
            
            #path_score = path_based_triple_scoring(s_pos, r_pos, t_neg, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)

            subg_score = subgraph_triple_scoring(s_pos, r_pos, t_neg, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)

            #ave_score = (path_score + subg_score)/float(2)

            triple_list.append([(s_pos, r_pos, t_neg), subg_score])
            
    #random shuffle!
    random.shuffle(triple_list)
    
    #sort
    sorted_list = sorted(triple_list, key = lambda x: x[-1], reverse=True)
    
    p = 0
    
    while p < len(sorted_list) and sorted_list[p][0] != (s_pos, r_pos, t_pos):
            
        p += 1
    
    if p == 0:
        
        Hits_at_1 += 1
        
    if p < 3:
        
        Hits_at_3 += 1
        
    if p < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p + 1.) 
        
    print('checkcorrect', (s_pos, r_pos, t_pos), sorted_list[p][0],
          'real score', sorted_list[p][-1],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'rank', p,
          'total_num', i, len(selected))