In [None]:
data_name = 'fb237_v4'
model_id = 'main_3'

In [None]:
#difine the names for saving
model_name = 'Model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [1]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
from collections import defaultdict
from copy import deepcopy
import pickle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [2]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = dict()
            
            if r not in one_hop[s]:
                
                one_hop[s][r] = set()
            
            one_hop[s][r].add(t)
            
            if t not in one_hop:
                
                one_hop[t] = dict()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            if r_inv not in one_hop[t]:
                
                one_hop[t][r_inv] = set()
            
            one_hop[t][r_inv].add(s)

In [3]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, size_bd=50, threshold=100000):
                
        self.size_bd = size_bd
        
        self.threshold = threshold
    
    '''
    Given an entity s, here is the function to find:
      1. any else entity t that is directely connected to s
      2. most of the paths from s to each t with length L
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on active entities for validation')
        
        #here is the result dict. Its key is each entity t that is directly connected to s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #direct_nb contains all the direct neighbour of s
        direct_nb = set()
        
        if mode == 'direct_neighbour':
        
            for r in one_hop[s]:
            
                for t in one_hop[s][r]:
                
                    direct_nb.add(t)
                    
        elif mode == 'target_specified':
            
            direct_nb.add(t_input)
            
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                direct_nb.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {e1, ..., ek-1, node}
        from s to this node, we further find the direct neighbor t' of this node. 
        If t' is not a on-path entity (not among e1,...ek-1), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict):
            
            #when the current path is within lower_bd and upper_bd and its corresponding
            #length still within the size_bd and its tail node is within the note dict, 
            #we will then intend to add this path
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in direct_nb) and (length_dict[len(path)] < self.size_bd):
                
                #if this path already exists between the source entity and the current target node,
                #we will not count it.
                #here is an interesting situation: this path may exist between s and some other node t,
                #however, it does not exist between s and this node t. Then, we still count it: length_dict[len(path)] += 1
                #That is, each path may be counted for multiple times.
                #We count how many paths we "actually" found between entity pairs
                #Same type of path between different entity pairs are count separately.
                if tuple(path) not in res[node]:
                
                    res[node].add(tuple(path))
                
                    length_dict[len(path)] += 1
                
            #For some rare entities, we may face such a case: so many paths are evaluated,
            #but no entities on the paths are direct neighbors of the rare entity.
            #In this case, the recursion cannot be bounded and stoped by the size threshold.
            #In order to cure this, we count how many times the recursion happens on a specific length, using the count_dict.
            #Its key is length, value counts the recursion occurred to that length. 
            #The recursion is forced to stop for that length (and hence for longer lengths) once reach the threshold.
            if (len(path) < upper_bd) and (length_dict[len(path) + 1] < self.size_bd) and (
                count_dict[len(path)] <= self.threshold):
                
                #we randomly shuffle relation r so that the reading in order is not fixed
                temp_list = list()
                
                for r in one_hop[node]:
                    
                    temp_list.append(r)
                
                for i_0 in range(len(temp_list)):
                    
                    if count_dict[len(path)] > self.threshold:
                        break
                    
                    r = random.choice(temp_list)
                    
                    for i_1 in range(len(one_hop[node][r])):
                        
                        if count_dict[len(path)] > self.threshold:
                            break
                        
                        t = random.choice(list(one_hop[node][r]))
                        
                        if t not in on_path_en:
                                
                            count_dict[len(path)] += 1

                            helper(t, path + [r], on_path_en.union({t}), res, direct_nb, 
                                   lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        return(res, length_dict)

In [4]:
train_path = '../data/' + data_name + '/train.txt'

In [5]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [6]:
#define the dictionaries and sets for load KG
one_hop = dict() 
data = set()
s_t_r = defaultdict(set)
entity2id = dict()
id2entity = dict()
relation2id = dict()
id2relation = dict()

#fill in the sets and dicts
Class_1.load_train_data(train_path, one_hop, data, s_t_r,
                        entity2id, id2entity, relation2id, id2relation)

### Build the deep neural network structure

We use biLSTM to train on the input path embedding sequence to predict the output embedding or the relation.

In [12]:
# Input layer, using integer to represent each relation type
#note that inputs_path is the path inputs, while inputs_out_re is the output relation inputs
fst_path = keras.Input(shape=(None,), dtype="int32")
scd_path = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of two paths are not the same
in_embd_var = layers.Embedding(len(relation2id)+1, 300)

# Obtain the embedding
fst_p_embd = in_embd_var(fst_path)
scd_p_embd = in_embd_var(scd_path)

# Embed each integer in a 300-dimensional vector as output
rela_embd = layers.Embedding(len(relation2id)+1, 300)(id_rela)

#add 2 layer bi-directional LSTM
lstm_layer_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_layer_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

#first LSTM layer
fst_lstm_mid = lstm_layer_1(fst_p_embd)
scd_lstm_mid = lstm_layer_1(scd_p_embd)

#second LSTM layer
fst_lstm_out = lstm_layer_2(fst_lstm_mid)
scd_lstm_out = lstm_layer_2(scd_lstm_mid)

###########################################
####apply the attention mechanism##########
#first expand the dimention at the end to get ready for conv2D: (Batch,time,300,1)
fst_exp_dim = tf.expand_dims(fst_lstm_out, axis=-1)
scd_exp_dim = tf.expand_dims(scd_lstm_out, axis=-1)

#define the attention layer using convolutional 2D: output is
att_layer_conv2D = layers.Conv2D(100, (1, 300), padding='valid', activation='relu', 
                   input_shape=(None, 300), data_format='channels_last')

#shape: (Batch,Time,1,100)
fst_att_mid = att_layer_conv2D(fst_exp_dim)
scd_att_mid = att_layer_conv2D(scd_exp_dim)

#squeeze out the dim 1 to become: (Batch, Time, 100)
fst_squez = tf.squeeze(fst_att_mid, 2)
scd_squez = tf.squeeze(scd_att_mid, 2)

#expand the dimention again to become (Batch, Time, 100, 1)
fst_exp_dim_2 = tf.expand_dims(fst_squez, axis=-1)
scd_exp_dim_2 = tf.expand_dims(scd_squez, axis=-1)

#obtain the attention score for each time step by another conv2D layer
att_layer_conv2D_2 = layers.Conv2D(1, (1, 100), padding='valid', activation='relu', 
                     input_shape=(None, 100), data_format='channels_last')

#obtain (Batch, Time, 1, 1)
fst_mid_score = att_layer_conv2D_2(fst_exp_dim_2)
scd_mid_score = att_layer_conv2D_2(scd_exp_dim_2)

#squeeze again to obtain (Batch, Time, 1)
fst_squez_2 = tf.squeeze(fst_mid_score, -1)
scd_squez_2 = tf.squeeze(scd_mid_score, -1)

#softmax the attention score
softmax_l = layers.Softmax(1) #define softmax

fst_att_score = softmax_l(fst_squez_2)
scd_att_score = softmax_l(scd_squez_2)

#multiply the attention score to lstm output
fst_att_befsum = layers.Multiply()([fst_lstm_out, fst_att_score])
scd_att_befsum = layers.Multiply()([scd_lstm_out, scd_att_score])

#sum over time dimension to complete the attention: (Batch, 300)
fst_att_out = tf.reduce_sum(fst_att_befsum, axis=1)
scd_att_out = tf.reduce_sum(scd_att_befsum, axis=1)
######################################

#concatenate the output vector from both siamese tunnel: (Batch, 600)
path_concat = layers.concatenate([fst_att_out, scd_att_out], axis=-1)

#add dropout on top of the concatenation from both channel
dropout = layers.Dropout(0.25)(path_concat)

#multiply into output embd size by dense layer: (Batch, 300)
path_out_vect = layers.Dense(300, activation='tanh')(dropout)

#remove the time dimension from the output embd since there is only one step
rela_out_embd = tf.reduce_sum(rela_embd, axis=1)

# Normalize the vectors to have unit length
path_out_vect_norm = tf.math.l2_normalize(path_out_vect, axis=-1)
rela_out_embd_norm = tf.math.l2_normalize(rela_out_embd, axis=-1)

# Calculate the dot product
dot_product = layers.Dot(axes=-1)([path_out_vect_norm, rela_out_embd_norm])

#put together the model
model = keras.Model([fst_path, scd_path, id_rela], dot_product)

In [13]:
#config the Adam optimizer 
opt = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['binary_accuracy'])

### Build the batches
We build each big-batch for each path combination with length (i,j). Then, we iteratively train the siamese network on different big-batches. The length of each big-batch is N.

To be specific:
* If we allow the length difference between two paths in a combination to be d, then the combination with path length i and path length j, denoted as (i,j), will be like (2,2), (2,3), (2,4), (3,3), (3,4), (3,5), ... 
* We will first build all the big-batches before fitting the NN model. 
* That is, we will perform the ObtainPathsByDynamicProgramming class function for some randomly chosen source entities. Then, for each target entity, we will further have two for loops:
* for path_1 in all the 
* Do this until all the slots in all big-batchs are filled.
* In every epoch, big-batchs will be re-filled.

Then, in the training, we will use negative sampling: In each batch (actual batch, not the big-batch), we will include K true output relation embeddings and K random selected output relation embeddings. The true label is [1,0], while the false label is [0,1].

In [14]:
#function to build all the big batches
def build_big_batches(holder_len, lower_bd, upper_bd, Class_2, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    if holder_len % 10 != 0:
        raise ValueError('We would like to take 10X as a big-batch size')
    
    #the set of all relation IDs
    relation_id_set = set()
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #count how many appending has performed
    count = 0

    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    
    carry_on = True
    
    while carry_on:

        #obtain paths by dynamic programming
        source_id = random.choice(existing_ids)

        result, length_dict = Class_2.obtain_paths('direct_neighbour', source_id, 
                                                   'not_specified', lower_bd, upper_bd, one_hop)
        
        #We want to increase the diversity of paths and targets.
        #So we abandon one sub-graph from a source_id, if we sampled more than K1 path pairs
        #Note that we mean "sampled", not "appended"! 
        #We do not care whether the pair is actually appended.
        threshold_0 = 1000
        count_0 = 0
        
        for target_id in result:

            if (not carry_on) or (count_0 > threshold_0):
                break
            
            #we want to make sure s, t are indeed directly connected, 
            #otherwise there is no relation for positive sample
            #also, we want to make sure s and t and not connected by all relations, 
            #although this situation is rare. 
            #But in that case, there is no relation for negative samples
            #Also, we want at least two different paths here between s and t
            if ((source_id, target_id) in s_t_r) and (
                len(s_t_r[(source_id, target_id)]) < len(id2relation)) and (
                len(result[target_id]) >= 2):
                
                dir_r = list(s_t_r[(source_id, target_id)])
                
                non_dir_r = list(relation_id_set.difference(dir_r))
                
                if len(dir_r) <= 0:
                    
                    raise ValueError('errors when creating s_t_r !!')
                    
                temp_path_list = list(result[target_id])
                    
                #futhermore, we will abandon one targed_id if we sampled more than K2 times
                threshold_1 = 50
                count_1 = 0
                
                while count_1 <= threshold_1 and count_0 <= threshold_0:
                
                    temp_pair = random.sample(temp_path_list,2)
                    
                    path_1, path_2 = temp_pair[0], temp_pair[1]

                    #decide which path is shorter and which is longer
                    if len(path_1) <= len(path_2):

                        path_s, path_l = path_1, path_2

                    else:

                        path_s, path_l = path_2, path_1                            

                    if (len(path_s) < lower_bd) or (len(path_l) > upper_bd):

                        raise ValueError('something wrong with the path finding')

                    #proceed when the entire length not yet reached,
                    #and whether this path pair is new, and whether the two paths are different
                    #But it is optional to require the path to be new. 
                    #We may remove this requirment, especially for short paths
                    '''remember to cancel the comment below when using path_comb'''
                    if (carry_on) and (path_s != path_l):

                        #we always add one positive and one negative situation together,
                        #hence, the length of list should always be even.
                        #also we want to make sure the length of lists coincide
                        if (len(x_p_list['s']) != len(y_list)) or (
                            len(x_p_list['s']) != len(x_p_list['l'])) or (
                            len(y_list) != len(x_r_list)) or (
                            len(y_list) % 2 != 0):

                            raise ValueError('error when building big batches: length error')

                        #####positive#####################
                        #we randomly choose one direction relation as the target relation
                        relation_id = random.choice(dir_r)

                        #append the paths: note that we add the space holder id at the end
                        #of the shorter path
                        x_p_list['s'].append(list(path_s) + [num_r]*abs(len(path_s)-upper_bd))
                        x_p_list['l'].append(list(path_l) + [num_r]*abs(len(path_l)-upper_bd))

                        #append relation
                        x_r_list.append([relation_id])
                        y_list.append(1.)

                        #####negative#####################
                        relation_id = random.choice(non_dir_r)

                        #append the paths: note that we add the space holder id at the end
                        #of the shorter path
                        x_p_list['s'].append(list(path_s) + [num_r]*abs(len(path_s)-upper_bd))
                        x_p_list['l'].append(list(path_l) + [num_r]*abs(len(path_l)-upper_bd))

                        #append relation
                        x_r_list.append([relation_id])
                        y_list.append(0.)

                        ######add to path combinations#####
                        #here is the tricky part: we have to add both (path_s, path_l)
                        #and (path_l, path_s). This is because when the length are the same
                        #adding only one situation won't guarantee that 
                        #the same path with different order is also considered.
                        #in other words: path combination don't have order, but our dict does.
                        #so we have to add both situations.
                        '''remember to cancel the comment here when using path_comb'''
                        #path_comb[(len(path_s), len(path_l))].add((path_s, path_l))
                        #path_comb[(len(path_s), len(path_l))].add((path_l, path_s))

                        count += 2

                        if count % 20000 == 0:
                            print('generating big-batches', count, holder_len)

                    if len(y_list) >= holder_len:

                        carry_on = False
                        
                    count_1 += 1
                    count_0 += 1

### Start Training: load the KG and call classes

Here, we use the validation set to see the training efficiency. That is, we use the validation to check whether the true relation between entities can be predicted by paths.

The trick is: in validation, we have to use the same relation ID and entity ID as in the training. But we don't want to use the links in training anymore. That is, in validation, we want to use (and update if necessary) entity2id, id2entity, relation2id and id2relation. But we want to use new one_hop, data, data_ and s_t_r for validation set. Then, path-finding will also be based on new one_hop.


In [16]:
model_name

'Model_main_3_fb237_v4'

In [17]:
ids_name

'IDs_main_3_fb237_v4'

In [18]:
#first, we save the relation and ids
Dict = dict()
Dict['one_hop'] = one_hop
Dict['data'] = data
Dict['s_t_r'] = s_t_r
Dict['entity2id'] = entity2id
Dict['id2entity'] = id2entity
Dict['relation2id'] = relation2id
Dict['id2relation'] = id2relation

with open('../weight_bin/' + ids_name + '.pickle', 'wb') as handle:
    pickle.dump(Dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [19]:
holder_len = 1000000
lower_bd = 2
upper_bd = 10
num_epoch = 10
batch_size = 4

#90% to be train, 10% to be validation
train_len = 9*int(holder_len/10)
    
######################################
###pre-define the lists###############

#define the lists
x_p_list, x_r_list, y_list = {'s': [], 'l': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches(holder_len, lower_bd, upper_bd, Class_2, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity)

#######################################
###do the training#####################

#generate the input arrays
x_train_s = np.asarray(x_p_list['s'][:train_len], dtype='int')
x_train_l = np.asarray(x_p_list['l'][:train_len], dtype='int')
x_train_r = np.asarray(x_r_list[:train_len], dtype='int')
y_train = np.asarray(y_list[:train_len], dtype='int')

x_valid_s = np.asarray(x_p_list['s'][train_len:], dtype='int')
x_valid_l = np.asarray(x_p_list['l'][train_len:], dtype='int')
x_valid_r = np.asarray(x_r_list[train_len:], dtype='int')
y_valid = np.asarray(y_list[train_len:], dtype='int')

model.fit([x_train_s, x_train_l, x_train_r], y_train, 
          validation_data=([x_valid_s, x_valid_l, x_valid_r], y_valid),
          batch_size=batch_size, epochs=num_epoch)

# Save model and weights
add_h5 = model_name + '.h5'
save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path = os.path.join(save_dir, add_h5)
model.save(model_path)
print('Save model')
del(model)

del(x_train_s, x_train_l, x_train_r, y_train)
del(x_valid_s, x_valid_l, x_valid_r, y_valid)

del(x_p_list, x_r_list, y_list)

generating big-batches 20000 1000000
generating big-batches 40000 1000000
generating big-batches 60000 1000000
generating big-batches 80000 1000000
generating big-batches 100000 1000000
generating big-batches 120000 1000000
generating big-batches 140000 1000000
generating big-batches 160000 1000000
generating big-batches 180000 1000000
generating big-batches 200000 1000000
generating big-batches 220000 1000000
generating big-batches 240000 1000000
generating big-batches 260000 1000000
generating big-batches 280000 1000000
generating big-batches 300000 1000000
generating big-batches 320000 1000000
generating big-batches 340000 1000000
generating big-batches 360000 1000000
generating big-batches 380000 1000000
generating big-batches 400000 1000000
generating big-batches 420000 1000000
generating big-batches 440000 1000000
generating big-batches 460000 1000000
generating big-batches 480000 1000000
generating big-batches 500000 1000000
generating big-batches 520000 1000000
generating big-b

### Result on the testset for inductive link prediction

We use the testset for inductive link prediction.

In [1]:
data_name = 'fb237_v4'
model_id = 'main_3'

In [2]:
model_name = 'Model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [3]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
from collections import defaultdict
from copy import deepcopy
import pickle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras.utils import plot_model

In [4]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = dict()
            
            if r not in one_hop[s]:
                
                one_hop[s][r] = set()
            
            one_hop[s][r].add(t)
            
            if t not in one_hop:
                
                one_hop[t] = dict()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            if r_inv not in one_hop[t]:
                
                one_hop[t][r_inv] = set()
            
            one_hop[t][r_inv].add(s)

In [5]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, size_bd=50, threshold=100000):
                
        self.size_bd = size_bd
        
        self.threshold = threshold
    
    '''
    Given an entity s, here is the function to find:
      1. any else entity t that is directely connected to s
      2. most of the paths from s to each t with length L
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on active entities for validation')
        
        #here is the result dict. Its key is each entity t that is directly connected to s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #direct_nb contains all the direct neighbour of s
        direct_nb = set()
        
        if mode == 'direct_neighbour':
        
            for r in one_hop[s]:
            
                for t in one_hop[s][r]:
                
                    direct_nb.add(t)
                    
        elif mode == 'target_specified':
            
            direct_nb.add(t_input)
            
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                direct_nb.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {e1, ..., ek-1, node}
        from s to this node, we further find the direct neighbor t' of this node. 
        If t' is not a on-path entity (not among e1,...ek-1), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict):
            
            #when the current path is within lower_bd and upper_bd and its corresponding
            #length still within the size_bd and its tail node is within the note dict, 
            #we will then intend to add this path
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in direct_nb) and (length_dict[len(path)] < self.size_bd):
                
                #if this path already exists between the source entity and the current target node,
                #we will not count it.
                #here is an interesting situation: this path may exist between s and some other node t,
                #however, it does not exist between s and this node t. Then, we still count it: length_dict[len(path)] += 1
                #That is, each path may be counted for multiple times.
                #We count how many paths we "actually" found between entity pairs
                #Same type of path between different entity pairs are count separately.
                if tuple(path) not in res[node]:
                
                    res[node].add(tuple(path))
                
                    length_dict[len(path)] += 1
                
            #For some rare entities, we may face such a case: so many paths are evaluated,
            #but no entities on the paths are direct neighbors of the rare entity.
            #In this case, the recursion cannot be bounded and stoped by the size threshold.
            #In order to cure this, we count how many times the recursion happens on a specific length, using the count_dict.
            #Its key is length, value counts the recursion occurred to that length. 
            #The recursion is forced to stop for that length (and hence for longer lengths) once reach the threshold.
            if (len(path) < upper_bd) and (length_dict[len(path) + 1] < self.size_bd) and (
                count_dict[len(path)] <= self.threshold):
                
                #we randomly shuffle relation r so that the reading in order is not fixed
                temp_list = list()
                
                for r in one_hop[node]:
                    
                    temp_list.append(r)
                
                for i_0 in range(len(temp_list)):
                    
                    if count_dict[len(path)] > self.threshold:
                        break
                    
                    r = random.choice(temp_list)
                    
                    for i_1 in range(len(one_hop[node][r])):
                        
                        if count_dict[len(path)] > self.threshold:
                            break
                        
                        t = random.choice(list(one_hop[node][r]))
                        
                        if t not in on_path_en:
                                
                            count_dict[len(path)] += 1

                            helper(t, path + [r], on_path_en.union({t}), res, direct_nb, 
                                   lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        return(res, length_dict)

In [6]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [7]:
#load ids and relation/entity dicts
with open('../weight_bin/' + ids_name + '.pickle', 'rb') as handle:
    Dict = pickle.load(handle)

one_hop = Dict['one_hop']
data = Dict['data']
s_t_r = Dict['s_t_r']
entity2id = Dict['entity2id']
id2entity = Dict['id2entity']
relation2id = Dict['relation2id']
id2relation = Dict['id2relation']

#we want to keep the initial entity/relation dicts
entity2id_ini = deepcopy(entity2id)
id2entity_ini = deepcopy(id2entity)
relation2id_ini = deepcopy(relation2id)
id2relation_ini = deepcopy(id2relation)

num_r = len(id2relation)

In [8]:
#load the model
model = keras.models.load_model('../weight_bin/' + model_name + '.h5')

2023-01-10 12:19:30.871008: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [9]:
ind_train_path = '../data/' + data_name + '_ind/train.txt'
ind_valid_path = '../data/' + data_name + '_ind/valid.txt'
ind_test_path = '../data/' + data_name + '_ind/test.txt'

In [10]:
#load the test dataset
one_hop_ind = dict() 
data_ind = set()
s_t_r_ind = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_train_path, 
                        one_hop_ind, data_ind, s_t_r_ind,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [11]:
print(size_0, size_1, len(data_ind))

4707 7758 11714


In [12]:
#load the test dataset
one_hop_test = dict() 
data_test = set()
s_t_r_test = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_test_path, 
                        one_hop_test, data_test, s_t_r_test,
                        entity2id, id2entity, relation2id, id2relation)


