In [1]:
data_name = 'nell_v4'
model_id = 'main_7'

In [2]:
#difine the names for saving
model_name = 'Model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [3]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
from collections import defaultdict
from copy import deepcopy
from sklearn.utils import shuffle
import pickle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [4]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = dict()
            
            if r not in one_hop[s]:
                
                one_hop[s][r] = set()
            
            one_hop[s][r].add(t)
            
            if t not in one_hop:
                
                one_hop[t] = dict()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            if r_inv not in one_hop[t]:
                
                one_hop[t][r_inv] = set()
            
            one_hop[t][r_inv].add(s)

In [5]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, size_bd=50, threshold=100000):
                
        self.size_bd = size_bd
        
        self.threshold = threshold
    
    '''
    Given an entity s, here is the function to find:
      1. any else entity t that is directely connected to s
      2. most of the paths from s to each t with length L
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on active entities for validation')
        
        #here is the result dict. Its key is each entity t that is directly connected to s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #direct_nb contains all the direct neighbour of s
        direct_nb = set()
        
        if mode == 'direct_neighbour':
        
            for r in one_hop[s]:
            
                for t in one_hop[s][r]:
                
                    direct_nb.add(t)
                    
        elif mode == 'target_specified':
            
            direct_nb.add(t_input)
            
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                direct_nb.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {e1, ..., ek-1, node}
        from s to this node, we further find the direct neighbor t' of this node. 
        If t' is not a on-path entity (not among e1,...ek-1), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict):
            
            #when the current path is within lower_bd and upper_bd and its corresponding
            #length still within the size_bd and its tail node is within the note dict, 
            #we will then intend to add this path
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in direct_nb) and (length_dict[len(path)] < self.size_bd):
                
                #if this path already exists between the source entity and the current target node,
                #we will not count it.
                #here is an interesting situation: this path may exist between s and some other node t,
                #however, it does not exist between s and this node t. Then, we still count it: length_dict[len(path)] += 1
                #That is, each path may be counted for multiple times.
                #We count how many paths we "actually" found between entity pairs
                #Same type of path between different entity pairs are count separately.
                if tuple(path) not in res[node]:
                
                    res[node].add(tuple(path))
                
                    length_dict[len(path)] += 1
                
            #For some rare entities, we may face such a case: so many paths are evaluated,
            #but no entities on the paths are direct neighbors of the rare entity.
            #In this case, the recursion cannot be bounded and stoped by the size threshold.
            #In order to cure this, we count how many times the recursion happens on a specific length, using the count_dict.
            #Its key is length, value counts the recursion occurred to that length. 
            #The recursion is forced to stop for that length (and hence for longer lengths) once reach the threshold.
            if (len(path) < upper_bd) and (length_dict[len(path) + 1] < self.size_bd) and (
                count_dict[len(path)] <= self.threshold):
                
                #we randomly shuffle relation r so that the reading in order is not fixed
                temp_list = list()
                
                for r in one_hop[node]:
                    
                    temp_list.append(r)
                
                for i_0 in range(len(temp_list)):
                    
                    if count_dict[len(path)] > self.threshold:
                        break
                    
                    r = random.choice(temp_list)
                    
                    for i_1 in range(len(one_hop[node][r])):
                        
                        if count_dict[len(path)] > self.threshold:
                            break
                        
                        t = random.choice(list(one_hop[node][r]))
                        
                        if t not in on_path_en:
                                
                            count_dict[len(path)] += 1

                            helper(t, path + [r], on_path_en.union({t}), res, direct_nb, 
                                   lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        return(res, length_dict)

In [6]:
train_path = '../data/' + data_name + '/train.txt'

In [7]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [8]:
#define the dictionaries and sets for load KG
one_hop = dict() 
data = set()
s_t_r = defaultdict(set)
entity2id = dict()
id2entity = dict()
relation2id = dict()
id2relation = dict()

#fill in the sets and dicts
Class_1.load_train_data(train_path, one_hop, data, s_t_r,
                        entity2id, id2entity, relation2id, id2relation)

### Build the deep neural network structure

We use biLSTM to train on the input path embedding sequence to predict the output embedding or the relation.

In [9]:
# Input layer, using integer to represent each relation type
#note that inputs_path is the path inputs, while inputs_out_re is the output relation inputs
fst_path = keras.Input(shape=(None,), dtype="int32")
scd_path = keras.Input(shape=(None,), dtype="int32")
thd_path = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of two paths are not the same
in_embd_var = layers.Embedding(len(relation2id)+1, 300)

# Obtain the embedding
fst_p_embd = in_embd_var(fst_path)
scd_p_embd = in_embd_var(scd_path)
thd_p_embd = in_embd_var(thd_path)

# Embed each integer in a 300-dimensional vector as output
rela_embd = layers.Embedding(len(relation2id)+1, 300)(id_rela)

#add 2 layer bi-directional LSTM
lstm_layer_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_layer_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

#first LSTM layer
fst_lstm_mid = lstm_layer_1(fst_p_embd)
scd_lstm_mid = lstm_layer_1(scd_p_embd)
thd_lstm_mid = lstm_layer_1(thd_p_embd)

#second LSTM layer
fst_lstm_out = lstm_layer_2(fst_lstm_mid)
scd_lstm_out = lstm_layer_2(scd_lstm_mid)
thd_lstm_out = lstm_layer_2(thd_lstm_mid)

#reduce max
fst_reduce_max = tf.reduce_sum(fst_lstm_out, axis=1)
scd_reduce_max = tf.reduce_sum(scd_lstm_out, axis=1)
thd_reduce_max = tf.reduce_sum(thd_lstm_out, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
path_concat = layers.concatenate([fst_reduce_max, scd_reduce_max, thd_reduce_max], axis=-1)

#multiply into output embd size by dense layer: (Batch, 300)
path_out_vect = layers.Dense(300, activation='tanh')(path_concat)

#remove the time dimension from the output embd since there is only one step
rela_out_embd = tf.reduce_sum(rela_embd, axis=1)

# Normalize the vectors to have unit length
path_out_vect_norm = tf.math.l2_normalize(path_out_vect, axis=-1)
rela_out_embd_norm = tf.math.l2_normalize(rela_out_embd, axis=-1)

# Calculate the dot product
dot_product = layers.Dot(axes=-1)([path_out_vect_norm, rela_out_embd_norm])

#put together the model
model = keras.Model([fst_path, scd_path, thd_path, id_rela], dot_product)

2023-01-24 23:18:26.087583: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [10]:
#config the Adam optimizer 
opt = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['binary_accuracy'])

### Build the batches
We build each big-batch for each path combination with length (i,j). Then, we iteratively train the siamese network on different big-batches. The length of each big-batch is N.

To be specific:
* If we allow the length difference between two paths in a combination to be d, then the combination with path length i and path length j, denoted as (i,j), will be like (2,2), (2,3), (2,4), (3,3), (3,4), (3,5), ... 
* We will first build all the big-batches before fitting the NN model. 
* That is, we will perform the ObtainPathsByDynamicProgramming class function for some randomly chosen source entities. Then, for each target entity, we will further have two for loops:
* for path_1 in all the 
* Do this until all the slots in all big-batchs are filled.
* In every epoch, big-batchs will be re-filled.

Then, in the training, we will use negative sampling: In each batch (actual batch, not the big-batch), we will include K true output relation embeddings and K random selected output relation embeddings. The true label is [1,0], while the false label is [0,1].

In [11]:
#function to build all the big batches
def build_big_batches(holder_len, lower_bd, upper_bd, Class_2, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    if holder_len % 10 != 0:
        raise ValueError('We would like to take 10X as a big-batch size')
    
    #the set of all relation IDs
    relation_id_set = set()
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #count how many appending has performed
    count = 0

    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    
    carry_on = True
    
    while carry_on:

        #obtain paths by dynamic programming
        source_id = random.choice(existing_ids)

        result, length_dict = Class_2.obtain_paths('direct_neighbour', source_id, 
                                                   'not_specified', lower_bd, upper_bd, one_hop)
        
        #We want to increase the diversity of paths and targets.
        #So we abandon one sub-graph from a source_id, if we sampled more than K1 path pairs
        #Note that we mean "sampled", not "appended"! 
        #We do not care whether the pair is actually appended.
        threshold_0 = 1000
        count_0 = 0
        
        for target_id in result:

            if (not carry_on) or (count_0 > threshold_0):
                break
            
            #we want to make sure s, t are indeed directly connected, 
            #otherwise there is no relation for positive sample
            #also, we want to make sure s and t and not connected by all relations, 
            #although this situation is rare. 
            #But in that case, there is no relation for negative samples
            #Also, we want at least two different paths here between s and t
            if ((source_id, target_id) in s_t_r) and (
                len(s_t_r[(source_id, target_id)]) < len(id2relation)) and (
                len(result[target_id]) >= 3):
                
                dir_r = list(s_t_r[(source_id, target_id)])
                
                non_dir_r = list(relation_id_set.difference(dir_r))
                
                if len(dir_r) <= 0:
                    
                    raise ValueError('errors when creating s_t_r !!')
                    
                temp_path_list = list(result[target_id])
                    
                #futhermore, we will abandon one targed_id if we sampled more than K2 times
                threshold_1 = 50
                count_1 = 0
                
                while count_1 <= threshold_1 and count_0 <= threshold_0:
                
                    temp_pair = random.sample(temp_path_list, 3)
                    
                    path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]                         

                    #proceed when the entire length not yet reached,
                    #and whether this path pair is new, and whether the two paths are different
                    #But it is optional to require the path to be new. 
                    #We may remove this requirment, especially for short paths
                    '''remember to cancel the comment below when using path_comb'''
                    if (carry_on) and (path_1 != path_2) and (path_2 != path_3) and (
                        path_1 != path_3):

                        #####positive#####################
                        #we randomly choose one direction relation as the target relation
                        relation_id = random.choice(dir_r)

                        #append the paths: note that we add the space holder id at the end
                        #of the shorter path
                        x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                        x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                        x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                        #append relation
                        x_r_list.append([relation_id])
                        y_list.append(1.)

                        #####negative#####################
                        relation_id = random.choice(non_dir_r)

                        #append the paths: note that we add the space holder id at the end
                        #of the shorter path
                        x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                        x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                        x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                        #append relation
                        x_r_list.append([relation_id])
                        y_list.append(0.)

                        ######add to path combinations#####
                        #here is the tricky part: we have to add both (path_s, path_l)
                        #and (path_l, path_s). This is because when the length are the same
                        #adding only one situation won't guarantee that 
                        #the same path with different order is also considered.
                        #in other words: path combination don't have order, but our dict does.
                        #so we have to add both situations.
                        '''remember to cancel the comment here when using path_comb'''
                        #path_comb[(len(path_s), len(path_l))].add((path_s, path_l))
                        #path_comb[(len(path_s), len(path_l))].add((path_l, path_s))

                        count += 2

                        if count % 20000 == 0:
                            print('generating big-batches', count, holder_len)

                    if len(y_list) >= holder_len:

                        carry_on = False
                        
                    count_1 += 1
                    count_0 += 1

### Start Training: load the KG and call classes

Here, we use the validation set to see the training efficiency. That is, we use the validation to check whether the true relation between entities can be predicted by paths.

The trick is: in validation, we have to use the same relation ID and entity ID as in the training. But we don't want to use the links in training anymore. That is, in validation, we want to use (and update if necessary) entity2id, id2entity, relation2id and id2relation. But we want to use new one_hop, data, data_ and s_t_r for validation set. Then, path-finding will also be based on new one_hop.


In [12]:
model_name

'Model_main_7_nell_v4'

In [13]:
ids_name

'IDs_main_7_nell_v4'

In [14]:
#first, we save the relation and ids
Dict = dict()
Dict['one_hop'] = one_hop
Dict['data'] = data
Dict['s_t_r'] = s_t_r
Dict['entity2id'] = entity2id
Dict['id2entity'] = id2entity
Dict['relation2id'] = relation2id
Dict['id2relation'] = id2relation

with open('../weight_bin/' + ids_name + '.pickle', 'wb') as handle:
    pickle.dump(Dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [15]:
holder_len = 1000000
lower_bd = 2
upper_bd = 8
num_epoch = 10
batch_size = 32

#90% to be train, 10% to be validation
train_len = 9*int(holder_len/10)
    
######################################
###pre-define the lists###############

#define the lists
x_p_list, x_r_list, y_list = {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches(holder_len, lower_bd, upper_bd, Class_2, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity)

#######################################
###do the training#####################

#generate the input arrays
x_train_1 = np.asarray(x_p_list['1'][:train_len], dtype='int')
x_train_2 = np.asarray(x_p_list['2'][:train_len], dtype='int')
x_train_3 = np.asarray(x_p_list['3'][:train_len], dtype='int')
x_train_r = np.asarray(x_r_list[:train_len], dtype='int')
y_train = np.asarray(y_list[:train_len], dtype='int')

#random shuffle the training batches
x_train_1, x_train_2, x_train_3, x_train_r, y_train = shuffle(
           x_train_1, x_train_2, x_train_3, x_train_r, y_train, random_state=None)

#generate the validation arrays
x_valid_1 = np.asarray(x_p_list['1'][train_len:], dtype='int')
x_valid_2 = np.asarray(x_p_list['2'][train_len:], dtype='int')
x_valid_3 = np.asarray(x_p_list['3'][train_len:], dtype='int')
x_valid_r = np.asarray(x_r_list[train_len:], dtype='int')
y_valid = np.asarray(y_list[train_len:], dtype='int')

#random shuffle the validation batches
x_valid_1, x_valid_2, x_valid_3, x_valid_r, y_valid = shuffle(
           x_valid_1, x_valid_2, x_valid_3, x_valid_r, y_valid, random_state=None)

model.fit([x_train_1, x_train_2, x_train_3, x_train_r], y_train, 
          validation_data=([x_valid_1, x_valid_2, x_valid_3, x_valid_r], y_valid),
          batch_size=batch_size, epochs=num_epoch)

# Save model and weights
add_h5 = model_name + '.h5'
save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path = os.path.join(save_dir, add_h5)
model.save(model_path)
print('Save model')
del(model)

del(x_train_1, x_train_2, x_train_3, x_train_r, y_train)
del(x_valid_1, x_valid_2, x_valid_3, x_valid_r, y_valid)

del(x_p_list, x_r_list, y_list)

generating big-batches 20000 1000000
generating big-batches 40000 1000000
generating big-batches 60000 1000000
generating big-batches 80000 1000000
generating big-batches 100000 1000000
generating big-batches 120000 1000000
generating big-batches 140000 1000000
generating big-batches 160000 1000000
generating big-batches 180000 1000000
generating big-batches 200000 1000000
generating big-batches 220000 1000000
generating big-batches 240000 1000000
generating big-batches 260000 1000000
generating big-batches 280000 1000000
generating big-batches 300000 1000000
generating big-batches 320000 1000000
generating big-batches 340000 1000000
generating big-batches 360000 1000000
generating big-batches 380000 1000000
generating big-batches 400000 1000000
generating big-batches 420000 1000000
generating big-batches 440000 1000000
generating big-batches 460000 1000000
generating big-batches 480000 1000000
generating big-batches 500000 1000000
generating big-batches 520000 1000000
generating big-b

NameError: name 'x_train_s' is not defined

### Result on the testset for inductive link prediction

We use the testset for inductive link prediction.

In [50]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
from collections import defaultdict
from copy import deepcopy
import pickle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras.utils import plot_model

In [51]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = dict()
            
            if r not in one_hop[s]:
                
                one_hop[s][r] = set()
            
            one_hop[s][r].add(t)
            
            if t not in one_hop:
                
                one_hop[t] = dict()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            if r_inv not in one_hop[t]:
                
                one_hop[t][r_inv] = set()
            
            one_hop[t][r_inv].add(s)

In [52]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, size_bd=50, threshold=100000):
                
        self.size_bd = size_bd
        
        self.threshold = threshold
    
    '''
    Given an entity s, here is the function to find:
      1. any else entity t that is directely connected to s
      2. most of the paths from s to each t with length L
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on active entities for validation')
        
        #here is the result dict. Its key is each entity t that is directly connected to s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #direct_nb contains all the direct neighbour of s
        direct_nb = set()
        
        if mode == 'direct_neighbour':
        
            for r in one_hop[s]:
            
                for t in one_hop[s][r]:
                
                    direct_nb.add(t)
                    
        elif mode == 'target_specified':
            
            direct_nb.add(t_input)
            
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                direct_nb.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {e1, ..., ek-1, node}
        from s to this node, we further find the direct neighbor t' of this node. 
        If t' is not a on-path entity (not among e1,...ek-1), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict):
            
            #when the current path is within lower_bd and upper_bd and its corresponding
            #length still within the size_bd and its tail node is within the note dict, 
            #we will then intend to add this path
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in direct_nb) and (length_dict[len(path)] < self.size_bd):
                
                #if this path already exists between the source entity and the current target node,
                #we will not count it.
                #here is an interesting situation: this path may exist between s and some other node t,
                #however, it does not exist between s and this node t. Then, we still count it: length_dict[len(path)] += 1
                #That is, each path may be counted for multiple times.
                #We count how many paths we "actually" found between entity pairs
                #Same type of path between different entity pairs are count separately.
                if tuple(path) not in res[node]:
                
                    res[node].add(tuple(path))
                
                    length_dict[len(path)] += 1
                
            #For some rare entities, we may face such a case: so many paths are evaluated,
            #but no entities on the paths are direct neighbors of the rare entity.
            #In this case, the recursion cannot be bounded and stoped by the size threshold.
            #In order to cure this, we count how many times the recursion happens on a specific length, using the count_dict.
            #Its key is length, value counts the recursion occurred to that length. 
            #The recursion is forced to stop for that length (and hence for longer lengths) once reach the threshold.
            if (len(path) < upper_bd) and (length_dict[len(path) + 1] < self.size_bd) and (
                count_dict[len(path)] <= self.threshold):
                
                #we randomly shuffle relation r so that the reading in order is not fixed
                temp_list = list()
                
                for r in one_hop[node]:
                    
                    temp_list.append(r)
                
                for i_0 in range(len(temp_list)):
                    
                    if count_dict[len(path)] > self.threshold:
                        break
                    
                    r = random.choice(temp_list)
                    
                    for i_1 in range(len(one_hop[node][r])):
                        
                        if count_dict[len(path)] > self.threshold:
                            break
                        
                        t = random.choice(list(one_hop[node][r]))
                        
                        if t not in on_path_en:
                                
                            count_dict[len(path)] += 1

                            helper(t, path + [r], on_path_en.union({t}), res, direct_nb, 
                                   lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        return(res, length_dict)

In [53]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [54]:
#load ids and relation/entity dicts
with open('../weight_bin/' + ids_name + '.pickle', 'rb') as handle:
    Dict = pickle.load(handle)

one_hop = Dict['one_hop']
data = Dict['data']
s_t_r = Dict['s_t_r']
entity2id = Dict['entity2id']
id2entity = Dict['id2entity']
relation2id = Dict['relation2id']
id2relation = Dict['id2relation']

#we want to keep the initial entity/relation dicts
entity2id_ini = deepcopy(entity2id)
id2entity_ini = deepcopy(id2entity)
relation2id_ini = deepcopy(relation2id)
id2relation_ini = deepcopy(id2relation)

num_r = len(id2relation)
num_r

152

In [55]:
ids_name

'IDs_main_7_nell_v4'

In [56]:
model_name

'Model_main_7_nell_v4'

In [57]:
#load the model
model = keras.models.load_model('../weight_bin/' + model_name + '.h5')

In [58]:
ind_train_path = '../data/' + data_name + '_ind/train.txt'
ind_valid_path = '../data/' + data_name + '_ind/valid.txt'
ind_test_path = '../data/' + data_name + '_ind/test.txt'

In [59]:
#load the test dataset
one_hop_ind = dict() 
data_ind = set()
s_t_r_ind = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_train_path, 
                        one_hop_ind, data_ind, s_t_r_ind,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [60]:
print(size_0, size_1, len(data_ind))

2092 4886 7073


In [61]:
#load the test dataset
one_hop_test = dict() 
data_test = set()
s_t_r_test = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_test_path, 
                        one_hop_test, data_test, s_t_r_test,
                        entity2id, id2entity, relation2id, id2relation)


len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [62]:
print(size_0, size_1, len(data_test))

4886 4886 731


In [63]:
#load the validation for existing triple removal when ranking
one_hop_valid = dict() 
data_valid = set()
s_t_r_valid = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_valid_path, 
                        one_hop_valid, data_valid, s_t_r_valid,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [64]:
print(size_0, size_1, len(data_valid))

4886 4886 716


In [65]:
print(len(entity2id), len(entity2id_ini))

4886 2092


In [66]:
valid_path = '../data/' + data_name + '/valid.txt'
test_path = '../data/' + data_name + '/test.txt'

In [67]:
#load the test dataset
one_hop_train_test = dict() 
data_train_test = set()
s_t_r_train_test = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(test_path, 
                        one_hop_train_test, data_train_test, s_t_r_train_test,
                        entity2id, id2entity, relation2id, id2relation)


len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [68]:
#load the validation for existing triple removal when ranking
one_hop_train_valid = dict() 
data_train_valid = set()
s_t_r_train_valid = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(valid_path, 
                        one_hop_train_valid, data_train_valid, s_t_r_train_valid,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [69]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_test:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

11

In [70]:
overlapping = 0

for ele in data_valid:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

17

In [71]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_ind:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

110

In [72]:
def relation_ranking(s, t, lower_bd, upper_bd, one_hop, id2relation, model):
    
    path_holder = set()
    
    for iteration in range(20):
    
        result, length_dict = Class_2.obtain_paths('target_specified', 
                                                   s, t, lower_bd, upper_bd, one_hop)
        if t in result:
            
            for path in result[t]:
                
                path_holder.add(path)
    
    path_holder = list(path_holder)
    random.shuffle(path_holder)
    
    score_dict = defaultdict(float)
    
    count = 0
    
    if len(path_holder) >= 3:
    
        #iterate over path_1
        while count <= 50:

            temp_pair = random.sample(path_holder, 3)

            path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

            list_1 = list()
            list_2 = list()
            list_3 = list()
            list_r = list()

            for i in range(len(id2relation)):

                if i not in id2relation:

                    raise ValueError ('error when generating id2relation')

                list_1.append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                list_2.append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                list_3.append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))
                list_r.append([i])

            input_1 = np.array(list_1)
            input_2 = np.array(list_2)
            input_3 = np.array(list_3)
            input_r = np.array(list_r)

            pred = model.predict([input_1, input_2, input_3, input_r], verbose = 0)

            for i in range(pred.shape[0]):

                score_dict[i] += float(pred[i])

            count += 1
    
    count_abv = 0
    
    for key in score_dict:
        
        if score_dict[key] >= 0.8:
            
            count_abv += 1
    
    print(len(score_dict), count_abv, len(path_holder))

    return(score_dict)