len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [13]:
print(size_0, size_1, len(data_test))

7758 7758 1424


In [14]:
#load the validation for existing triple removal when ranking
one_hop_valid = dict() 
data_valid = set()
s_t_r_valid = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_valid_path, 
                        one_hop_valid, data_valid, s_t_r_valid,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [15]:
print(size_0, size_1, len(data_valid))

7758 7758 1416


In [16]:
print(len(entity2id), len(entity2id_ini))

7758 4707


In [17]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_test:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [18]:
overlapping = 0

for ele in data_valid:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [19]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_ind:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [20]:
def relation_ranking(s, t, lower_bd, upper_bd, one_hop, id2relation, model):
    
    path_holder = set()
    
    for iteration in range(20):
    
        result, length_dict = Class_2.obtain_paths('target_specified', 
                                                   s, t, lower_bd, upper_bd, one_hop)
        if t in result:
            
            for path in result[t]:
                
                path_holder.add(path)
    
    path_holder = list(path_holder)
    random.shuffle(path_holder)
    
    score_dict = defaultdict(float)
    
    count = 0
    
    if len(path_holder) >= 2:
    
        #iterate over path_1
        while count <= 50:

            temp_pair = random.sample(path_holder, 2)

            path_1, path_2 = temp_pair[0], temp_pair[1]

            #decide which path is shorter and which is longer
            if len(path_1) <= len(path_2):

                path_s, path_l = path_1, path_2

            else:

                path_s, path_l = path_2, path_1                            

            #whether lengths of the two paths satisfies the requirments
            if (len(path_s) < lower_bd) or (len(path_l) > upper_bd):

                raise ValueError('something wrong with path finding')

            list_s = list()
            list_l = list()
            list_r = list()

            for i in range(len(id2relation)):

                if i not in id2relation:

                    raise ValueError ('error when generating id2relation')

                list_s.append(list(path_s) + [num_r]*abs(len(path_s)-upper_bd))
                list_l.append(list(path_l) + [num_r]*abs(len(path_l)-upper_bd))
                list_r.append([i])

            input_s = np.array(list_s)
            input_l = np.array(list_l)
            input_r = np.array(list_r)

            pred = model.predict([input_s, input_l, input_r], verbose = 0)

            for i in range(pred.shape[0]):

                score_dict[i] += float(pred[i])

            count += 1
                
    print(len(score_dict), len(path_holder))

    return(score_dict)