In [74]:
########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_test), min(len(data_test), 500))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 8, one_hop_ind, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
            (s_true, sorted_list[p][1], t_true) in data_valid) or (
            (s_true, sorted_list[p][1], t_true) in data_ind) or (
            (s_true, sorted_list[p][1], t_true) in data) or (
            (s_true, sorted_list[p][1], t_true) in data_train_valid) or (
            (s_true, sorted_list[p][1], t_true) in data_train_test):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', r_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

152 45 1140
checkcorrect 124 124 real score 1.7071497396100312 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.045454545454545456 cur_rank 21 abs_cur_rank 21 total_num 0 500
0 0 1
checkcorrect 60 60 real score 0.0 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.03092399403874814 cur_rank 60 abs_cur_rank 60 total_num 1 500
152 4 818
checkcorrect 10 10 real score 50.99950700998306 Hits@1 0.3333333333333333 Hits@3 0.3333333333333333 Hits@10 0.3333333333333333 MRR 0.3539493293591654 cur_rank 0 abs_cur_rank 0 total_num 2 500
152 51 999
checkcorrect 124 124 real score 2.048174860421568 Hits@1 0.25 Hits@3 0.25 Hits@10 0.25 MRR 0.2750773816347587 cur_rank 25 abs_cur_rank 25 total_num 3 500
152 48 718
checkcorrect 124 124 real score 2.670605568215251 Hits@1 0.2 Hits@3 0.2 Hits@10 0.2 MRR 0.2318266111901599 cur_rank 16 abs_cur_rank 16 total_num 4 500
152 18 1641
checkcorrect 124 124 real score -1.498010371113196 Hits@1 0.16666666666666666 Hits@3 0.16666666666666666 Hits@10 0.16666666666666666 MRR 0.196661064880

152 68 1094
checkcorrect 94 94 real score 0.7811224139295518 Hits@1 0.09302325581395349 Hits@3 0.11627906976744186 Hits@10 0.20930232558139536 MRR 0.1454309907027077 cur_rank 68 abs_cur_rank 69 total_num 42 500
152 22 130
checkcorrect 104 104 real score -2.1691303779371083 Hits@1 0.09090909090909091 Hits@3 0.11363636363636363 Hits@10 0.20454545454545456 MRR 0.14268006463906524 cur_rank 40 abs_cur_rank 40 total_num 43 500
152 21 889
checkcorrect 124 124 real score -0.7596410135738552 Hits@1 0.08888888888888889 Hits@3 0.1111111111111111 Hits@10 0.2 MRR 0.13992868374771497 cur_rank 52 abs_cur_rank 52 total_num 44 500
152 29 2176
checkcorrect 124 124 real score -1.5018517160788178 Hits@1 0.08695652173913043 Hits@3 0.10869565217391304 Hits@10 0.1956521739130435 MRR 0.1372318214026128 cur_rank 62 abs_cur_rank 62 total_num 45 500
152 21 173
checkcorrect 30 30 real score -1.7473498624749482 Hits@1 0.0851063829787234 Hits@3 0.10638297872340426 Hits@10 0.19148936170212766 MRR 0.13502121527347918