In [21]:
########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_test), min(len(data_test), 500))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 10, one_hop_ind, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
              (s_true, sorted_list[p][1], t_true) in data_valid) or (
              (s_true, sorted_list[p][1], t_true) in data_ind):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

438 1720
Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 0 500
438 579
Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 1 500
438 16
Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 2 500
438 5119
Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 3 500
438 371
Hits@1 0.8 Hits@3 1.0 Hits@10 1.0 MRR 0.9 cur_rank 1 abs_cur_rank 1 total_num 4 500
438 693
Hits@1 0.6666666666666666 Hits@3 1.0 Hits@10 1.0 MRR 0.8055555555555555 cur_rank 2 abs_cur_rank 2 total_num 5 500
438 2244
Hits@1 0.7142857142857143 Hits@3 1.0 Hits@10 1.0 MRR 0.8333333333333333 cur_rank 0 abs_cur_rank 0 total_num 6 500
438 930
Hits@1 0.625 Hits@3 0.875 Hits@10 1.0 MRR 0.7470238095238095 cur_rank 6 abs_cur_rank 6 total_num 7 500
438 103
Hits@1 0.6666666666666666 Hits@3 0.8888888888888888 Hits@10 1.0 MRR 0.7751322751322751 cur_rank 0 abs_cur_rank 0 total_num 8 500
438 396
Hits@1 0.7 Hits@3 0.9 Hits@10 1.0 MRR 