152 43 1097
checkcorrect 66 66 real score 24.838398540392518 Hits@1 0.11904761904761904 Hits@3 0.16666666666666666 Hits@10 0.23809523809523808 MRR 0.17506574121741067 cur_rank 0 abs_cur_rank 0 total_num 83 500
152 69 1180
checkcorrect 124 124 real score 4.641451987903565 Hits@1 0.11764705882352941 Hits@3 0.16470588235294117 Hits@10 0.23529411764705882 MRR 0.17356636835154757 cur_rank 20 abs_cur_rank 20 total_num 84 500
152 65 746
checkcorrect 124 124 real score 3.355307932011783 Hits@1 0.11627906976744186 Hits@3 0.16279069767441862 Hits@10 0.23255813953488372 MRR 0.1721295501149017 cur_rank 19 abs_cur_rank 19 total_num 85 500
152 37 288
checkcorrect 4 4 real score 26.393575608730316 Hits@1 0.12643678160919541 Hits@3 0.1724137931034483 Hits@10 0.2413793103448276 MRR 0.18164530241243154 cur_rank 0 abs_cur_rank 0 total_num 86 500
152 23 984
checkcorrect 124 124 real score -0.6690998449921608 Hits@1 0.125 Hits@3 0.17045454545454544 Hits@10 0.23863636363636365 MRR 0.17982293074575256 cur_ra

152 63 1136
checkcorrect 124 124 real score 5.268470829818398 Hits@1 0.11290322580645161 Hits@3 0.1693548387096774 Hits@10 0.22580645161290322 MRR 0.17253211235052254 cur_rank 13 abs_cur_rank 13 total_num 123 500
152 61 1387
checkcorrect 124 124 real score 4.471054862020537 Hits@1 0.112 Hits@3 0.168 Hits@10 0.224 MRR 0.17157290808329734 cur_rank 18 abs_cur_rank 18 total_num 124 500
152 16 586
checkcorrect 94 94 real score -2.0000068659428507 Hits@1 0.1111111111111111 Hits@3 0.16666666666666666 Hits@10 0.2222222222222222 MRR 0.17032459702141176 cur_rank 69 abs_cur_rank 69 total_num 125 500
152 18 516
checkcorrect 18 18 real score 34.80863322969526 Hits@1 0.11023622047244094 Hits@3 0.16535433070866143 Hits@10 0.2283464566929134 MRR 0.17095196239919594 cur_rank 3 abs_cur_rank 3 total_num 126 500
152 19 133
checkcorrect 14 14 real score 37.37769392132759 Hits@1 0.109375 Hits@3 0.171875 Hits@10 0.234375 MRR 0.17222056685961887 cur_rank 2 abs_cur_rank 2 total_num 127 500
152 77 671
checkcorr

152 22 165
checkcorrect 14 14 real score 37.68952292203903 Hits@1 0.12121212121212122 Hits@3 0.18181818181818182 Hits@10 0.2606060606060606 MRR 0.18518914933749386 cur_rank 2 abs_cur_rank 2 total_num 164 500
152 39 1784
checkcorrect 114 114 real score -3.2176043624058366 Hits@1 0.12048192771084337 Hits@3 0.18072289156626506 Hits@10 0.25903614457831325 MRR 0.1841250400556235 cur_rank 116 abs_cur_rank 117 total_num 165 500
152 4 272
checkcorrect 10 10 real score 50.999507546424866 Hits@1 0.12574850299401197 Hits@3 0.18562874251497005 Hits@10 0.2634730538922156 MRR 0.1890105188576856 cur_rank 0 abs_cur_rank 0 total_num 166 500
152 19 336
checkcorrect 124 124 real score -2.547711888793856 Hits@1 0.125 Hits@3 0.18452380952380953 Hits@10 0.2619047619047619 MRR 0.1879608028397714 cur_rank 78 abs_cur_rank 78 total_num 167 500
152 20 1865
checkcorrect 124 124 real score 0.4384010194335133 Hits@1 0.1242603550295858 Hits@3 0.1834319526627219 Hits@10 0.2603550295857988 MRR 0.1870526487319662 cur_r

152 30 1027
checkcorrect 124 124 real score -0.46996454847976565 Hits@1 0.13170731707317074 Hits@3 0.2048780487804878 Hits@10 0.2731707317073171 MRR 0.19806688493014427 cur_rank 55 abs_cur_rank 55 total_num 204 500
152 19 310
checkcorrect 18 18 real score 37.04158188402653 Hits@1 0.13106796116504854 Hits@3 0.20388349514563106 Hits@10 0.2766990291262136 MRR 0.19807626898388145 cur_rank 4 abs_cur_rank 4 total_num 205 500
152 20 110
checkcorrect 18 18 real score 33.65720055997372 Hits@1 0.13043478260869565 Hits@3 0.20772946859903382 Hits@10 0.28019323671497587 MRR 0.19872968475368555 cur_rank 2 abs_cur_rank 2 total_num 206 500
152 40 1330
checkcorrect 124 124 real score 1.7261174423620105 Hits@1 0.12980769230769232 Hits@3 0.20673076923076922 Hits@10 0.27884615384615385 MRR 0.19799278504551662 cur_rank 21 abs_cur_rank 21 total_num 207 500
152 56 924
checkcorrect 124 124 real score 2.3443680992349982 Hits@1 0.1291866028708134 Hits@3 0.20574162679425836 Hits@10 0.27751196172248804 MRR 0.1972

152 19 2980
checkcorrect 124 124 real score -0.4119412188883871 Hits@1 0.1306122448979592 Hits@3 0.20408163265306123 Hits@10 0.27755102040816326 MRR 0.1970599599922403 cur_rank 33 abs_cur_rank 33 total_num 244 500
152 44 262
checkcorrect 24 24 real score 9.644036976969801 Hits@1 0.13008130081300814 Hits@3 0.2032520325203252 Hits@10 0.2764227642276423 MRR 0.19652990595433145 cur_rank 14 abs_cur_rank 14 total_num 245 500
152 54 1088
checkcorrect 124 124 real score 3.668862466234714 Hits@1 0.12955465587044535 Hits@3 0.20242914979757085 Hits@10 0.27530364372469635 MRR 0.1959591595964417 cur_rank 17 abs_cur_rank 17 total_num 246 500
152 23 1057
checkcorrect 124 124 real score 0.35503339720889926 Hits@1 0.12903225806451613 Hits@3 0.20161290322580644 Hits@10 0.27419354838709675 MRR 0.19529500975935926 cur_rank 31 abs_cur_rank 31 total_num 247 500
152 21 107
checkcorrect 148 148 real score -1.526664787903428 Hits@1 0.1285140562248996 Hits@3 0.20080321285140562 Hits@10 0.27309236947791166 MRR 0