438 2507
Hits@1 0.4838709677419355 Hits@3 0.7741935483870968 Hits@10 0.8548387096774194 MRR 0.6279439086835549 cur_rank 1 abs_cur_rank 1 total_num 61 500
438 2418
Hits@1 0.49206349206349204 Hits@3 0.7777777777777778 Hits@10 0.8571428571428571 MRR 0.633849560926673 cur_rank 0 abs_cur_rank 0 total_num 62 500
438 800
Hits@1 0.484375 Hits@3 0.765625 Hits@10 0.84375 MRR 0.6253661160826484 cur_rank 10 abs_cur_rank 11 total_num 63 500
438 324
Hits@1 0.47692307692307695 Hits@3 0.7538461538461538 Hits@10 0.8307692307692308 MRR 0.6164776996447469 cur_rank 20 abs_cur_rank 20 total_num 64 500
438 1463
Hits@1 0.4696969696969697 Hits@3 0.7575757575757576 Hits@10 0.8333333333333334 MRR 0.6147128860137658 cur_rank 1 abs_cur_rank 1 total_num 65 500
438 298
Hits@1 0.47761194029850745 Hits@3 0.7611940298507462 Hits@10 0.835820895522388 MRR 0.6204634399538589 cur_rank 0 abs_cur_rank 0 total_num 66 500
438 1015
Hits@1 0.4852941176470588 Hits@3 0.7647058823529411 Hits@10 0.8382352941176471 MRR 0.62604485995

438 1826
Hits@1 0.46153846153846156 Hits@3 0.7777777777777778 Hits@10 0.8717948717948718 MRR 0.6246260397094217 cur_rank 0 abs_cur_rank 0 total_num 116 500
438 1164
Hits@1 0.4661016949152542 Hits@3 0.7796610169491526 Hits@10 0.8728813559322034 MRR 0.6278071749661214 cur_rank 0 abs_cur_rank 0 total_num 117 500
438 3931
Hits@1 0.47058823529411764 Hits@3 0.7815126050420168 Hits@10 0.8739495798319328 MRR 0.6309348457647255 cur_rank 0 abs_cur_rank 0 total_num 118 500
438 1332
Hits@1 0.475 Hits@3 0.7833333333333333 Hits@10 0.875 MRR 0.634010388716686 cur_rank 0 abs_cur_rank 0 total_num 119 500
438 1505
Hits@1 0.4793388429752066 Hits@3 0.7851239669421488 Hits@10 0.8760330578512396 MRR 0.6370350962479531 cur_rank 0 abs_cur_rank 0 total_num 120 500
438 1287
Hits@1 0.47540983606557374 Hits@3 0.7868852459016393 Hits@10 0.8770491803278688 MRR 0.6359118577541174 cur_rank 1 abs_cur_rank 1 total_num 121 500
438 967
Hits@1 0.4796747967479675 Hits@3 0.7886178861788617 Hits@10 0.8780487804878049 MRR 0.6