152 40 164
checkcorrect 32 32 real score 49.771566569805145 Hits@1 0.12982456140350876 Hits@3 0.2 Hits@10 0.2736842105263158 MRR 0.19599659157121269 cur_rank 0 abs_cur_rank 0 total_num 284 500
152 59 1513
checkcorrect 124 124 real score 4.513111753854901 Hits@1 0.12937062937062938 Hits@3 0.1993006993006993 Hits@10 0.2727272727272727 MRR 0.1955169654797461 cur_rank 16 abs_cur_rank 16 total_num 285 500
0 0 0
checkcorrect 18 18 real score 0.0 Hits@1 0.1289198606271777 Hits@3 0.1986062717770035 Hits@10 0.27177700348432055 MRR 0.19501910699008623 cur_rank 18 abs_cur_rank 18 total_num 286 500
152 21 183
checkcorrect 90 90 real score -2.98527490766719 Hits@1 0.1284722222222222 Hits@3 0.19791666666666666 Hits@10 0.2708333333333333 MRR 0.19438825360933365 cur_rank 74 abs_cur_rank 74 total_num 287 500
152 65 869
checkcorrect 124 124 real score 3.9669655908364803 Hits@1 0.12802768166089964 Hits@3 0.1972318339100346 Hits@10 0.2698961937716263 MRR 0.19390786365067006 cur_rank 17 abs_cur_rank 17 tot

ValueError: Sample larger than population or is negative

In [73]:
########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_ind), min(len(data_ind), 500))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 8, one_hop_ind, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
            (s_true, sorted_list[p][1], t_true) in data_valid) or (
            (s_true, sorted_list[p][1], t_true) in data_ind) or (
            (s_true, sorted_list[p][1], t_true) in data) or (
            (s_true, sorted_list[p][1], t_true) in data_train_valid) or (
            (s_true, sorted_list[p][1], t_true) in data_train_test):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', r_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

152 40 303
checkcorrect 24 24 real score 3.467067563906312 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.03333333333333333 cur_rank 29 abs_cur_rank 29 total_num 0 500
152 58 121
checkcorrect 32 32 real score 0.03531505726277828 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.023076923076923078 cur_rank 77 abs_cur_rank 77 total_num 1 500
152 46 369
checkcorrect 34 34 real score 32.589312763884664 Hits@1 0.0 Hits@3 0.0 Hits@10 0.3333333333333333 MRR 0.09871794871794871 cur_rank 3 abs_cur_rank 3 total_num 2 500
152 24 2121
checkcorrect 124 124 real score 0.014906604774296284 Hits@1 0.0 Hits@3 0.0 Hits@10 0.25 MRR 0.0807952182952183 cur_rank 36 abs_cur_rank 36 total_num 3 500
152 62 1161
checkcorrect 124 124 real score 4.658795545808971 Hits@1 0.0 Hits@3 0.0 Hits@10 0.2 MRR 0.07713617463617464 cur_rank 15 abs_cur_rank 15 total_num 4 500
0 0 0
checkcorrect 18 18 real score 0.0 Hits@1 0.0 Hits@3 0.0 Hits@10 0.16666666666666666 MRR 0.07305207535470694 cur_rank 18 abs_cur_rank 18 total_num 5 500
152 46 470


152 35 20
checkcorrect 0 0 real score 38.94167596101761 Hits@1 0.22727272727272727 Hits@3 0.22727272727272727 Hits@10 0.3181818181818182 MRR 0.2657478326883759 cur_rank 0 abs_cur_rank 0 total_num 43 500
0 0 0
checkcorrect 60 60 real score 0.0 Hits@1 0.2222222222222222 Hits@3 0.2222222222222222 Hits@10 0.3111111111111111 MRR 0.2602066240202553 cur_rank 60 abs_cur_rank 60 total_num 44 500
152 19 2495
checkcorrect 124 124 real score -0.3196194116026163 Hits@1 0.21739130434782608 Hits@3 0.21739130434782608 Hits@10 0.30434782608695654 MRR 0.2551073718815764 cur_rank 38 abs_cur_rank 38 total_num 45 500
152 51 960
checkcorrect 124 124 real score 2.676052486291155 Hits@1 0.2127659574468085 Hits@3 0.2127659574468085 Hits@10 0.2978723404255319 MRR 0.2506927266845013 cur_rank 20 abs_cur_rank 20 total_num 46 500
152 33 1006
checkcorrect 124 124 real score -0.7940260786563158 Hits@1 0.20833333333333334 Hits@3 0.20833333333333334 Hits@10 0.2916666666666667 MRR 0.24578090681887274 cur_rank 66 abs_cur

152 21 1137
checkcorrect 124 124 real score -0.4944718738552183 Hits@1 0.16279069767441862 Hits@3 0.20930232558139536 Hits@10 0.2558139534883721 MRR 0.21616067178206425 cur_rank 45 abs_cur_rank 45 total_num 85 500
152 47 764
checkcorrect 116 116 real score 14.68192882835865 Hits@1 0.16091954022988506 Hits@3 0.20689655172413793 Hits@10 0.26436781609195403 MRR 0.2149532055674556 cur_rank 8 abs_cur_rank 8 total_num 86 500
0 0 0
checkcorrect 60 60 real score 0.0 Hits@1 0.1590909090909091 Hits@3 0.20454545454545456 Hits@10 0.26136363636363635 MRR 0.21269684462490443 cur_rank 60 abs_cur_rank 60 total_num 87 500
152 24 84
checkcorrect 4 4 real score 27.22908551618457 Hits@1 0.15730337078651685 Hits@3 0.21348314606741572 Hits@10 0.2696629213483146 MRR 0.21405231079016765 cur_rank 2 abs_cur_rank 2 total_num 88 500
152 25 2439
checkcorrect 92 92 real score 0.5620250425999984 Hits@1 0.15555555555555556 Hits@3 0.2111111111111111 Hits@10 0.26666666666666666 MRR 0.2120323747204561 cur_rank 30 abs_cu