438 2318
Hits@1 0.47953216374269003 Hits@3 0.8011695906432749 Hits@10 0.9064327485380117 MRR 0.6488785425320844 cur_rank 0 abs_cur_rank 0 total_num 170 500
438 1777
Hits@1 0.47674418604651164 Hits@3 0.7965116279069767 Hits@10 0.9069767441860465 MRR 0.6460749851142623 cur_rank 5 abs_cur_rank 5 total_num 171 500
438 204
Hits@1 0.4797687861271676 Hits@3 0.7976878612716763 Hits@10 0.9075144508670521 MRR 0.648120794448862 cur_rank 0 abs_cur_rank 0 total_num 172 500
438 1544
Hits@1 0.47701149425287354 Hits@3 0.7988505747126436 Hits@10 0.9080459770114943 MRR 0.6472695255152477 cur_rank 1 abs_cur_rank 1 total_num 173 500
438 1287
Hits@1 0.48 Hits@3 0.8 Hits@10 0.9085714285714286 MRR 0.6492851282265892 cur_rank 0 abs_cur_rank 0 total_num 174 500
438 340
Hits@1 0.4772727272727273 Hits@3 0.8011363636363636 Hits@10 0.9090909090909091 MRR 0.6484369172707564 cur_rank 1 abs_cur_rank 1 total_num 175 500
438 1315
Hits@1 0.4745762711864407 Hits@3 0.8022598870056498 Hits@10 0.9096045197740112 MRR 0.64759

438 182
Hits@1 0.4666666666666667 Hits@3 0.8088888888888889 Hits@10 0.9111111111111111 MRR 0.6455457704876795 cur_rank 2 abs_cur_rank 2 total_num 224 500
438 1666
Hits@1 0.4646017699115044 Hits@3 0.8097345132743363 Hits@10 0.911504424778761 MRR 0.6449017626536632 cur_rank 1 abs_cur_rank 1 total_num 225 500
438 4336
Hits@1 0.46255506607929514 Hits@3 0.8105726872246696 Hits@10 0.9118942731277533 MRR 0.6435292145068776 cur_rank 2 abs_cur_rank 2 total_num 226 500
438 274
Hits@1 0.4605263157894737 Hits@3 0.8070175438596491 Hits@10 0.9122807017543859 MRR 0.6414377121040696 cur_rank 5 abs_cur_rank 5 total_num 227 500
438 5
Hits@1 0.4585152838427948 Hits@3 0.8034934497816594 Hits@10 0.9082969432314411 MRR 0.6386508520217232 cur_rank 307 abs_cur_rank 307 total_num 228 500
0 0
Hits@1 0.45652173913043476 Hits@3 0.8 Hits@10 0.9043478260869565 MRR 0.635911916370211 cur_rank 114 abs_cur_rank 114 total_num 229 500
438 1773
Hits@1 0.45454545454545453 Hits@3 0.7965367965367965 Hits@10 0.904761904761904

438 2050
Hits@1 0.43010752688172044 Hits@3 0.7921146953405018 Hits@10 0.899641577060932 MRR 0.6196729544698186 cur_rank 8 abs_cur_rank 8 total_num 278 500
438 3733
Hits@1 0.42857142857142855 Hits@3 0.7928571428571428 Hits@10 0.9 MRR 0.6192455510609979 cur_rank 1 abs_cur_rank 1 total_num 279 500
438 253
Hits@1 0.4306049822064057 Hits@3 0.7935943060498221 Hits@10 0.900355871886121 MRR 0.6206005490999267 cur_rank 0 abs_cur_rank 0 total_num 280 500
438 366
Hits@1 0.4326241134751773 Hits@3 0.7943262411347518 Hits@10 0.900709219858156 MRR 0.6219459372236859 cur_rank 0 abs_cur_rank 0 total_num 281 500
438 214
Hits@1 0.43462897526501765 Hits@3 0.7950530035335689 Hits@10 0.901060070671378 MRR 0.6232818173041675 cur_rank 0 abs_cur_rank 0 total_num 282 500
438 671
Hits@1 0.43661971830985913 Hits@3 0.795774647887324 Hits@10 0.9014084507042254 MRR 0.6246082897784486 cur_rank 0 abs_cur_rank 0 total_num 283 500
438 2302
Hits@1 0.43859649122807015 Hits@3 0.7964912280701755 Hits@10 0.9017543859649123 M

438 2428
Hits@1 0.44144144144144143 Hits@3 0.7837837837837838 Hits@10 0.8978978978978979 MRR 0.6230401752256681 cur_rank 0 abs_cur_rank 0 total_num 332 500
438 944
Hits@1 0.4431137724550898 Hits@3 0.7844311377245509 Hits@10 0.8982035928143712 MRR 0.6241687974555314 cur_rank 0 abs_cur_rank 0 total_num 333 500
438 1065
Hits@1 0.44477611940298506 Hits@3 0.7850746268656716 Hits@10 0.8985074626865671 MRR 0.6252906816422313 cur_rank 0 abs_cur_rank 0 total_num 334 500
438 1256
Hits@1 0.44642857142857145 Hits@3 0.7857142857142857 Hits@10 0.8988095238095238 MRR 0.6264058879468676 cur_rank 0 abs_cur_rank 0 total_num 335 500
438 1416
Hits@1 0.44510385756676557 Hits@3 0.7833827893175074 Hits@10 0.8961424332344213 MRR 0.6247753751545121 cur_rank 12 abs_cur_rank 13 total_num 336 500
438 5499
Hits@1 0.4467455621301775 Hits@3 0.7840236686390533 Hits@10 0.8964497041420119 MRR 0.6258855071806821 cur_rank 0 abs_cur_rank 0 total_num 337 500
438 441
Hits@1 0.44542772861356933 Hits@3 0.7817109144542773 Hits

438 1269
Hits@1 0.45595854922279794 Hits@3 0.7849740932642487 Hits@10 0.8886010362694301 MRR 0.628641939326295 cur_rank 0 abs_cur_rank 0 total_num 385 500
438 1496
Hits@1 0.45478036175710596 Hits@3 0.7855297157622739 Hits@10 0.8888888888888888 MRR 0.6278788679929798 cur_rank 2 abs_cur_rank 2 total_num 386 500
438 5319
Hits@1 0.4536082474226804 Hits@3 0.7835051546391752 Hits@10 0.8891752577319587 MRR 0.6269049533847505 cur_rank 3 abs_cur_rank 3 total_num 387 500
438 335
Hits@1 0.455012853470437 Hits@3 0.7840616966580977 Hits@10 0.8894601542416453 MRR 0.6278640666151238 cur_rank 0 abs_cur_rank 0 total_num 388 500
438 714
Hits@1 0.45384615384615384 Hits@3 0.7846153846153846 Hits@10 0.8897435897435897 MRR 0.627108859606709 cur_rank 2 abs_cur_rank 2 total_num 389 500
438 2339
Hits@1 0.45268542199488493 Hits@3 0.782608695652174 Hits@10 0.8900255754475703 MRR 0.6260165095821394 cur_rank 4 abs_cur_rank 5 total_num 390 500
438 135
Hits@1 0.45153061224489793 Hits@3 0.7831632653061225 Hits@10 0.8

438 1275
Hits@1 0.4636363636363636 Hits@3 0.7863636363636364 Hits@10 0.8931818181818182 MRR 0.6321377739709999 cur_rank 0 abs_cur_rank 0 total_num 439 500
438 1871
Hits@1 0.46485260770975056 Hits@3 0.7868480725623582 Hits@10 0.8934240362811792 MRR 0.6329719286785487 cur_rank 0 abs_cur_rank 0 total_num 440 500
438 2800
Hits@1 0.4660633484162896 Hits@3 0.7873303167420814 Hits@10 0.8936651583710408 MRR 0.6338023089304071 cur_rank 0 abs_cur_rank 0 total_num 441 500
438 2595
Hits@1 0.4650112866817156 Hits@3 0.7878103837471784 Hits@10 0.8939051918735892 MRR 0.6331240493918132 cur_rank 2 abs_cur_rank 2 total_num 442 500
438 2475
Hits@1 0.46396396396396394 Hits@3 0.7882882882882883 Hits@10 0.8941441441441441 MRR 0.6324488450763662 cur_rank 2 abs_cur_rank 2 total_num 443 500
438 2643
Hits@1 0.4651685393258427 Hits@3 0.7887640449438202 Hits@10 0.8943820224719101 MRR 0.63327480272788 cur_rank 0 abs_cur_rank 0 total_num 444 500
438 463
Hits@1 0.4663677130044843 Hits@3 0.7892376681614349 Hits@10 0.

438 1929
Hits@1 0.4716599190283401 Hits@3 0.7935222672064778 Hits@10 0.8947368421052632 MRR 0.6378889088616351 cur_rank 11 abs_cur_rank 11 total_num 493 500
438 4642
Hits@1 0.4707070707070707 Hits@3 0.793939393939394 Hits@10 0.8949494949494949 MRR 0.6376103454093894 cur_rank 1 abs_cur_rank 1 total_num 494 500
438 610
Hits@1 0.4717741935483871 Hits@3 0.7943548387096774 Hits@10 0.8951612903225806 MRR 0.6383409697129995 cur_rank 0 abs_cur_rank 0 total_num 495 500
438 2551
Hits@1 0.47283702213279677 Hits@3 0.7947686116700201 Hits@10 0.8953722334004024 MRR 0.6390686538785669 cur_rank 0 abs_cur_rank 0 total_num 496 500
438 2582
Hits@1 0.4738955823293173 Hits@3 0.7951807228915663 Hits@10 0.8955823293172691 MRR 0.6397934156177666 cur_rank 0 abs_cur_rank 0 total_num 497 500
438 438
Hits@1 0.4749498997995992 Hits@3 0.7955911823647295 Hits@10 0.8957915831663327 MRR 0.640515272500296 cur_rank 0 abs_cur_rank 0 total_num 498 500
438 1915
Hits@1 0.474 Hits@3 0.796 Hits@10 0.896 MRR 0.6399009086219621

In [21]:
########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_test), min(len(data_test), 500))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 2, 10, one_hop_ind, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    inverse_r = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #we want to see how many inverse relaiton are ranked above true relation
        #then, we remove them from ranking, otherwise it is not fair for us to compare our
        #result with other model who does not consider inverse relations
        if sorted_list[p][1] % 2 != 0:
            
            inverse_r += 1
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
              (s_true, sorted_list[p][1], t_true) in data_valid) or (
              (s_true, sorted_list[p][1], t_true) in data_ind):
            
            exist_tri += 1
            
        p += 1
    
    if p - inverse_r - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - inverse_r - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - inverse_r - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - inverse_r - exist_tri + 1.) 
        
    print('Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - inverse_r - exist_tri, 
          'total_num', i, len(selected))

438 1112
Hits@1 0.0 Hits@3 1.0 Hits@10 1.0 MRR 0.5 cur_rank 1 total_num 0 500
438 943
Hits@1 0.0 Hits@3 1.0 Hits@10 1.0 MRR 0.41666666666666663 cur_rank 2 total_num 1 500
438 3148
Hits@1 0.0 Hits@3 1.0 Hits@10 1.0 MRR 0.4444444444444444 cur_rank 1 total_num 2 500
438 4729
Hits@1 0.0 Hits@3 1.0 Hits@10 1.0 MRR 0.41666666666666663 cur_rank 2 total_num 3 500
438 319
Hits@1 0.0 Hits@3 1.0 Hits@10 1.0 MRR 0.4333333333333333 cur_rank 1 total_num 4 500
438 1274
Hits@1 0.16666666666666666 Hits@3 1.0 Hits@10 1.0 MRR 0.5277777777777778 cur_rank 0 total_num 5 500
438 2636
Hits@1 0.14285714285714285 Hits@3 1.0 Hits@10 1.0 MRR 0.5238095238095238 cur_rank 1 total_num 6 500
438 3157
Hits@1 0.125 Hits@3 1.0 Hits@10 1.0 MRR 0.5 cur_rank 2 total_num 7 500
438 287
Hits@1 0.1111111111111111 Hits@3 0.8888888888888888 Hits@10 1.0 MRR 0.4666666666666667 cur_rank 4 total_num 8 500
438 846
Hits@1 0.2 Hits@3 0.9 Hits@10 1.0 MRR 0.52 cur_rank 0 total_num 9 500
438 381
Hits@1 0.2727272727272727 Hits@3 0.909090909

438 1051
Hits@1 0.3880597014925373 Hits@3 0.8208955223880597 Hits@10 0.8805970149253731 MRR 0.6071595043931521 cur_rank 0 total_num 66 500
438 1068
Hits@1 0.39705882352941174 Hits@3 0.8235294117647058 Hits@10 0.8823529411764706 MRR 0.6129365705050176 cur_rank 0 total_num 67 500
438 1760
Hits@1 0.4057971014492754 Hits@3 0.8260869565217391 Hits@10 0.8840579710144928 MRR 0.6185461854252347 cur_rank 0 total_num 68 500
438 583
Hits@1 0.4 Hits@3 0.8142857142857143 Hits@10 0.8714285714285714 MRR 0.6097948453613368 cur_rank 167 total_num 69 500
438 541
Hits@1 0.4084507042253521 Hits@3 0.8169014084507042 Hits@10 0.8732394366197183 MRR 0.6152906926097687 cur_rank 0 total_num 70 500
0 0
Hits@1 0.4027777777777778 Hits@3 0.8055555555555556 Hits@10 0.8611111111111112 MRR 0.6069975137982693 cur_rank 54 total_num 71 500
438 721
Hits@1 0.3972602739726027 Hits@3 0.8082191780821918 Hits@10 0.863013698630137 MRR 0.6055317944311697 cur_rank 1 total_num 72 500
438 1749
Hits@1 0.40540540540540543 Hits@3 0.81