152 35 4
checkcorrect 60 60 real score 23.25424900650978 Hits@1 0.14960629921259844 Hits@3 0.2047244094488189 Hits@10 0.2755905511811024 MRR 0.20655310695032605 cur_rank 6 abs_cur_rank 6 total_num 126 500
152 26 349
checkcorrect 0 0 real score 50.85659921169281 Hits@1 0.15625 Hits@3 0.2109375 Hits@10 0.28125 MRR 0.21275191080227662 cur_rank 0 abs_cur_rank 0 total_num 127 500
0 0 0
checkcorrect 4 4 real score 0.0 Hits@1 0.15503875968992248 Hits@3 0.20930232558139536 Hits@10 0.2868217054263566 MRR 0.21265305878055354 cur_rank 4 abs_cur_rank 4 total_num 128 500
0 0 0
checkcorrect 12 12 real score 0.0 Hits@1 0.15384615384615385 Hits@3 0.2076923076923077 Hits@10 0.2846153846153846 MRR 0.21165829166172875 cur_rank 11 abs_cur_rank 12 total_num 129 500
152 67 797
checkcorrect 124 124 real score 4.4288312629796565 Hits@1 0.15267175572519084 Hits@3 0.20610687022900764 Hits@10 0.2824427480916031 MRR 0.21044434728986341 cur_rank 18 abs_cur_rank 18 total_num 130 500
152 42 1251
checkcorrect 114 114

0 0 0
checkcorrect 64 64 real score 0.0 Hits@1 0.13095238095238096 Hits@3 0.18452380952380953 Hits@10 0.27380952380952384 MRR 0.19187768153866286 cur_rank 64 abs_cur_rank 64 total_num 167 500
152 28 927
checkcorrect 124 124 real score 0.7664006177801639 Hits@1 0.1301775147928994 Hits@3 0.1834319526627219 Hits@10 0.27218934911242604 MRR 0.1909463506338228 cur_rank 28 abs_cur_rank 28 total_num 168 500
0 0 0
checkcorrect 24 24 real score 0.0 Hits@1 0.12941176470588237 Hits@3 0.18235294117647058 Hits@10 0.27058823529411763 MRR 0.19005843092421207 cur_rank 24 abs_cur_rank 24 total_num 169 500
152 4 352
checkcorrect 10 10 real score 50.99950647354126 Hits@1 0.13450292397660818 Hits@3 0.1871345029239766 Hits@10 0.27485380116959063 MRR 0.19479493132816403 cur_rank 0 abs_cur_rank 0 total_num 170 500
0 0 0
checkcorrect 120 120 real score 0.0 Hits@1 0.13372093023255813 Hits@3 0.18604651162790697 Hits@10 0.27325581395348836 MRR 0.19371045186003472 cur_rank 120 abs_cur_rank 120 total_num 171 500
15

152 38 523
checkcorrect 118 118 real score 21.713668674230576 Hits@1 0.1291866028708134 Hits@3 0.1722488038277512 Hits@10 0.2727272727272727 MRR 0.18862928463401446 cur_rank 7 abs_cur_rank 7 total_num 208 500
152 58 972
checkcorrect 124 124 real score 4.4494321728125215 Hits@1 0.12857142857142856 Hits@3 0.17142857142857143 Hits@10 0.2714285714285714 MRR 0.18801116199009898 cur_rank 16 abs_cur_rank 16 total_num 209 500
0 0 0
checkcorrect 24 24 real score 0.0 Hits@1 0.12796208530805686 Hits@3 0.17061611374407584 Hits@10 0.27014218009478674 MRR 0.1873096872887241 cur_rank 24 abs_cur_rank 24 total_num 210 500
0 0 0
checkcorrect 60 60 real score 0.0 Hits@1 0.12735849056603774 Hits@3 0.16981132075471697 Hits@10 0.2688679245283019 MRR 0.18650347858747046 cur_rank 60 abs_cur_rank 60 total_num 211 500
152 24 1341
checkcorrect 124 124 real score -0.17349485028535128 Hits@1 0.1267605633802817 Hits@3 0.16901408450704225 Hits@10 0.2676056338028169 MRR 0.18573705762674755 cur_rank 42 abs_cur_rank 42

152 30 1524
checkcorrect 60 60 real score 12.137865078635514 Hits@1 0.116 Hits@3 0.16 Hits@10 0.252 MRR 0.1749033377708479 cur_rank 10 abs_cur_rank 10 total_num 249 500
152 4 195
checkcorrect 10 10 real score 50.999507546424866 Hits@1 0.11952191235059761 Hits@3 0.16334661354581673 Hits@10 0.2549800796812749 MRR 0.17819057546897202 cur_rank 0 abs_cur_rank 0 total_num 250 500
152 18 307
checkcorrect 146 146 real score -2.036908849142492 Hits@1 0.11904761904761904 Hits@3 0.1626984126984127 Hits@10 0.25396825396825395 MRR 0.17754747410702398 cur_rank 61 abs_cur_rank 61 total_num 251 500
0 0 0
checkcorrect 10 10 real score 0.0 Hits@1 0.11857707509881422 Hits@3 0.16205533596837945 Hits@10 0.25296442687747034 MRR 0.1772050299046606 cur_rank 10 abs_cur_rank 10 total_num 252 500
152 33 183
checkcorrect 60 60 real score 25.16241727769375 Hits@1 0.11811023622047244 Hits@3 0.16141732283464566 Hits@10 0.2559055118110236 MRR 0.17716354028561337 cur_rank 5 abs_cur_rank 5 total_num 253 500
152 18 204


KeyboardInterrupt: 