438 4200
Hits@1 0.453125 Hits@3 0.796875 Hits@10 0.8828125 MRR 0.6355563910247383 cur_rank 0 total_num 127 500
438 2585
Hits@1 0.4573643410852713 Hits@3 0.7984496124031008 Hits@10 0.8837209302325582 MRR 0.6383815352803605 cur_rank 0 total_num 128 500
438 575
Hits@1 0.45384615384615384 Hits@3 0.8 Hits@10 0.8846153846153846 MRR 0.63731706193205 cur_rank 1 total_num 129 500
438 2512
Hits@1 0.4580152671755725 Hits@3 0.8015267175572519 Hits@10 0.8854961832061069 MRR 0.6400856339783703 cur_rank 0 total_num 130 500
438 1164
Hits@1 0.45454545454545453 Hits@3 0.7954545454545454 Hits@10 0.8863636363636364 MRR 0.6364991266502513 cur_rank 5 total_num 131 500
438 4245
Hits@1 0.45864661654135336 Hits@3 0.7969924812030075 Hits@10 0.8872180451127819 MRR 0.6392322159235576 cur_rank 0 total_num 132 500
438 2937
Hits@1 0.4626865671641791 Hits@3 0.7985074626865671 Hits@10 0.8880597014925373 MRR 0.6419245128196506 cur_rank 0 total_num 133 500
438 482
Hits@1 0.45925925925925926 Hits@3 0.7925925925925926 Hit

438 858
Hits@1 0.4574468085106383 Hits@3 0.7978723404255319 Hits@10 0.8882978723404256 MRR 0.6383844221335447 cur_rank 0 total_num 187 500
438 1555
Hits@1 0.4603174603174603 Hits@3 0.798941798941799 Hits@10 0.8888888888888888 MRR 0.6402977320693461 cur_rank 0 total_num 188 500
438 2512
Hits@1 0.4631578947368421 Hits@3 0.8 Hits@10 0.8894736842105263 MRR 0.64219090190056 cur_rank 0 total_num 189 500
438 256
Hits@1 0.4607329842931937 Hits@3 0.8010471204188482 Hits@10 0.8900523560209425 MRR 0.6414464469167875 cur_rank 1 total_num 190 500
438 12
Hits@1 0.4583333333333333 Hits@3 0.796875 Hits@10 0.890625 MRR 0.6391472466724293 cur_rank 4 total_num 191 500
438 348
Hits@1 0.46113989637305697 Hits@3 0.7979274611398963 Hits@10 0.8911917098445595 MRR 0.6410169500575462 cur_rank 0 total_num 192 500
438 609
Hits@1 0.4587628865979381 Hits@3 0.7938144329896907 Hits@10 0.8865979381443299 MRR 0.6380809274872937 cur_rank 13 total_num 193 500
438 1587
Hits@1 0.46153846153846156 Hits@3 0.7948717948717948 

438 298
Hits@1 0.45564516129032256 Hits@3 0.8145161290322581 Hits@10 0.9032258064516129 MRR 0.6465738959482309 cur_rank 1 total_num 247 500
438 515
Hits@1 0.4578313253012048 Hits@3 0.8152610441767069 Hits@10 0.9036144578313253 MRR 0.647993277892214 cur_rank 0 total_num 248 500
438 2618
Hits@1 0.46 Hits@3 0.816 Hits@10 0.904 MRR 0.6494013047806451 cur_rank 0 total_num 249 500
438 3630
Hits@1 0.46215139442231074 Hits@3 0.8167330677290837 Hits@10 0.9043824701195219 MRR 0.6507981123313198 cur_rank 0 total_num 250 500
438 367
Hits@1 0.4603174603174603 Hits@3 0.8174603174603174 Hits@10 0.9047619047619048 MRR 0.6501997071236558 cur_rank 1 total_num 251 500
438 944
Hits@1 0.4624505928853755 Hits@3 0.8181818181818182 Hits@10 0.9051383399209486 MRR 0.651582316976922 cur_rank 0 total_num 252 500
438 1970
Hits@1 0.4645669291338583 Hits@3 0.8188976377952756 Hits@10 0.905511811023622 MRR 0.6529540401384302 cur_rank 0 total_num 253 500
438 113
Hits@1 0.4627450980392157 Hits@3 0.8196078431372549 Hits@

438 1022
Hits@1 0.4383116883116883 Hits@3 0.8116883116883117 Hits@10 0.9058441558441559 MRR 0.6375362199282065 cur_rank 3 total_num 307 500
438 2470
Hits@1 0.4401294498381877 Hits@3 0.8122977346278317 Hits@10 0.9061488673139159 MRR 0.6387092418701864 cur_rank 0 total_num 308 500
438 1683
Hits@1 0.43870967741935485 Hits@3 0.8096774193548387 Hits@10 0.9064516129032258 MRR 0.63745534108996 cur_rank 3 total_num 309 500
438 227
Hits@1 0.43729903536977494 Hits@3 0.8102893890675241 Hits@10 0.9067524115755627 MRR 0.6370133625012463 cur_rank 1 total_num 310 500
438 731
Hits@1 0.4391025641025641 Hits@3 0.8108974358974359 Hits@10 0.907051282051282 MRR 0.6381767812111782 cur_rank 0 total_num 311 500
438 786
Hits@1 0.43769968051118213 Hits@3 0.8083067092651757 Hits@10 0.9073482428115016 MRR 0.6366703591199816 cur_rank 5 total_num 312 500
438 1622
Hits@1 0.43630573248407645 Hits@3 0.8089171974522293 Hits@10 0.9076433121019108 MRR 0.6362351031992174 cur_rank 1 total_num 313 500
438 453
Hits@1 0.43492

438 155
Hits@1 0.4375 Hits@3 0.8070652173913043 Hits@10 0.9157608695652174 MRR 0.6396931792239805 cur_rank 0 total_num 367 500
438 1680
Hits@1 0.43902439024390244 Hits@3 0.8075880758807588 Hits@10 0.9159891598915989 MRR 0.640669620472696 cur_rank 0 total_num 368 500
438 1656
Hits@1 0.43783783783783786 Hits@3 0.8081081081081081 Hits@10 0.9162162162162162 MRR 0.6402894323092563 cur_rank 1 total_num 369 500
438 3147
Hits@1 0.4366576819407008 Hits@3 0.8086253369272237 Hits@10 0.9164420485175202 MRR 0.639911293677695 cur_rank 1 total_num 370 500
438 863
Hits@1 0.4381720430107527 Hits@3 0.8091397849462365 Hits@10 0.9166666666666666 MRR 0.6408792740710345 cur_rank 0 total_num 371 500
438 2801
Hits@1 0.43967828418230565 Hits@3 0.8096514745308311 Hits@10 0.9168900804289544 MRR 0.6418420642209781 cur_rank 0 total_num 372 500
438 3346
Hits@1 0.4385026737967914 Hits@3 0.8074866310160428 Hits@10 0.9144385026737968 MRR 0.6401394815415957 cur_rank 196 total_num 373 500
438 1292
Hits@1 0.4373333333333

438 7
Hits@1 0.4392523364485981 Hits@3 0.8037383177570093 Hits@10 0.9182242990654206 MRR 0.6405132132571166 cur_rank 29 total_num 427 500
438 790
Hits@1 0.4405594405594406 Hits@3 0.8041958041958042 Hits@10 0.9184149184149184 MRR 0.6413511777949789 cur_rank 0 total_num 428 500
438 40
Hits@1 0.43953488372093025 Hits@3 0.8023255813953488 Hits@10 0.9162790697674419 MRR 0.6400534618776261 cur_rank 11 total_num 429 500
438 6
Hits@1 0.4385150812064965 Hits@3 0.8004640371229699 Hits@10 0.91415313225058 MRR 0.6387230980836333 cur_rank 14 total_num 430 500
438 3219
Hits@1 0.4375 Hits@3 0.8009259259259259 Hits@10 0.9143518518518519 MRR 0.6384019798010322 cur_rank 1 total_num 431 500
0 0
Hits@1 0.43648960739030024 Hits@3 0.7990762124711316 Hits@10 0.9122401847575058 MRR 0.63704916132331 cur_rank 18 total_num 432 500
438 1388
Hits@1 0.4377880184331797 Hits@3 0.7995391705069125 Hits@10 0.9124423963133641 MRR 0.6378854535783255 cur_rank 0 total_num 433 500
0 0
Hits@1 0.4367816091954023 Hits@3 0.79770

438 453
Hits@1 0.44672131147540983 Hits@3 0.8012295081967213 Hits@10 0.9139344262295082 MRR 0.6424467732267992 cur_rank 1 total_num 487 500
438 379
Hits@1 0.4458077709611452 Hits@3 0.8016359918200409 Hits@10 0.9141104294478528 MRR 0.6421554710320614 cur_rank 1 total_num 488 500
438 981
Hits@1 0.44693877551020406 Hits@3 0.8020408163265306 Hits@10 0.9142857142857143 MRR 0.6428857659891388 cur_rank 0 total_num 489 500
438 1187
Hits@1 0.4460285132382892 Hits@3 0.8004073319755601 Hits@10 0.9144602851323829 MRR 0.6418673777546542 cur_rank 6 total_num 490 500
438 826
Hits@1 0.4451219512195122 Hits@3 0.8008130081300813 Hits@10 0.9146341463414634 MRR 0.6415790294258845 cur_rank 1 total_num 491 500
438 834
Hits@1 0.44421906693711966 Hits@3 0.7991886409736308 Hits@10 0.9127789046653144 MRR 0.6402891770704199 cur_rank 175 total_num 492 500
438 388
Hits@1 0.4433198380566802 Hits@3 0.7995951417004049 Hits@10 0.9129554655870445 MRR 0.6400051908820182 cur_rank 1 total_num 493 500
438 85
Hits@1 0.44242

In [18]:
'''Ranking on transductive setting'''

########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_test), int(len(data_test)/5))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 2, 10, one_hop, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    inverse_r = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #we want to see how many inverse relaiton are ranked above true relation
        #then, we remove them from ranking, otherwise it is not fair for us to compare our
        #result with other model who does not consider inverse relations
        if temp_list[p][1] % 2 != 0:
            
            inverse_r += 1
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
              (s_true, sorted_list[p][1], t_true) in data_valid) or (
              (s_true, sorted_list[p][1], t_true) in data):
            
            exist_tri += 1
            
        p += 1
    
    if p - inverse_r - exist_tri == 0:
        
        Hits_at_1 += 1
        
    print('Calculating_hit_at_one', Hits_at_1/(i+1), 'current ranking:', 
          p - inverse_r - exist_tri, 'triple number:', i, len(selected))

438 66
Calculating_hit_at_one 0.0 current ranking: 1 triple number: 0 672
438 1171
Calculating_hit_at_one 0.0 current ranking: 1 triple number: 1 672
438 1044
Calculating_hit_at_one 0.0 current ranking: 1 triple number: 2 672
438 600
Calculating_hit_at_one 0.25 current ranking: 0 triple number: 3 672
438 4356
Calculating_hit_at_one 0.2 current ranking: 1 triple number: 4 672
438 377
Calculating_hit_at_one 0.3333333333333333 current ranking: 0 triple number: 5 672
438 1864
Calculating_hit_at_one 0.42857142857142855 current ranking: 0 triple number: 6 672
438 509
Calculating_hit_at_one 0.375 current ranking: 2 triple number: 7 672
438 1675
Calculating_hit_at_one 0.3333333333333333 current ranking: 27 triple number: 8 672
438 625
Calculating_hit_at_one 0.3 current ranking: 1 triple number: 9 672
438 2672
Calculating_hit_at_one 0.36363636363636365 current ranking: 0 triple number: 10 672
438 1404
Calculating_hit_at_one 0.3333333333333333 current ranking: 2 triple number: 11 672
438 553
Cal

KeyboardInterrupt: 

In [144]:
########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_test), int(len(data_test)/5))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 2, 10, one_hop_ind, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    inverse_r = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #we want to see how many inverse relaiton are ranked above true relation
        #then, we remove them from ranking, otherwise it is not fair for us to compare our
        #result with other model who does not consider inverse relations
        if temp_list[p][1] % 2 != 0:
            
            inverse_r += 1
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
              (s_true, sorted_list[p][1], t_true) in data_valid) or (
              (s_true, sorted_list[p][1], t_true) in data_ind):
            
            exist_tri += 1
            
        p += 1
    
    if p - inverse_r - exist_tri == 0:
        
        Hits_at_1 += 1
        
    print('Calculating_hit_at_one', Hits_at_1/(i+1), 'current ranking:', 
          p - inverse_r - exist_tri, 'triple number:', i, len(selected))

438 263
Calculating_hit_at_one 1.0 current ranking: 0 0 284
438 3646
Calculating_hit_at_one 0.5 current ranking: 1 1 284
438 1588
Calculating_hit_at_one 0.3333333333333333 current ranking: 10 2 284
438 707
Calculating_hit_at_one 0.25 current ranking: 4 3 284
438 3451
Calculating_hit_at_one 0.4 current ranking: 0 4 284
438 2757
Calculating_hit_at_one 0.3333333333333333 current ranking: 7 5 284
438 559
Calculating_hit_at_one 0.2857142857142857 current ranking: 1 6 284
438 2868
Calculating_hit_at_one 0.375 current ranking: 0 7 284
438 3601
Calculating_hit_at_one 0.3333333333333333 current ranking: 1 8 284
438 1470
Calculating_hit_at_one 0.4 current ranking: 0 9 284
438 37
Calculating_hit_at_one 0.36363636363636365 current ranking: 1 10 284
438 2329
Calculating_hit_at_one 0.3333333333333333 current ranking: 7 11 284
438 1230
Calculating_hit_at_one 0.38461538461538464 current ranking: 0 12 284
438 929
Calculating_hit_at_one 0.35714285714285715 current ranking: 1 13 284
438 691
Calculating_h

438 660
Calculating_hit_at_one 0.3063063063063063 current ranking: 1 110 284
438 2095
Calculating_hit_at_one 0.30357142857142855 current ranking: 2 111 284
438 2160
Calculating_hit_at_one 0.30973451327433627 current ranking: 0 112 284
438 1207
Calculating_hit_at_one 0.3157894736842105 current ranking: 0 113 284
438 1671
Calculating_hit_at_one 0.3130434782608696 current ranking: 1 114 284
438 5113
Calculating_hit_at_one 0.3103448275862069 current ranking: 1 115 284
438 4248
Calculating_hit_at_one 0.3162393162393162 current ranking: 0 116 284
438 2935
Calculating_hit_at_one 0.3135593220338983 current ranking: 96 117 284
438 36
Calculating_hit_at_one 0.31092436974789917 current ranking: 1 118 284
438 429
Calculating_hit_at_one 0.30833333333333335 current ranking: 1 119 284
438 4473
Calculating_hit_at_one 0.30578512396694213 current ranking: 3 120 284


KeyboardInterrupt: 