### Train the inductive link prediction model

In [1]:
data_name = 'nell_v3'
model_id = 'main_inductive'

In [2]:
#difine the names for saving
model_name = 'Model_' + model_id + '_' + data_name
one_hop_model_name = 'One_hop_model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [3]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
import pickle

from collections import defaultdict
from copy import deepcopy
from sklearn.utils import shuffle
from sys import getsizeof

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [4]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = set()
            
            one_hop[s].add((r,t))
            
            if t not in one_hop:
                
                one_hop[t] = set()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            one_hop[t].add((r_inv,s))
            
        #change each set in one_hop to list
        for e in one_hop:
            
            one_hop[e] = list(one_hop[e])

In [5]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, amount_bd=50, size_bd=50, threshold=20000):
        
        self.amount_bd = amount_bd #how many Tuples we choose in one_hop[node] for next recursion
                        
        self.size_bd = size_bd #size bound limit the number of paths to a target entity t
        
        #number of times paths with specific length been performed for recursion
        self.threshold = threshold
        
    '''
    Given an entity s, the function will find the paths from s to other entities, using recursion.
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on existing entities')

        #here is the result dict. Its key is each entity t sharing paths from s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #qualified_t contains the types of t we want to consider,
        #that is, what t will be added to the result set.
        qualified_t = set()

        #under this mode, we will only consider the direct neighbour of s
        if mode == 'direct_neighbour':
        
            for Tuple in one_hop[s]:
            
                t = Tuple[1]
                
                qualified_t.add(t)
        
        #under this mode, we will only consider one specified entity t
        elif mode == 'target_specified':
            
            qualified_t.add(t_input)
        
        #under this mode, we will consider any entity
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                qualified_t.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {s, e1, ..., ek-1, node}
        from s to this node, we will further find the direct neighbor t' of this node. 
        If t' is not an on-path entity (not among s, e1,...ek-1, node), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict):

            #when the current path is within lower_bd and upper_bd, 
            #and the node is among the qualified t, and it has not been fill of paths w.r.t size_limit,
            #we will add this path to the node
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in qualified_t) and (len(res[node]) < self.size_bd):
                
                res[node].add(tuple(path))
                    
            #won't start new recursions if the current path length already reaches upper limit
            #or the number of recursions performed on this length has reached the limit
            if (len(path) < upper_bd) and (count_dict[len(path)] <= self.threshold):
                                
                #temp list is the id list for us to go-over one_hop[node]
                temp_list = [i for i in range(len(one_hop[node]))]
                random.shuffle(temp_list) #so we random-shuffle the list
                
                #only take 20 recursions if there are too many (r,t)
                for i in temp_list[:self.amount_bd]:
                    
                    #obtain tuple of (r,t)
                    Tuple = one_hop[node][i]
                    r, t = Tuple[0], Tuple[1]
                    
                    #add to count_dict even if eventually this step not proceed
                    count_dict[len(path)] += 1
                    
                    #if t not on the path and we not exceed the computation threshold, 
                    #then finally proceed to next recursion
                    if (t not in on_path_en) and (count_dict[len(path)] <= self.threshold):

                        helper(t, path + [r], on_path_en.union({t}), res, qualified_t, 
                               lower_bd, upper_bd, one_hop, count_dict)

        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict)
        
        return(res, count_dict)

In [6]:
train_path = '../data/' + data_name + '/train.txt'
valid_path = '../data/' + data_name + '/valid.txt'
test_path = '../data/' + data_name + '/test.txt'

In [7]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [8]:
#define the dictionaries and sets for load KG
one_hop = dict() 
data = set()
s_t_r = defaultdict(set)

#define the dictionaries, which is shared by initail and inductive train/valid/test
entity2id = dict()
id2entity = dict()
relation2id = dict()
id2relation = dict()

#fill in the sets and dicts
Class_1.load_train_data(train_path, one_hop, data, s_t_r,
                        entity2id, id2entity, relation2id, id2relation)

In [9]:
#define the dictionaries and sets for load KG
one_hop_valid = dict() 
data_valid = set()
s_t_r_valid = defaultdict(set)

#fill in the sets and dicts
Class_1.load_train_data(valid_path, one_hop_valid, data_valid, s_t_r_valid,
                        entity2id, id2entity, relation2id, id2relation)

In [10]:
#define the dictionaries and sets for load KG
one_hop_test = dict() 
data_test = set()
s_t_r_test = defaultdict(set)

#fill in the sets and dicts
Class_1.load_train_data(test_path, one_hop_test, data_test, s_t_r_test,
                        entity2id, id2entity, relation2id, id2relation)

#### Build the path-based siamese neural network structure

We use biLSTM to train on the input path embedding sequence to predict the output embedding or the relation.

In [11]:
# Input layer, using integer to represent each relation type
#note that inputs_path is the path inputs, while inputs_out_re is the output relation inputs
fst_path = keras.Input(shape=(None,), dtype="int32")
scd_path = keras.Input(shape=(None,), dtype="int32")
thd_path = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of paths are not the same
in_embd_var = layers.Embedding(len(relation2id)+1, 300)

# Obtain the embedding
fst_p_embd = in_embd_var(fst_path)
scd_p_embd = in_embd_var(scd_path)
thd_p_embd = in_embd_var(thd_path)

# Embed each integer in a 300-dimensional vector as output
rela_embd = layers.Embedding(len(relation2id)+1, 300)(id_rela)

#add 2 layer bi-directional LSTM
lstm_layer_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_layer_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

#first LSTM layer
fst_lstm_mid = lstm_layer_1(fst_p_embd)
scd_lstm_mid = lstm_layer_1(scd_p_embd)
thd_lstm_mid = lstm_layer_1(thd_p_embd)

#second LSTM layer
fst_lstm_out = lstm_layer_2(fst_lstm_mid)
scd_lstm_out = lstm_layer_2(scd_lstm_mid)
thd_lstm_out = lstm_layer_2(thd_lstm_mid)

#reduce max
fst_reduce_max = tf.reduce_max(fst_lstm_out, axis=1)
scd_reduce_max = tf.reduce_max(scd_lstm_out, axis=1)
thd_reduce_max = tf.reduce_max(thd_lstm_out, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
path_concat = layers.concatenate([fst_reduce_max, scd_reduce_max, thd_reduce_max], axis=-1)

#add dropout on top of the concatenation from all channels
dropout = layers.Dropout(0.25)(path_concat)

#multiply into output embd size by dense layer: (Batch, 300)
path_out_vect = layers.Dense(300, activation='tanh')(dropout)

#remove the time dimension from the output embd since there is only one step
rela_out_embd = tf.reduce_sum(rela_embd, axis=1)

# Normalize the vectors to have unit length
path_out_vect_norm = tf.math.l2_normalize(path_out_vect, axis=-1)
rela_out_embd_norm = tf.math.l2_normalize(rela_out_embd, axis=-1)

# Calculate the dot product
dot_product = layers.Dot(axes=-1)([path_out_vect_norm, rela_out_embd_norm])

#put together the model
model = keras.Model([fst_path, scd_path, thd_path, id_rela], dot_product)

2023-03-24 20:33:55.697851: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [12]:
#config the Adam optimizer 
opt = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['binary_accuracy'])

#### Build the subgraph-based siamese neural network

In [13]:
#each input is an vector with number of relations to be dim:
#each dim represent the existence (1) or not (0) of an out-going relation from the entity
source_path_1 = keras.Input(shape=(None,), dtype="int32")
source_path_2 = keras.Input(shape=(None,), dtype="int32")
source_path_3 = keras.Input(shape=(None,), dtype="int32")

target_path_1 = keras.Input(shape=(None,), dtype="int32")
target_path_2 = keras.Input(shape=(None,), dtype="int32")
target_path_3 = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela_ = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of paths are not the same
in_embd_var_ = layers.Embedding(len(relation2id)+1, 300)

# Obtain the source embeddings
source_embd_1 = in_embd_var_(source_path_1)
source_embd_2 = in_embd_var_(source_path_2)
source_embd_3 = in_embd_var_(source_path_3)

#Obtain the target embeddings
target_embd_1 = in_embd_var_(target_path_1)
target_embd_2 = in_embd_var_(target_path_2)
target_embd_3 = in_embd_var_(target_path_3)

# Embed each integer in a 300-dimensional vector as output
rela_embd_ = layers.Embedding(len(relation2id)+1, 300)(id_rela_)

#add 2 layer bi-directional LSTM network
lstm_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

###source lstm implimentation########
#first LSTM layer
source_mid_1 = lstm_1(source_embd_1)
source_mid_2 = lstm_1(source_embd_2)
source_mid_3 = lstm_1(source_embd_3)

#second LSTM layer
source_out_1 = lstm_2(source_mid_1)
source_out_2 = lstm_2(source_mid_2)
source_out_3 = lstm_2(source_mid_3)

#reduce max
source_max_1 = tf.reduce_max(source_out_1, axis=1)
source_max_2 = tf.reduce_max(source_out_2, axis=1)
source_max_3 = tf.reduce_max(source_out_3, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
source_concat = layers.concatenate([source_max_1, source_max_2, source_max_3], axis=-1)

#add dropout on top of the concatenation from all channels
source_dropout = layers.Dropout(0.25)(source_concat)

###target lstm implimentation########
#first LSTM layer
target_mid_1 = lstm_1(target_embd_1)
target_mid_2 = lstm_1(target_embd_2)
target_mid_3 = lstm_1(target_embd_3)

#second LSTM layer
target_out_1 = lstm_2(target_mid_1)
target_out_2 = lstm_2(target_mid_2)
target_out_3 = lstm_2(target_mid_3)

#reduce max
target_max_1 = tf.reduce_max(target_out_1, axis=1)
target_max_2 = tf.reduce_max(target_out_2, axis=1)
target_max_3 = tf.reduce_max(target_out_3, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
target_concat = layers.concatenate([target_max_1, target_max_2, target_max_3], axis=-1)

#add dropout on top of the concatenation from all channels
target_dropout = layers.Dropout(0.25)(target_concat)

#further concatenate source and target output embeddings: (Batch, 1800)
final_concat = layers.concatenate([source_dropout, target_dropout], axis=-1)

#multiply into output embd size by dense layer: (Batch, 300)
out_vect = layers.Dense(300, activation='tanh')(final_concat)

#remove the time dimension from the output embd since there is only one step
rela_out_embd_ = tf.reduce_sum(rela_embd_, axis=1)

# Normalize the vectors to have unit length
out_vect_norm = tf.math.l2_normalize(out_vect, axis=-1)
rela_out_embd_norm_ = tf.math.l2_normalize(rela_out_embd_, axis=-1)

# Calculate the dot product
dot_product_ = layers.Dot(axes=-1)([out_vect_norm, rela_out_embd_norm_])

#put together the model
model_2 = keras.Model([source_path_1, source_path_2, source_path_3,
                       target_path_1, target_path_2, target_path_3, id_rela_], dot_product_)

In [14]:
#config the Adam optimizer 
opt_ = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model_2.compile(loss='binary_crossentropy', optimizer=opt_, metrics=['binary_accuracy'])

### Build the big-batch for path-based model
We will build the big-batch for the path-based model training. That is, we will build three list to store three paths, respectively.

In order to reduce computational complexity, we will run the path-finding algorithm for each entity e in the dataset before the training. That is, for each entity e, we will have two dictionaries. Dict 1 stores the paths between e and any other entities in the dataset. Will Dict 2 stores the paths between e and its direct neighbors. The two dicts will be used and invariant throughout the training.

* At each step, three different paths between two entities s and t are selected. Each path is append to one of the list. 
* If this step is for positive samples, the existing relation r will be selected between s and t. If there are more than one relation from s to t, we randomly choose one. Also, the label list will be appended 1.
* If this step is for negative samples, one relation that does not exist between s and t will be selected randomly and append to the relation list. Also, the label list will be appended 0.
* In practice, the positive step is always fallowed by a negative step. The same paths in the positive step will be used in the next negative step, while the relation is a negative one chosen in the above way.
* We do this until the length limit is reached.

**For relation prediciton, we will only need to train using (s,r,t) triple. (t,r-1,s) is not necessary and hence not included in training.**

In [15]:
#function to build the big batche for path-based training
def build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    count = 0
    for s in existing_ids:
        
        #impliment the path finding algorithm to find paths between s and t
        result, length_dict = Class_2.obtain_paths('direct_neighbour', s, 'nb', lower_bd, upper_bd, one_hop)
        
        for iteration in range(10):

            #proceed only if at least three paths are between s and t
            for t in result:

                if len(s_t_r[(s,t)]) == 0:

                    raise ValueError(s,t,id2entity[s], id2entity[t])

                #we are only interested in forward link in relation prediciton
                ini_r_list = list()

                #obtain initial relations between s and t
                for r in s_t_r[(s,t)]:
                    if r % 2 == 0:#initial relation id is always an even number
                        ini_r_list.append(r)

                #if there exist more than three paths between s and t, 
                #and inital connection between s and t exists,
                #and not every r in the relation dictionary exists between s and t (although this is rare)
                #we then proceed
                if len(result[t]) >= 3 and len(ini_r_list) > 0 and len(ini_r_list) < int(num_ini_r):

                    #obtain the list form of all the paths from s to t
                    temp_path_list = list(result[t])

                    temp_pair = random.sample(temp_path_list, 3)

                    path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

                    #####positive#####################
                    #append the paths: note that we add the space holder id at the end of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    r = random.choice(ini_r_list)
                    x_r_list.append([r])
                    y_list.append(1.)

                    #####negative#####################
                    #append the paths: note that we add the space holder id at the end
                    #of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    neg_r_list = list(ini_r_id_set.difference(set(ini_r_list)))
                    r_ran = random.choice(neg_r_list)
                    x_r_list.append([r_ran])
                    y_list.append(0.)
        
        count += 1
        if count % 100 == 0:
            print('generating big-batches for path-based model', count, len(existing_ids))

### Build the big-batch for the subgraph-based network training

Again, to reduce computational complexity, we store the subgraph of each entity e at the biginning.

* At each step, we will select one triple (s,r,t) from the dataset. Then, reaching out paths of s and t is generated respectively according to their out-going relations.
* We will select three paths for each of source and target entity. Add them to the corresponding list.
* If this is a positive sample step, the id of relation r is appended to the relation list.
* If this is a negative sample step, the id of a random relation is appended to the relation lsit.
* Similarly, one negative sample step always follows one positive step. The one-hop vectors from the previous positve step is used again for the negative step.

In [16]:
#Again, it is too slow to run the path-finding algorithm again and again on the complete FB15K-237
#Instead, we will find the subgraph for each entity once.
#then in the subgraph based training, the subgraphs are stored and used for multiple times
def store_subgraph_dicts(lower_bd, upper_bd, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
    
    #the ids to start path finding
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    #Dict stores the subgraph for each entity
    Dict_1 = dict()
    
    count = 0
    for s in existing_ids:
        
        path_set = set()
            
        result, length_dict = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)

        for t_ in result:
            for path in result[t_]:
                path_set.add(path)

        del(result, length_dict)
        
        path_list = list(path_set)
        
        path_select = random.sample(path_list, min(len(path_list), 100))
            
        Dict_1[s] = deepcopy(path_select)
        
        count += 1
        if count % 100 == 0:
            print('generating and storing paths for the path-based model', count, len(existing_ids))
        
    return(Dict_1)

In [17]:
#function to build the big-batch for one-hope neighbor training
def build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_s_list, x_t_list, x_r_list, y_list, Dict,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
    
    data = list(data)
    
    for iteration in range(10):

        data = shuffle(data)

        for i_0 in range(len(data)):

            triple = data[i_0]

            s, r, t = triple[0], triple[1], triple[2] #obtain entities and relation IDs

            path_s, path_t = Dict[s], Dict[t] #sets holding all the paths from s or t

            #see if both path_s and path_t have at least three paths
            if len(path_s) >= 3 and len(path_t) >= 3:

                #change to lists
                path_s, path_t = list(path_s), list(path_t)

                #randomly obtain three paths
                temp_s = random.sample(path_s, 3)
                temp_t = random.sample(path_t, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #####positive step###########
                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative step###########
                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                neg_r_list = list(ini_r_id_set.difference({r}))
                r_ran = random.choice(neg_r_list)
                x_r_list.append([r_ran])
                y_list.append(0.)

            if i_0 % 1000 == 0:
                print('generating big-batches for subgraph-based model', i_0, len(data), iteration)

### Start Training: load the KG and call classes

Here, we use the validation set to see the training efficiency. That is, we use the validation to check whether the true relation between entities can be predicted by paths.

The trick is: in validation, we have to use the same relation ID and entity ID as in the training. But we don't want to use the links in training anymore. That is, in validation, we want to use (and update if necessary) entity2id, id2entity, relation2id and id2relation. But we want to use new one_hop, data, data_ and s_t_r for validation set. Then, path-finding will also be based on new one_hop.


In [18]:
model_name

'Model_main_inductive_nell_v3'

In [19]:
one_hop_model_name

'One_hop_model_main_inductive_nell_v3'

In [20]:
ids_name

'IDs_main_inductive_nell_v3'

In [21]:
#first, we save the relation and ids
Dict = dict()

#save training data
Dict['one_hop'] = one_hop
Dict['data'] = data
Dict['s_t_r'] = s_t_r

#save valid data
Dict['one_hop_valid'] = one_hop_valid
Dict['data_valid'] = data_valid
Dict['s_t_r_valid'] = s_t_r_valid

#save test data
Dict['one_hop_test'] = one_hop_test
Dict['data_test'] = data_test
Dict['s_t_r_test'] = s_t_r_test

#save shared dictionaries
Dict['entity2id'] = entity2id
Dict['id2entity'] = id2entity
Dict['relation2id'] = relation2id
Dict['id2relation'] = id2relation

with open('../weight_bin/' + ids_name + '.pickle', 'wb') as handle:
    pickle.dump(Dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [22]:
###train the path-based model
lower_bd = 1
upper_bd = 10
num_epoch = 10
batch_size = 32
        
#define the training lists
train_p_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, list(), list()

#define the validation lists
valid_p_list, valid_r_list, valid_y_list = {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      train_p_list, train_r_list, train_y_list,
                      relation2id, entity2id, id2relation, id2entity)

#fill in the validation array list
build_big_batches_path(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                      valid_p_list, valid_r_list, valid_y_list,
                      relation2id, entity2id, id2relation, id2entity)    

#######################################
###do the training#####################
#sometimes the validation dataset is so small so sparse, 
#which cannot find three paths between any pair of s and t.
#in such a case, we will divide the training big-batch into train and valid
if len(valid_y_list) >= 100:
    #generate the input arrays
    x_train_1 = np.asarray(train_p_list['1'], dtype='int')
    x_train_2 = np.asarray(train_p_list['2'], dtype='int')
    x_train_3 = np.asarray(train_p_list['3'], dtype='int')
    x_train_r = np.asarray(train_r_list, dtype='int')
    y_train = np.asarray(train_y_list, dtype='int')

    #generate the validation arrays
    x_valid_1 = np.asarray(valid_p_list['1'], dtype='int')
    x_valid_2 = np.asarray(valid_p_list['2'], dtype='int')
    x_valid_3 = np.asarray(valid_p_list['3'], dtype='int')
    x_valid_r = np.asarray(valid_r_list, dtype='int')
    y_valid = np.asarray(valid_y_list, dtype='int')

else:
    split = int(len(train_y_list)*0.8)
    #generate the input arrays
    x_train_1 = np.asarray(train_p_list['1'][:split], dtype='int')
    x_train_2 = np.asarray(train_p_list['2'][:split], dtype='int')
    x_train_3 = np.asarray(train_p_list['3'][:split], dtype='int')
    x_train_r = np.asarray(train_r_list[:split], dtype='int')
    y_train = np.asarray(train_y_list[:split], dtype='int')

    #generate the validation arrays
    x_valid_1 = np.asarray(train_p_list['1'][split:], dtype='int')
    x_valid_2 = np.asarray(train_p_list['2'][split:], dtype='int')
    x_valid_3 = np.asarray(train_p_list['3'][split:], dtype='int')
    x_valid_r = np.asarray(train_r_list[split:], dtype='int')
    y_valid = np.asarray(train_y_list[split:], dtype='int')

#do the training
model.fit([x_train_1, x_train_2, x_train_3, x_train_r], y_train, 
          validation_data=([x_valid_1, x_valid_2, x_valid_3, x_valid_r], y_valid),
          batch_size=batch_size, epochs=num_epoch)   

# Save model and weights
add_h5 = model_name + '.h5'
save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path = os.path.join(save_dir, add_h5)
model.save(model_path)
print('Save model')
del(model)

generating big-batches for path-based model 100 4647
generating big-batches for path-based model 200 4647
generating big-batches for path-based model 300 4647
generating big-batches for path-based model 400 4647
generating big-batches for path-based model 500 4647
generating big-batches for path-based model 600 4647
generating big-batches for path-based model 700 4647
generating big-batches for path-based model 800 4647
generating big-batches for path-based model 900 4647
generating big-batches for path-based model 1000 4647
generating big-batches for path-based model 1100 4647
generating big-batches for path-based model 1200 4647
generating big-batches for path-based model 1300 4647
generating big-batches for path-based model 1400 4647
generating big-batches for path-based model 1500 4647
generating big-batches for path-based model 1600 4647
generating big-batches for path-based model 1700 4647
generating big-batches for path-based model 1800 4647
generating big-batches for path-based

In [23]:
###train the subgraph-based model
lower_bd = 1
upper_bd = 3
num_epoch = 10
batch_size = 32

Dict_train = store_subgraph_dicts(lower_bd, upper_bd, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity)

Dict_valid = store_subgraph_dicts(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                         relation2id, entity2id, id2relation, id2entity)
        
#define the training lists
train_s_list, train_t_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, {'1': [], '2': [], '3': []}, list(), list()

#define the validation lists
valid_s_list, valid_t_list, valid_r_list, valid_y_list = {'1': [], '2': [], '3': []}, {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      train_s_list, train_t_list, train_r_list, train_y_list, Dict_train,
                      relation2id, entity2id, id2relation, id2entity)

#fill in the validation array list
build_big_batches_subgraph(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                      valid_s_list, valid_t_list, valid_r_list, valid_y_list, Dict_valid,
                      relation2id, entity2id, id2relation, id2entity)    

#######################################
###do the training#####################
#sometimes the validation dataset is so small so sparse, 
#which cannot find three paths between any pair of s and t.
#in such a case, we will divide the training big-batch into train and valid
if len(valid_y_list) >= 100:
    #generate the input arrays
    x_train_s_1 = np.asarray(train_s_list['1'], dtype='int')
    x_train_s_2 = np.asarray(train_s_list['2'], dtype='int')
    x_train_s_3 = np.asarray(train_s_list['3'], dtype='int')

    x_train_t_1 = np.asarray(train_t_list['1'], dtype='int')
    x_train_t_2 = np.asarray(train_t_list['2'], dtype='int')
    x_train_t_3 = np.asarray(train_t_list['3'], dtype='int')

    x_train_r = np.asarray(train_r_list, dtype='int')
    y_train = np.asarray(train_y_list, dtype='int')

    #generate the validation arrays
    x_valid_s_1 = np.asarray(valid_s_list['1'], dtype='int')
    x_valid_s_2 = np.asarray(valid_s_list['2'], dtype='int')
    x_valid_s_3 = np.asarray(valid_s_list['3'], dtype='int')

    x_valid_t_1 = np.asarray(valid_t_list['1'], dtype='int')
    x_valid_t_2 = np.asarray(valid_t_list['2'], dtype='int')
    x_valid_t_3 = np.asarray(valid_t_list['3'], dtype='int')

    x_valid_r = np.asarray(valid_r_list, dtype='int')
    y_valid = np.asarray(valid_y_list, dtype='int')

else:
    split = int(len(train_y_list)*0.8)
    #generate the input arrays
    x_train_s_1 = np.asarray(train_s_list['1'][:split], dtype='int')
    x_train_s_2 = np.asarray(train_s_list['2'][:split], dtype='int')
    x_train_s_3 = np.asarray(train_s_list['3'][:split], dtype='int')

    x_train_t_1 = np.asarray(train_t_list['1'][:split], dtype='int')
    x_train_t_2 = np.asarray(train_t_list['2'][:split], dtype='int')
    x_train_t_3 = np.asarray(train_t_list['3'][:split], dtype='int')

    x_train_r = np.asarray(train_r_list[:split], dtype='int')
    y_train = np.asarray(train_y_list[:split], dtype='int')

    #generate the validation arrays
    x_valid_s_1 = np.asarray(train_s_list['1'][split:], dtype='int')
    x_valid_s_2 = np.asarray(train_s_list['2'][split:], dtype='int')
    x_valid_s_3 = np.asarray(train_s_list['3'][split:], dtype='int')

    x_valid_t_1 = np.asarray(train_t_list['1'][split:], dtype='int')
    x_valid_t_2 = np.asarray(train_t_list['2'][split:], dtype='int')
    x_valid_t_3 = np.asarray(train_t_list['3'][split:], dtype='int')

    x_valid_r = np.asarray(train_r_list[split:], dtype='int')
    y_valid = np.asarray(train_y_list[split:], dtype='int')

#do the training
model_2.fit([x_train_s_1, x_train_s_2, x_train_s_3, x_train_t_1, x_train_t_2, x_train_t_3, x_train_r], y_train, 
          validation_data=([x_valid_s_1, x_valid_s_2, x_valid_s_3, x_valid_t_1, x_valid_t_2, x_valid_t_3, x_valid_r], y_valid),
          batch_size=batch_size, epochs=num_epoch)

# Save model and weights
one_hop_add_h5 = one_hop_model_name + '.h5'
one_hop_save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(one_hop_save_dir):
    os.makedirs(one_hop_save_dir)
one_hop_model_path = os.path.join(one_hop_save_dir, one_hop_add_h5)
model_2.save(one_hop_model_path)
print('Save model')
del(model_2, Dict_train, Dict_valid)

generating and storing paths for the path-based model 100 4647
generating and storing paths for the path-based model 200 4647
generating and storing paths for the path-based model 300 4647
generating and storing paths for the path-based model 400 4647
generating and storing paths for the path-based model 500 4647
generating and storing paths for the path-based model 600 4647
generating and storing paths for the path-based model 700 4647
generating and storing paths for the path-based model 800 4647
generating and storing paths for the path-based model 900 4647
generating and storing paths for the path-based model 1000 4647
generating and storing paths for the path-based model 1100 4647
generating and storing paths for the path-based model 1200 4647
generating and storing paths for the path-based model 1300 4647
generating and storing paths for the path-based model 1400 4647
generating and storing paths for the path-based model 1500 4647
generating and storing paths for the path-based m

generating big-batches for subgraph-based model 6000 16393 5
generating big-batches for subgraph-based model 7000 16393 5
generating big-batches for subgraph-based model 8000 16393 5
generating big-batches for subgraph-based model 9000 16393 5
generating big-batches for subgraph-based model 10000 16393 5
generating big-batches for subgraph-based model 11000 16393 5
generating big-batches for subgraph-based model 12000 16393 5
generating big-batches for subgraph-based model 13000 16393 5
generating big-batches for subgraph-based model 14000 16393 5
generating big-batches for subgraph-based model 15000 16393 5
generating big-batches for subgraph-based model 16000 16393 5
generating big-batches for subgraph-based model 0 16393 6
generating big-batches for subgraph-based model 1000 16393 6
generating big-batches for subgraph-based model 2000 16393 6
generating big-batches for subgraph-based model 3000 16393 6
generating big-batches for subgraph-based model 4000 16393 6
generating big-batch

### Result on the testset for inductive link prediction

We use the testset for inductive link prediction.

In [1]:
data_name = 'fb237_v4'
model_id = 'main_inductive'

In [2]:
#difine the names for saving
model_name = 'Model_' + model_id + '_' + data_name
one_hop_model_name = 'One_hop_model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [3]:
ids_name

'IDs_main_inductive_fb237_v4'

In [4]:
one_hop_model_name

'One_hop_model_main_inductive_fb237_v4'

In [5]:
model_name

'Model_main_inductive_fb237_v4'

In [6]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
import pickle

from collections import defaultdict
from copy import deepcopy
from sklearn.utils import shuffle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [7]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = set()
            
            one_hop[s].add((r,t))
            
            if t not in one_hop:
                
                one_hop[t] = set()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            one_hop[t].add((r_inv,s))
            
        #change each set in one_hop to list
        for e in one_hop:
            
            one_hop[e] = list(one_hop[e])

In [8]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, amount_bd=50, size_bd=50, threshold=20000):
        
        self.amount_bd = amount_bd #how many Tuples we choose in one_hop[node] for next recursion
                        
        self.size_bd = size_bd #size bound limit the number of paths to a target entity t
        
        #number of times paths with specific length been performed for recursion
        self.threshold = threshold
        
    '''
    Given an entity s, the function will find the paths from s to other entities, using recursion.
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on existing entities')

        #here is the result dict. Its key is each entity t sharing paths from s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #qualified_t contains the types of t we want to consider,
        #that is, what t will be added to the result set.
        qualified_t = set()

        #under this mode, we will only consider the direct neighbour of s
        if mode == 'direct_neighbour':
        
            for Tuple in one_hop[s]:
            
                t = Tuple[1]
                
                qualified_t.add(t)
        
        #under this mode, we will only consider one specified entity t
        elif mode == 'target_specified':
            
            qualified_t.add(t_input)
        
        #under this mode, we will consider any entity
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                qualified_t.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {s, e1, ..., ek-1, node}
        from s to this node, we will further find the direct neighbor t' of this node. 
        If t' is not an on-path entity (not among s, e1,...ek-1, node), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict):

            #when the current path is within lower_bd and upper_bd, 
            #and the node is among the qualified t, and it has not been fill of paths w.r.t size_limit,
            #we will add this path to the node
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in qualified_t) and (len(res[node]) < self.size_bd):
                
                res[node].add(tuple(path))
                    
            #won't start new recursions if the current path length already reaches upper limit
            #or the number of recursions performed on this length has reached the limit
            if (len(path) < upper_bd) and (count_dict[len(path)] <= self.threshold):
                                
                #temp list is the id list for us to go-over one_hop[node]
                temp_list = [i for i in range(len(one_hop[node]))]
                random.shuffle(temp_list) #so we random-shuffle the list
                
                #only take 20 recursions if there are too many (r,t)
                for i in temp_list[:self.amount_bd]:
                    
                    #obtain tuple of (r,t)
                    Tuple = one_hop[node][i]
                    r, t = Tuple[0], Tuple[1]
                    
                    #add to count_dict even if eventually this step not proceed
                    count_dict[len(path)] += 1
                    
                    #if t not on the path and we not exceed the computation threshold, 
                    #then finally proceed to next recursion
                    if (t not in on_path_en) and (count_dict[len(path)] <= self.threshold):

                        helper(t, path + [r], on_path_en.union({t}), res, qualified_t, 
                               lower_bd, upper_bd, one_hop, count_dict)

        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict)
        
        return(res, count_dict)

In [9]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [10]:
#load ids and relation/entity dicts
with open('../weight_bin/' + ids_name + '.pickle', 'rb') as handle:
    Dict = pickle.load(handle)
    
#save training data
one_hop = Dict['one_hop']
data = Dict['data']
s_t_r = Dict['s_t_r']

#save valid data
one_hop_valid = Dict['one_hop_valid']
data_valid = Dict['data_valid']
s_t_r_valid = Dict['s_t_r_valid']

#save test data
one_hop_test = Dict['one_hop_test']
data_test = Dict['data_test']
s_t_r_test = Dict['s_t_r_test']

#save shared dictionaries
entity2id = Dict['entity2id']
id2entity = Dict['id2entity']
relation2id = Dict['relation2id']
id2relation = Dict['id2relation']

#we want to keep the initial entity/relation dicts before adding new entities
entity2id_ini = deepcopy(entity2id)
id2entity_ini = deepcopy(id2entity)
relation2id_ini = deepcopy(relation2id)
id2relation_ini = deepcopy(id2relation)

num_r = len(id2relation)
num_r

438

In [11]:
ids_name

'IDs_main_inductive_fb237_v4'

In [12]:
model_name

'Model_main_inductive_fb237_v4'

In [13]:
#load the model
model = keras.models.load_model('../weight_bin/' + model_name + '.h5')

2023-03-27 10:21:31.919630: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [14]:
#load the one-hop neighbor model
model_2 = keras.models.load_model('../weight_bin/' + one_hop_model_name + '.h5')

In [15]:
ind_train_path = '../data/' + data_name + '_ind/train.txt'
ind_valid_path = '../data/' + data_name + '_ind/valid.txt'
ind_test_path = '../data/' + data_name + '_ind/test.txt'

In [16]:
#load the test dataset
one_hop_ind = dict() 
data_ind = set()
s_t_r_ind = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_train_path, 
                        one_hop_ind, data_ind, s_t_r_ind,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [17]:
print(size_0, size_1, len(data_ind))

4707 7758 11714


In [18]:
#load the test dataset
one_hop_ind_test = dict() 
data_ind_test = set()
s_t_r_ind_test = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_test_path, 
                        one_hop_ind_test, data_ind_test, s_t_r_ind_test,
                        entity2id, id2entity, relation2id, id2relation)


len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [19]:
print(size_0, size_1, len(data_ind_test))

7758 7758 1424


In [20]:
#load the validation for existing triple removal when ranking
one_hop_ind_valid = dict() 
data_ind_valid = set()
s_t_r_ind_valid = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_valid_path, 
                        one_hop_ind_valid, data_ind_valid, s_t_r_ind_valid,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [21]:
print(size_0, size_1, len(data_ind_valid))

7758 7758 1416


In [22]:
print(len(entity2id), len(entity2id_ini))

7758 4707


In [23]:
#obtain all the inital entities and new entities
ini_ent_set, new_ent_set, all_ent_set = set(), set(), set()

for ID in id2entity:
    all_ent_set.add(ID)
    if ID in id2entity_ini:
        ini_ent_set.add(ID)
    else:
        new_ent_set.add(ID)
        
print(len(ini_ent_set), len(new_ent_set), len(all_ent_set))

4707 3051 7758


In [24]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_ind_test:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [25]:
overlapping = 0

for ele in data_ind_valid:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [26]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_ind:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [27]:
#the function to do path-based relation scoring
def path_based_relation_scoring(s, t, lower_bd, upper_bd, one_hop, id2relation, model):
    
    path_holder = set()
    
    for iteration in range(3):
    
        result, length_dict = Class_2.obtain_paths('target_specified', 
                                                   s, t, lower_bd, upper_bd, one_hop)
        if t in result:
            
            for path in result[t]:
                
                path_holder.add(path)
                
        del(result, length_dict)
    
    path_holder = list(path_holder)
    random.shuffle(path_holder)
    
    score_dict = defaultdict(float)
    count_dict = defaultdict(int)
    
    count = 0
    
    if len(path_holder) >= 3:
    
        #iterate over path_1
        while count < 10:

            temp_pair = random.sample(path_holder, 3)

            path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

            list_1 = list()
            list_2 = list()
            list_3 = list()
            list_r = list()

            for i in range(len(id2relation)):

                if i not in id2relation:

                    raise ValueError ('error when generating id2relation')
                
                #only care about initial relations
                if i % 2 == 0:

                    list_1.append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    list_2.append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    list_3.append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))
                    list_r.append([i])
            
            #change to arrays
            input_1 = np.array(list_1)
            input_2 = np.array(list_2)
            input_3 = np.array(list_3)
            input_r = np.array(list_r)

            pred = model.predict([input_1, input_2, input_3, input_r], verbose = 0)

            for i in range(pred.shape[0]):
                #need to times 2 to go back to relation id from pred position
                score_dict[2*i] += float(pred[i])
                count_dict[2*i] += 1

            count += 1
            
    #average the score
    for r in score_dict:
        score_dict[r] = deepcopy(score_dict[r]/float(count_dict[r]))
    
    print(len(score_dict), len(path_holder))

    return(score_dict)

In [28]:
#the function to do path-based triple scoring: input one triple
def path_based_triple_scoring(s, r, t, lower_bd, upper_bd, one_hop, id2relation, model):
    
    path_holder = set()
    
    for iteration in range(3):
    
        result, length_dict = Class_2.obtain_paths('target_specified', 
                                                   s, t, lower_bd, upper_bd, one_hop)
        if t in result:
            
            for path in result[t]:
                
                path_holder.add(path)
                
        del(result, length_dict)
    
    path_holder = list(path_holder)
    random.shuffle(path_holder)
    
    score = 0.
    count = 0
    
    if len(path_holder) >= 3:
    
        #iterate over path_1
        while count < 10:

            temp_pair = random.sample(path_holder, 3)
            path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

            list_1 = list()
            list_2 = list()
            list_3 = list()
            list_r = list()

            list_1.append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
            list_2.append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
            list_3.append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))
            list_r.append([r])
            
            count += 1
            
        #change to arrays
        input_1 = np.array(list_1)
        input_2 = np.array(list_2)
        input_3 = np.array(list_3)
        input_r = np.array(list_r)

        pred = model.predict([input_1, input_2, input_3, input_r], verbose = 0)

        for i in range(pred.shape[0]):
            score += float(pred[i])
            
        #average the score
        score = score/float(count)

    return(score)

In [29]:
#subgraph based relation scoring
def subgraph_relation_scoring(s, t, lower_bd, upper_bd, one_hop, id2relation, model_2):
    
    #lists holding the input to the network
    list_s_1 = list()
    list_s_2 = list()
    list_s_3 = list()
    list_t_1 = list()
    list_t_2 = list()
    list_t_3 = list()
    list_r = list()
    
    path_s, path_t = set(), set() #sets holding all the paths from s or t
    
    for iteration in range(3):
    
        #obtain the paths out from s or t by "any target" mode. That is, 
        result_s, length_dict_s = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)
        result_t, length_dict_t = Class_2.obtain_paths('any_target', t, 'any', lower_bd, upper_bd, one_hop)

        #add paths to the source/target path_set
        for e in result_s:
            for path in result_s[e]:
                path_s.add(path)
        for e in result_t:
            for path in result_t[e]:
                path_t.add(path)
                
        del(result_s, length_dict_s, result_t, length_dict_t)
    
    #final output: the score dict
    score_dict = defaultdict(float)
    count_dict = defaultdict(int)
    
    #see if both path_s and path_t have at least three paths
    if len(path_s) >= 3 and len(path_t) >= 3:

        #change to lists
        path_s, path_t = list(path_s), list(path_t)
        
        count = 0
        while count < 10:

            #randomly obtain three paths
            temp_s = random.sample(path_s, 3)
            temp_t = random.sample(path_t, 3)
            s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
            t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]
            
            #add all forward (initial relation)
            for i in range(len(id2relation)):

                if i not in id2relation:

                    raise ValueError ('error when generating id2relation')
                    
                if i % 2 == 0:

                    #append the paths: note that we add the space holder id at the end of the shorter path
                    list_s_1.append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                    list_s_2.append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                    list_s_3.append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                    
                    list_t_1.append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                    list_t_2.append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                    list_t_3.append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                    
                    list_r.append([i])
                
            #change to arrays
            input_s_1 = np.array(list_s_1)
            input_s_2 = np.array(list_s_2)
            input_s_3 = np.array(list_s_3)
            input_t_1 = np.array(list_t_1)
            input_t_2 = np.array(list_t_2)
            input_t_3 = np.array(list_t_3)
            input_r = np.array(list_r)
            
            pred = model_2.predict([input_s_1, input_s_2, input_s_3,
                                    input_t_1, input_t_2, input_t_3, input_r], verbose = 0)

            for i in range(pred.shape[0]):
                #need to times 2 to go back to relation id from pred position
                score_dict[2*i] += float(pred[i])
                count_dict[2*i] += 1

            count += 1
            
    #average the score
    for r in score_dict:
        score_dict[r] = deepcopy(score_dict[r]/float(count_dict[r]))
            
    print(len(score_dict), len(path_s), len(path_t))
        
    return(score_dict)

In [30]:
#subgraph based triple scoring
def subgraph_triple_scoring(s, r, t, lower_bd, upper_bd, one_hop, id2relation, model_2):
    
    #lists holding the input to the network
    list_s_1 = list()
    list_s_2 = list()
    list_s_3 = list()
    list_t_1 = list()
    list_t_2 = list()
    list_t_3 = list()
    list_r = list()
    
    path_s, path_t = set(), set() #sets holding all the paths from s or t
    
    for iteration in range(3):
    
        #obtain the paths out from s or t by "any target" mode. That is, 
        result_s, length_dict_s = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)
        result_t, length_dict_t = Class_2.obtain_paths('any_target', t, 'any', lower_bd, upper_bd, one_hop)

        #add paths to the source/target path_set
        for e in result_s:
            for path in result_s[e]:
                path_s.add(path)
        for e in result_t:
            for path in result_t[e]:
                path_t.add(path)
                
        del(result_s, length_dict_s, result_t, length_dict_t)
    
    #final output: the score dict
    score = 0.
    
    #see if both path_s and path_t have at least three paths
    if len(path_s) >= 3 and len(path_t) >= 3:

        #change to lists
        path_s, path_t = list(path_s), list(path_t)
        
        count = 0
        while count < 10:

            #randomly obtain three paths
            temp_s = random.sample(path_s, 3)
            temp_t = random.sample(path_t, 3)
            s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
            t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

            #append the paths: note that we add the space holder id at the end of the shorter path
            list_s_1.append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
            list_s_2.append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
            list_s_3.append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

            list_t_1.append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
            list_t_2.append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
            list_t_3.append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

            list_r.append([r])
            count += 1
                
        #change to arrays
        input_s_1 = np.array(list_s_1)
        input_s_2 = np.array(list_s_2)
        input_s_3 = np.array(list_s_3)
        input_t_1 = np.array(list_t_1)
        input_t_2 = np.array(list_t_2)
        input_t_3 = np.array(list_t_3)
        input_r = np.array(list_r)

        pred = model_2.predict([input_s_1, input_s_2, input_s_3,
                                input_t_1, input_t_2, input_t_3, input_r], verbose = 0)

        for i in range(pred.shape[0]):
            score += float(pred[i])

        #average the score
        score = score/float(count)
        
    return(score)

#### Not fine tuned 

In [28]:
########################################################
#obtain the Hits@N for relation prediction##############

#we select all the triples in the inductive test set
selected = list(data_ind_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    #run the path-based scoring
    score_dict_path = path_based_relation_scoring(s_true, t_true, 1, 10, one_hop_ind, id2relation, model)
    
    #run the one-hop neighbour based scoring
    score_dict_subg = subgraph_relation_scoring(s_true, t_true, 1, 3, one_hop_ind, id2relation, model_2)
    
    #final score dict
    score_dict = defaultdict(float)
    
    for r in score_dict_path:
        score_dict[r] += score_dict_path[r]
    for r in score_dict_subg:
        score_dict[r] += score_dict_subg[r]
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        #again, we only care about initial relation prediciton
        if r % 2 == 0:
        
            if r in score_dict:

                temp_list.append([score_dict[r], r])

            else:

                temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
            (s_true, sorted_list[p][1], t_true) in data_valid) or (
            (s_true, sorted_list[p][1], t_true) in data) or (
            (s_true, sorted_list[p][1], t_true) in data_ind) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_valid) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_test):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', r_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

142 131
1420 723 883
checkcorrect 16 16 real score 1.9938775181770325 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 0 809
142 91
1420 44 10
checkcorrect 118 118 real score 1.992316997051239 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 1 809
142 29
1420 640 584
checkcorrect 198 198 real score 1.6865678131580353 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 2 total_num 2 809
142 83
1420 640 67
checkcorrect 108 108 real score 1.9815807819366456 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 3 809
142 126
1420 28 139
checkcorrect 34 34 real score 1.9599078178405762 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 4 809
142 3
1420 508 1974
checkcorrect 62 62 real score 0.08618164416402578 Hits@1 0.8333333333333334 Hits@3 0.8333333333333334 Hits@10 0.8333333333333334 MRR 0.8382352941176471 cur_rank 33 abs_cur_rank 35 total_num 5 809
142 118
1420 123

142 116
1420 80 42
checkcorrect 74 74 real score 1.9944399833679198 Hits@1 0.6744186046511628 Hits@3 0.8372093023255814 Hits@10 0.9302325581395349 MRR 0.7676333446475284 cur_rank 0 abs_cur_rank 0 total_num 42 809
142 150
1420 1776 2066
checkcorrect 196 196 real score 1.5893937945365906 Hits@1 0.6590909090909091 Hits@3 0.8409090909090909 Hits@10 0.9318181818181818 MRR 0.7577628898449332 cur_rank 2 abs_cur_rank 2 total_num 43 809
142 150
1420 181 184
checkcorrect 16 16 real score 1.9985550940036774 Hits@1 0.6666666666666666 Hits@3 0.8444444444444444 Hits@10 0.9333333333333333 MRR 0.763145936737268 cur_rank 0 abs_cur_rank 0 total_num 44 809
142 25
1420 86 175
checkcorrect 162 162 real score 1.9298656761646271 Hits@1 0.6739130434782609 Hits@3 0.8478260869565217 Hits@10 0.9347826086956522 MRR 0.7682949381125448 cur_rank 0 abs_cur_rank 0 total_num 45 809
142 6
1420 561 21
checkcorrect 176 176 real score 1.7987667322158813 Hits@1 0.6808510638297872 Hits@3 0.851063829787234 Hits@10 0.936170212

0 1
1420 216 38
checkcorrect 10 10 real score 0.9845277667045593 Hits@1 0.6506024096385542 Hits@3 0.8072289156626506 Hits@10 0.927710843373494 MRR 0.745215693456672 cur_rank 0 abs_cur_rank 0 total_num 82 809
142 150
1420 225 261
checkcorrect 252 252 real score 1.6205146372318269 Hits@1 0.6428571428571429 Hits@3 0.7976190476190477 Hits@10 0.9285714285714286 MRR 0.7393202685345688 cur_rank 3 abs_cur_rank 3 total_num 83 809
142 5
1420 28 14
checkcorrect 34 34 real score 1.9599109292030334 Hits@1 0.6470588235294118 Hits@3 0.8 Hits@10 0.9294117647058824 MRR 0.7423870889047504 cur_rank 0 abs_cur_rank 0 total_num 84 809
142 144
1420 307 2098
checkcorrect 116 116 real score 1.560139548778534 Hits@1 0.6395348837209303 Hits@3 0.8023255813953488 Hits@10 0.9302325581395349 MRR 0.7395686343826021 cur_rank 1 abs_cur_rank 1 total_num 85 809
142 51
1420 464 119
checkcorrect 160 160 real score 1.9513022780418396 Hits@1 0.6436781609195402 Hits@3 0.8045977011494253 Hits@10 0.9310344827586207 MRR 0.742562

142 150
1420 365 640
checkcorrect 12 12 real score 1.9937394499778747 Hits@1 0.6885245901639344 Hits@3 0.8278688524590164 Hits@10 0.9344262295081968 MRR 0.7744170868815777 cur_rank 0 abs_cur_rank 0 total_num 121 809
142 150
1420 218 181
checkcorrect 244 244 real score 0.7052269876003265 Hits@1 0.6829268292682927 Hits@3 0.8292682926829268 Hits@10 0.9349593495934959 MRR 0.7708310401047628 cur_rank 2 abs_cur_rank 3 total_num 122 809
142 147
1420 1089 2630
checkcorrect 116 116 real score 1.8933486938476562 Hits@1 0.6774193548387096 Hits@3 0.8306451612903226 Hits@10 0.9354838709677419 MRR 0.7686469188135953 cur_rank 1 abs_cur_rank 1 total_num 123 809
142 7
1420 206 244
checkcorrect 196 196 real score 1.1508235037326813 Hits@1 0.672 Hits@3 0.824 Hits@10 0.936 MRR 0.7640977434630866 cur_rank 4 abs_cur_rank 5 total_num 124 809
142 141
1420 626 1951
checkcorrect 198 198 real score 1.6642430245876312 Hits@1 0.6666666666666666 Hits@3 0.8253968253968254 Hits@10 0.9365079365079365 MRR 0.76067897830

142 7
1420 27 5
checkcorrect 34 34 real score 1.9598975479602814 Hits@1 0.6832298136645962 Hits@3 0.84472049689441 Hits@10 0.9503105590062112 MRR 0.7735230927508434 cur_rank 0 abs_cur_rank 0 total_num 160 809
142 60
1420 402 89
checkcorrect 30 30 real score 1.9848996579647065 Hits@1 0.6851851851851852 Hits@3 0.845679012345679 Hits@10 0.9506172839506173 MRR 0.7749210983511469 cur_rank 0 abs_cur_rank 0 total_num 161 809
142 3
1420 262 868
checkcorrect 252 252 real score 1.4877402663230896 Hits@1 0.6809815950920245 Hits@3 0.8466257668711656 Hits@10 0.950920245398773 MRR 0.7722119709583996 cur_rank 2 abs_cur_rank 3 total_num 162 809
142 40
1420 278 52
checkcorrect 130 130 real score 1.9534773588180543 Hits@1 0.6829268292682927 Hits@3 0.8475609756097561 Hits@10 0.9512195121951219 MRR 0.7736009223549947 cur_rank 0 abs_cur_rank 0 total_num 163 809
142 51
1420 223 54
checkcorrect 10 10 real score 1.800457274913788 Hits@1 0.6848484848484848 Hits@3 0.8484848484848485 Hits@10 0.9515151515151515 M

142 150
1420 323 474
checkcorrect 16 16 real score 1.9985447883605958 Hits@1 0.69 Hits@3 0.835 Hits@10 0.96 MRR 0.7756999785533177 cur_rank 0 abs_cur_rank 0 total_num 199 809
142 150
1420 1611 2700
checkcorrect 196 196 real score 1.434458214044571 Hits@1 0.6865671641791045 Hits@3 0.835820895522388 Hits@10 0.9601990049751243 MRR 0.7734991494726214 cur_rank 2 abs_cur_rank 4 total_num 200 809
142 150
1420 554 597
checkcorrect 16 16 real score 1.996891015768051 Hits@1 0.6881188118811881 Hits@3 0.8366336633663366 Hits@10 0.9603960396039604 MRR 0.7746204408118659 cur_rank 0 abs_cur_rank 0 total_num 201 809
142 10
1420 133 234
checkcorrect 36 36 real score 1.7097932279109955 Hits@1 0.6847290640394089 Hits@3 0.8374384236453202 Hits@10 0.9605911330049262 MRR 0.7732676307586054 cur_rank 1 abs_cur_rank 3 total_num 202 809
0 1
1420 682 84
checkcorrect 12 12 real score 0.9964189529418945 Hits@1 0.6862745098039216 Hits@3 0.8382352941176471 Hits@10 0.9607843137254902 MRR 0.7743790639411613 cur_rank 0

142 150
1420 217 205
checkcorrect 16 16 real score 1.7978214964270591 Hits@1 0.6903765690376569 Hits@3 0.8368200836820083 Hits@10 0.9665271966527197 MRR 0.7791073739637248 cur_rank 0 abs_cur_rank 0 total_num 238 809
142 150
1420 197 184
checkcorrect 16 16 real score 1.9790741860866548 Hits@1 0.6916666666666667 Hits@3 0.8375 Hits@10 0.9666666666666667 MRR 0.7800277599055426 cur_rank 0 abs_cur_rank 0 total_num 239 809
142 150
1420 518 483
checkcorrect 244 244 real score 0.32548848092556 Hits@1 0.6887966804979253 Hits@3 0.8340248962655602 Hits@10 0.9626556016597511 MRR 0.7770677553692816 cur_rank 14 abs_cur_rank 14 total_num 240 809
142 150
1420 428 212
checkcorrect 16 16 real score 1.9985065460205078 Hits@1 0.6900826446280992 Hits@3 0.8347107438016529 Hits@10 0.9628099173553719 MRR 0.7779889629917226 cur_rank 0 abs_cur_rank 0 total_num 241 809
142 84
1420 22 12
checkcorrect 118 118 real score 1.9923193156719208 Hits@1 0.691358024691358 Hits@3 0.8353909465020576 Hits@10 0.9629629629629629

142 150
1420 2606 2394
checkcorrect 252 252 real score 1.3975101835560053 Hits@1 0.6834532374100719 Hits@3 0.8345323741007195 Hits@10 0.9640287769784173 MRR 0.7737599765771271 cur_rank 1 abs_cur_rank 4 total_num 277 809
142 70
1420 28 7
checkcorrect 118 118 real score 1.9923224985599517 Hits@1 0.6845878136200717 Hits@3 0.8351254480286738 Hits@10 0.96415770609319 MRR 0.7745708727184278 cur_rank 0 abs_cur_rank 0 total_num 278 809
142 150
1420 43 33
checkcorrect 74 74 real score 1.9940488159656524 Hits@1 0.6857142857142857 Hits@3 0.8357142857142857 Hits@10 0.9642857142857143 MRR 0.7753759767444334 cur_rank 0 abs_cur_rank 0 total_num 279 809
142 11
1420 176 25
checkcorrect 4 4 real score 0.5331997390603647 Hits@1 0.6832740213523132 Hits@3 0.8327402135231317 Hits@10 0.9644128113879004 MRR 0.7730614714891152 cur_rank 7 abs_cur_rank 7 total_num 280 809
142 132
1420 20 193
checkcorrect 10 10 real score 1.044320183992386 Hits@1 0.6808510638297872 Hits@3 0.8297872340425532 Hits@10 0.964539007092

142 150
1420 427 1388
checkcorrect 48 48 real score 1.989937311410904 Hits@1 0.6908517350157729 Hits@3 0.8359621451104101 Hits@10 0.9652996845425867 MRR 0.7789248024521459 cur_rank 0 abs_cur_rank 1 total_num 316 809
142 62
1420 210 64
checkcorrect 112 112 real score 0.1666665755212307 Hits@1 0.6886792452830188 Hits@3 0.8333333333333334 Hits@10 0.9622641509433962 MRR 0.7766999715369775 cur_rank 13 abs_cur_rank 13 total_num 317 809
142 150
1420 148 429
checkcorrect 244 244 real score 0.6519699066877365 Hits@1 0.6865203761755486 Hits@3 0.8307210031347962 Hits@10 0.9623824451410659 MRR 0.7748921346356076 cur_rank 4 abs_cur_rank 4 total_num 318 809
0 1
1420 566 59
checkcorrect 198 198 real score 0.8913632035255432 Hits@1 0.684375 Hits@3 0.83125 Hits@10 0.9625 MRR 0.773512263381538 cur_rank 2 abs_cur_rank 2 total_num 319 809
142 134
1420 80 42
checkcorrect 74 74 real score 1.995493185520172 Hits@1 0.6853582554517134 Hits@3 0.8317757009345794 Hits@10 0.9626168224299065 MRR 0.7742178326544927 

142 149
1420 474 900
checkcorrect 12 12 real score 1.993430233001709 Hits@1 0.6853932584269663 Hits@3 0.8342696629213483 Hits@10 0.9634831460674157 MRR 0.7746798095342138 cur_rank 0 abs_cur_rank 0 total_num 355 809
142 121
1420 96 56
checkcorrect 62 62 real score 1.8709112405776978 Hits@1 0.6862745098039216 Hits@3 0.834733893557423 Hits@10 0.9635854341736695 MRR 0.7753109585271152 cur_rank 0 abs_cur_rank 0 total_num 356 809
142 13
1420 381 34
checkcorrect 164 164 real score 1.3284802705049514 Hits@1 0.6871508379888268 Hits@3 0.835195530726257 Hits@10 0.9636871508379888 MRR 0.7759385815479891 cur_rank 0 abs_cur_rank 0 total_num 357 809
142 123
1420 18 174
checkcorrect 10 10 real score 1.1782634824514389 Hits@1 0.6852367688022284 Hits@3 0.8328690807799443 Hits@10 0.9637883008356546 MRR 0.7744735715715323 cur_rank 3 abs_cur_rank 3 total_num 358 809
142 150
1420 278 451
checkcorrect 16 16 real score 1.9656046390533448 Hits@1 0.6861111111111111 Hits@3 0.8333333333333334 Hits@10 0.9638888888

142 150
1420 638 110
checkcorrect 48 48 real score 1.985631400346756 Hits@1 0.6886075949367089 Hits@3 0.8253164556962025 Hits@10 0.9670886075949368 MRR 0.7753865287616374 cur_rank 0 abs_cur_rank 1 total_num 394 809
142 136
1420 473 413
checkcorrect 16 16 real score 1.9985552191734315 Hits@1 0.6893939393939394 Hits@3 0.8257575757575758 Hits@10 0.9671717171717171 MRR 0.7759537344970878 cur_rank 0 abs_cur_rank 0 total_num 395 809
0 0
1420 27 10
checkcorrect 118 118 real score 0.9960408210754395 Hits@1 0.690176322418136 Hits@3 0.8261964735516373 Hits@10 0.9672544080604534 MRR 0.7765180827729138 cur_rank 0 abs_cur_rank 0 total_num 396 809
142 35
1420 64 139
checkcorrect 10 10 real score 1.2624927550554275 Hits@1 0.6884422110552764 Hits@3 0.8266331658291457 Hits@10 0.9673366834170855 MRR 0.7754045532517088 cur_rank 2 abs_cur_rank 2 total_num 397 809
142 14
1420 216 58
checkcorrect 4 4 real score 0.8578274676343426 Hits@1 0.6867167919799498 Hits@3 0.8270676691729323 Hits@10 0.9674185463659147

0 1
1420 4 7
checkcorrect 128 128 real score 0.62395179271698 Hits@1 0.684331797235023 Hits@3 0.8202764976958525 Hits@10 0.9700460829493087 MRR 0.7714438873613247 cur_rank 5 abs_cur_rank 5 total_num 433 809
142 101
1420 389 70
checkcorrect 164 164 real score 1.0336823813617229 Hits@1 0.6850574712643678 Hits@3 0.8206896551724138 Hits@10 0.9701149425287356 MRR 0.7719693037122183 cur_rank 0 abs_cur_rank 0 total_num 434 809
142 26
1420 25 11
checkcorrect 118 118 real score 1.9923175036907197 Hits@1 0.6857798165137615 Hits@3 0.8211009174311926 Hits@10 0.9701834862385321 MRR 0.7724923098963645 cur_rank 0 abs_cur_rank 0 total_num 435 809
0 1
1420 48 386
checkcorrect 84 84 real score 0.8183444738388062 Hits@1 0.6842105263157895 Hits@3 0.8215102974828375 Hits@10 0.9702517162471396 MRR 0.7718687576998053 cur_rank 1 abs_cur_rank 1 total_num 436 809
142 150
1420 873 501
checkcorrect 16 16 real score 1.9933766305446625 Hits@1 0.684931506849315 Hits@3 0.821917808219178 Hits@10 0.9703196347031964 MRR

142 150
1420 166 132
checkcorrect 204 204 real score 1.4309698820114136 Hits@1 0.6849894291754757 Hits@3 0.8224101479915433 Hits@10 0.9704016913319239 MRR 0.7726826977217293 cur_rank 2 abs_cur_rank 2 total_num 472 809
142 150
1420 354 382
checkcorrect 16 16 real score 1.9977123379707336 Hits@1 0.6856540084388185 Hits@3 0.8227848101265823 Hits@10 0.9704641350210971 MRR 0.7731622700894049 cur_rank 0 abs_cur_rank 0 total_num 473 809
0 1
1420 88 7
checkcorrect 18 18 real score 0.9901179075241089 Hits@1 0.6863157894736842 Hits@3 0.8231578947368421 Hits@10 0.9705263157894737 MRR 0.7736398232050061 cur_rank 0 abs_cur_rank 0 total_num 474 809
142 79
1420 20 28
checkcorrect 118 118 real score 1.9923193812370301 Hits@1 0.6869747899159664 Hits@3 0.8235294117647058 Hits@10 0.9705882352941176 MRR 0.7741153697949116 cur_rank 0 abs_cur_rank 0 total_num 475 809
142 146
1420 23 18
checkcorrect 118 118 real score 1.9923216462135316 Hits@1 0.6876310272536688 Hits@3 0.8238993710691824 Hits@10 0.9706498951

142 150
1420 280 772
checkcorrect 48 48 real score 1.9857836723327638 Hits@1 0.685546875 Hits@3 0.818359375 Hits@10 0.96484375 MRR 0.7709765641816708 cur_rank 0 abs_cur_rank 1 total_num 511 809
142 78
1420 23 20
checkcorrect 118 118 real score 1.9923162817955018 Hits@1 0.6861598440545809 Hits@3 0.8187134502923976 Hits@10 0.9649122807017544 MRR 0.7714230036277104 cur_rank 0 abs_cur_rank 0 total_num 512 809
142 16
1420 204 57
checkcorrect 4 4 real score 0.7684235170949251 Hits@1 0.6848249027237354 Hits@3 0.8171206225680934 Hits@10 0.9649805447470817 MRR 0.77031128572182 cur_rank 4 abs_cur_rank 4 total_num 513 809
142 143
1420 123 247
checkcorrect 252 252 real score 1.5913399696350097 Hits@1 0.683495145631068 Hits@3 0.8155339805825242 Hits@10 0.9650485436893204 MRR 0.7693009725456611 cur_rank 3 abs_cur_rank 3 total_num 514 809
142 42
1420 20 161
checkcorrect 128 128 real score 1.760357129573822 Hits@1 0.6841085271317829 Hits@3 0.8158914728682171 Hits@10 0.9651162790697675 MRR 0.7697480636

142 122
1420 267 158
checkcorrect 264 264 real score 1.6284556984901428 Hits@1 0.6823956442831216 Hits@3 0.8185117967332124 Hits@10 0.9673321234119783 MRR 0.7695825786951278 cur_rank 0 abs_cur_rank 0 total_num 550 809
142 150
1420 440 550
checkcorrect 16 16 real score 1.998561531305313 Hits@1 0.6829710144927537 Hits@3 0.8188405797101449 Hits@10 0.967391304347826 MRR 0.7700000015598105 cur_rank 0 abs_cur_rank 0 total_num 551 809
142 55
1420 248 40
checkcorrect 106 106 real score 1.975869643688202 Hits@1 0.6835443037974683 Hits@3 0.8191681735985533 Hits@10 0.9674502712477396 MRR 0.7704159147577132 cur_rank 0 abs_cur_rank 0 total_num 552 809
142 150
1420 377 352
checkcorrect 16 16 real score 1.9985466599464417 Hits@1 0.6841155234657039 Hits@3 0.8194945848375451 Hits@10 0.9675090252707581 MRR 0.7708303264639268 cur_rank 0 abs_cur_rank 0 total_num 553 809
142 83
1420 120 149
checkcorrect 40 40 real score 1.2005861684679986 Hits@1 0.6828828828828829 Hits@3 0.8198198198198198 Hits@10 0.967567

142 150
1420 207 340
checkcorrect 16 16 real score 0.8845592722296715 Hits@1 0.6847457627118644 Hits@3 0.8203389830508474 Hits@10 0.9677966101694915 MRR 0.7714022075395203 cur_rank 2 abs_cur_rank 2 total_num 589 809
142 150
1420 630 261
checkcorrect 16 16 real score 1.9985979437828063 Hits@1 0.6852791878172588 Hits@3 0.8206429780033841 Hits@10 0.9678510998307953 MRR 0.7717890058347157 cur_rank 0 abs_cur_rank 0 total_num 590 809
0 1
1420 11 76
checkcorrect 208 208 real score 0.3651520907878876 Hits@1 0.6841216216216216 Hits@3 0.8192567567567568 Hits@10 0.9662162162162162 MRR 0.7706260739554903 cur_rank 11 abs_cur_rank 12 total_num 591 809
142 27
1420 53 26
checkcorrect 18 18 real score 1.9827616393566132 Hits@1 0.684654300168634 Hits@3 0.8195615514333895 Hits@10 0.9662731871838112 MRR 0.7710128765289213 cur_rank 0 abs_cur_rank 0 total_num 592 809
142 150
1420 44 369
checkcorrect 158 158 real score 0.09695198778063059 Hits@1 0.6835016835016835 Hits@3 0.8181818181818182 Hits@10 0.96464646

142 64
1420 224 53
checkcorrect 10 10 real score 1.6919856190681457 Hits@1 0.6931637519872814 Hits@3 0.8267090620031796 Hits@10 0.9666136724960255 MRR 0.7774859747371902 cur_rank 0 abs_cur_rank 0 total_num 628 809
142 114
1420 152 524
checkcorrect 48 48 real score 1.9912467539310454 Hits@1 0.6936507936507936 Hits@3 0.8269841269841269 Hits@10 0.9666666666666667 MRR 0.7778391716026867 cur_rank 0 abs_cur_rank 1 total_num 629 809
142 95
1420 140 135
checkcorrect 162 162 real score 1.6238716095685959 Hits@1 0.694136291600634 Hits@3 0.8272583201267829 Hits@10 0.9667194928684627 MRR 0.7781912489852498 cur_rank 0 abs_cur_rank 0 total_num 630 809
142 150
1420 433 489
checkcorrect 16 16 real score 1.9985749304294587 Hits@1 0.694620253164557 Hits@3 0.8275316455696202 Hits@10 0.9667721518987342 MRR 0.7785422121988808 cur_rank 0 abs_cur_rank 0 total_num 631 809
142 70
1420 475 69
checkcorrect 92 92 real score 1.9581538379192351 Hits@1 0.6951026856240127 Hits@3 0.8278041074249605 Hits@10 0.966824644

0 2
1420 13 18
checkcorrect 54 54 real score 0.6702583432197571 Hits@1 0.688622754491018 Hits@3 0.8203592814371258 Hits@10 0.9655688622754491 MRR 0.7737976170523855 cur_rank 4 abs_cur_rank 4 total_num 667 809
142 150
1420 262 123
checkcorrect 252 252 real score 1.632649451494217 Hits@1 0.6875934230194319 Hits@3 0.8191330343796711 Hits@10 0.9656203288490284 MRR 0.7730146609730845 cur_rank 3 abs_cur_rank 3 total_num 668 809
142 150
1420 151 271
checkcorrect 70 70 real score 1.9101448059082031 Hits@1 0.6880597014925374 Hits@3 0.8194029850746268 Hits@10 0.9656716417910448 MRR 0.7733534450611843 cur_rank 0 abs_cur_rank 0 total_num 669 809
0 1
1420 189 44
checkcorrect 162 162 real score 0.9706474542617798 Hits@1 0.6885245901639344 Hits@3 0.819672131147541 Hits@10 0.9657228017883756 MRR 0.7736912193606461 cur_rank 0 abs_cur_rank 0 total_num 670 809
142 133
1420 20 29
checkcorrect 118 118 real score 1.9923011183738708 Hits@1 0.6889880952380952 Hits@3 0.8199404761904762 Hits@10 0.96577380952380

142 150
1420 1606 853
checkcorrect 48 48 real score 1.974942409992218 Hits@1 0.695898161244696 Hits@3 0.8274398868458275 Hits@10 0.9674681753889675 MRR 0.7803820955082417 cur_rank 0 abs_cur_rank 0 total_num 706 809
0 1
1420 475 62
checkcorrect 50 50 real score 0.9756618738174438 Hits@1 0.6963276836158192 Hits@3 0.827683615819209 Hits@10 0.9675141242937854 MRR 0.7806922902885973 cur_rank 0 abs_cur_rank 0 total_num 707 809
0 1
1420 35 209
checkcorrect 28 28 real score 0.6545993089675903 Hits@1 0.6953455571227081 Hits@3 0.8265162200282088 Hits@10 0.9675599435825106 MRR 0.7798262456854633 cur_rank 5 abs_cur_rank 5 total_num 708 809
142 150
1420 413 319
checkcorrect 16 16 real score 1.9919181406497954 Hits@1 0.6957746478873239 Hits@3 0.8267605633802817 Hits@10 0.967605633802817 MRR 0.7801363495647796 cur_rank 0 abs_cur_rank 0 total_num 709 809
142 150
1420 89 387
checkcorrect 204 204 real score 1.9420485854148866 Hits@1 0.6962025316455697 Hits@3 0.8270042194092827 Hits@10 0.9676511954992968

142 150
1420 376 282
checkcorrect 16 16 real score 1.9985490620136261 Hits@1 0.6880856760374833 Hits@3 0.821954484605087 Hits@10 0.965194109772423 MRR 0.7745260238373556 cur_rank 0 abs_cur_rank 0 total_num 746 809
142 150
1420 1665 788
checkcorrect 196 196 real score 1.809610378742218 Hits@1 0.6885026737967914 Hits@3 0.8221925133689839 Hits@10 0.9652406417112299 MRR 0.7748274596343645 cur_rank 0 abs_cur_rank 0 total_num 747 809
142 66
1420 496 160
checkcorrect 16 16 real score 1.9985314786434174 Hits@1 0.6889185580774366 Hits@3 0.822429906542056 Hits@10 0.965287049399199 MRR 0.7751280905293787 cur_rank 0 abs_cur_rank 0 total_num 748 809
142 110
1420 84 361
checkcorrect 16 16 real score 1.9215801179409027 Hits@1 0.6893333333333334 Hits@3 0.8226666666666667 Hits@10 0.9653333333333334 MRR 0.7754279197420062 cur_rank 0 abs_cur_rank 0 total_num 749 809
142 122
1420 71 236
checkcorrect 62 62 real score 0.9550158113241196 Hits@1 0.6884154460719041 Hits@3 0.8229027962716379 Hits@10 0.965379494

142 150
1420 619 99
checkcorrect 92 92 real score 1.9582096695899964 Hits@1 0.6895674300254453 Hits@3 0.8244274809160306 Hits@10 0.9643765903307888 MRR 0.7756809484997335 cur_rank 0 abs_cur_rank 0 total_num 785 809
142 5
1420 194 52
checkcorrect 100 100 real score 0.9068430259823799 Hits@1 0.6886912325285895 Hits@3 0.8233799237611181 Hits@10 0.9644218551461246 MRR 0.774876852177443 cur_rank 6 abs_cur_rank 6 total_num 786 809
0 1
1420 12 43
checkcorrect 50 50 real score 0.992997407913208 Hits@1 0.6890862944162437 Hits@3 0.8236040609137056 Hits@10 0.9644670050761421 MRR 0.7751625414513295 cur_rank 0 abs_cur_rank 0 total_num 787 809
142 150
1420 258 502
checkcorrect 244 244 real score 0.6073213338851928 Hits@1 0.688212927756654 Hits@3 0.8225602027883396 Hits@10 0.9645120405576679 MRR 0.7744969362023417 cur_rank 3 abs_cur_rank 3 total_num 788 809
0 1
1420 35 202
checkcorrect 84 84 real score 0.8897531628608704 Hits@1 0.6886075949367089 Hits@3 0.8227848101265823 Hits@10 0.9645569620253165 M

In [31]:
###########################################
##obtain the AUC-PR for the test triples###
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import auc, plot_precision_recall_curve
import matplotlib.pyplot as plt

#we select all the triples in the inductive test set
pos_triples = list(data_ind_test)

#we build the negative samples by randomly replace head or tail entity in the triple.
neg_triples = list()

for i in range(len(pos_triples)):
    
    s_pos, r_pos, t_pos = pos_triples[i][0], pos_triples[i][1], pos_triples[i][2]
    
    #decide to replace the head or tail entity
    number_0 = random.uniform(0, 1)
    
    if number_0 < 0.5: #replace head entity
        s_neg = random.choice(list(new_ent_set))
        neg_triples.append((s_neg, r_pos, t_pos))
    else: #replace tail entity
        t_neg = random.choice(list(new_ent_set))
        neg_triples.append((s_pos, r_pos, t_neg))

if len(pos_triples) != len(neg_triples):
    raise ValueError('error when generating negative triples')
        
#combine all triples
all_triples = pos_triples + neg_triples

#obtain the label array
arr1 = np.ones((len(pos_triples),))
arr2 = np.zeros((len(neg_triples),))
y_test = np.concatenate((arr1, arr2))

#shuffle positive and negative triples (optional)
all_triples, y_test = shuffle(all_triples, y_test)

#obtain the score aray
y_score = np.zeros((len(y_test),))

#implement the scoring
for i in range(len(all_triples)):
    
    s, r, t = all_triples[i][0], all_triples[i][1], all_triples[i][2]
    
    path_score = path_based_triple_scoring(s, r, t, 1, 10, one_hop_ind, id2relation, model)
    
    subg_score = subgraph_triple_scoring(s, r, t, 1, 3, one_hop_ind, id2relation, model_2)
    
    ave_score = (path_score + subg_score)/float(2)
    
    y_score[i] = ave_score
    
    if i % 20 == 0 and i > 0:
        print('evaluating scores', i, len(all_triples))
        
        # Data to plot precision - recall curve
        precision, recall, thresholds = precision_recall_curve(y_test[:i], y_score[:i])
        # Use AUC function to calculate the area under the curve of precision recall curve
        auc_precision_recall = auc(recall, precision)
        print('AUC-PR is:', auc_precision_recall)
        
        
# Data to plot precision - recall curve
precision, recall, thresholds = precision_recall_curve(y_test, y_score)
# Use AUC function to calculate the area under the curve of precision recall curve
auc_precision_recall = auc(recall, precision)
print('AUC-PR is:', auc_precision_recall)

evaluating scores 20 2848
AUC-PR is: 1.0
evaluating scores 40 2848
AUC-PR is: 0.942565153499751
evaluating scores 60 2848
AUC-PR is: 0.9050756700080163
evaluating scores 80 2848
AUC-PR is: 0.9119715822295396
evaluating scores 100 2848
AUC-PR is: 0.8051519168194391
evaluating scores 120 2848
AUC-PR is: 0.8494111342053469
evaluating scores 140 2848
AUC-PR is: 0.8484663869332756
evaluating scores 160 2848
AUC-PR is: 0.8542383906523069
evaluating scores 180 2848
AUC-PR is: 0.8571681212943643
evaluating scores 200 2848
AUC-PR is: 0.8688673062158704
evaluating scores 220 2848
AUC-PR is: 0.8777552422954364
evaluating scores 240 2848
AUC-PR is: 0.8786174114845786
evaluating scores 260 2848
AUC-PR is: 0.8901002910229088
evaluating scores 280 2848
AUC-PR is: 0.8984785893357252
evaluating scores 300 2848
AUC-PR is: 0.9023298662116466
evaluating scores 320 2848
AUC-PR is: 0.8971822634065922
evaluating scores 340 2848
AUC-PR is: 0.8978130814364801
evaluating scores 360 2848
AUC-PR is: 0.90067631223

In [None]:
###########################################
##obtain Hits@N for entities prediction####


#### Fine tuned

In [29]:
#function to build the big batche for path-based training
def build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    count = 0
    for s in existing_ids:
        
        #impliment the path finding algorithm to find paths between s and t
        result, length_dict = Class_2.obtain_paths('direct_neighbour', s, 'nb', lower_bd, upper_bd, one_hop)
        
        for iteration in range(2):

            #proceed only if at least three paths are between s and t
            for t in result:

                if len(s_t_r[(s,t)]) == 0:

                    raise ValueError(s,t,id2entity[s], id2entity[t])

                #we are only interested in forward link in relation prediciton
                ini_r_list = list()

                #obtain initial relations between s and t
                for r in s_t_r[(s,t)]:
                    if r % 2 == 0:#initial relation id is always an even number
                        ini_r_list.append(r)

                #if there exist more than three paths between s and t, 
                #and inital connection between s and t exists,
                #and not every r in the relation dictionary exists between s and t (although this is rare)
                #we then proceed
                if len(result[t]) >= 3 and len(ini_r_list) > 0 and len(ini_r_list) < int(num_ini_r):

                    #obtain the list form of all the paths from s to t
                    temp_path_list = list(result[t])

                    temp_pair = random.sample(temp_path_list, 3)

                    path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

                    #####positive#####################
                    #append the paths: note that we add the space holder id at the end of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    r = random.choice(ini_r_list)
                    x_r_list.append([r])
                    y_list.append(1.)

                    #####negative#####################
                    #append the paths: note that we add the space holder id at the end
                    #of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    neg_r_list = list(ini_r_id_set.difference(set(ini_r_list)))
                    r_ran = random.choice(neg_r_list)
                    x_r_list.append([r_ran])
                    y_list.append(0.)
        
        count += 1
        if count % 100 == 0:
            print('generating big-batches for path-based model', count, len(existing_ids))

In [30]:
#Again, it is too slow to run the path-finding algorithm again and again on the complete FB15K-237
#Instead, we will find the subgraph for each entity once.
#then in the subgraph based training, the subgraphs are stored and used for multiple times
def store_subgraph_dicts(lower_bd, upper_bd, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
    
    #the ids to start path finding
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    #Dict stores the subgraph for each entity
    Dict_1 = dict()
    
    count = 0
    for s in existing_ids:
        
        path_set = set()
            
        result, length_dict = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)

        for t_ in result:
            for path in result[t_]:
                path_set.add(path)

        del(result, length_dict)
        
        path_list = list(path_set)
        
        path_select = random.sample(path_list, min(len(path_list), 100))
            
        Dict_1[s] = deepcopy(path_select)
        
        count += 1
        if count % 100 == 0:
            print('generating and storing paths for the path-based model', count, len(existing_ids))
        
    return(Dict_1)

In [31]:
#function to build the big-batch for one-hope neighbor training
def build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_s_list, x_t_list, x_r_list, y_list, Dict,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
    
    data = list(data)
    
    for iteration in range(2):

        data = shuffle(data)

        for i_0 in range(len(data)):

            triple = data[i_0]

            s, r, t = triple[0], triple[1], triple[2] #obtain entities and relation IDs

            path_s, path_t = Dict[s], Dict[t] #sets holding all the paths from s or t

            #see if both path_s and path_t have at least three paths
            if len(path_s) >= 3 and len(path_t) >= 3:

                #change to lists
                path_s, path_t = list(path_s), list(path_t)

                #randomly obtain three paths
                temp_s = random.sample(path_s, 3)
                temp_t = random.sample(path_t, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #####positive step###########
                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative step###########
                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                neg_r_list = list(ini_r_id_set.difference({r}))
                r_ran = random.choice(neg_r_list)
                x_r_list.append([r_ran])
                y_list.append(0.)

            if i_0 % 500 == 0:
                print('generating big-batches for subgraph-based model', i_0, len(data), iteration)

In [32]:
###fine tune the path-based model
lower_bd = 1
upper_bd = 10
batch_size = 32

#define the training lists
train_p_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_path(lower_bd, upper_bd, data_ind, one_hop_ind, s_t_r_ind,
                      train_p_list, train_r_list, train_y_list,
                      relation2id, entity2id, id2relation, id2entity)   

#######################################
###do the training#####################

#generate the input arrays
x_train_1 = np.asarray(train_p_list['1'], dtype='int')
x_train_2 = np.asarray(train_p_list['2'], dtype='int')
x_train_3 = np.asarray(train_p_list['3'], dtype='int')
x_train_r = np.asarray(train_r_list, dtype='int')
y_train = np.asarray(train_y_list, dtype='int')

model.fit([x_train_1, x_train_2, x_train_3, x_train_r], y_train,
           batch_size=batch_size, epochs=2)

generating big-batches for path-based model 100 3566
generating big-batches for path-based model 200 3566
generating big-batches for path-based model 300 3566
generating big-batches for path-based model 400 3566
generating big-batches for path-based model 500 3566
generating big-batches for path-based model 600 3566
generating big-batches for path-based model 700 3566
generating big-batches for path-based model 800 3566
generating big-batches for path-based model 900 3566
generating big-batches for path-based model 1000 3566
generating big-batches for path-based model 1100 3566
generating big-batches for path-based model 1200 3566
generating big-batches for path-based model 1300 3566
generating big-batches for path-based model 1400 3566
generating big-batches for path-based model 1500 3566
generating big-batches for path-based model 1600 3566
generating big-batches for path-based model 1700 3566
generating big-batches for path-based model 1800 3566
generating big-batches for path-based

<keras.callbacks.History at 0x7fe1c0623eb0>

In [33]:
###fine tune the subgraph model
lower_bd = 1
upper_bd = 3
batch_size = 32

Dict_train_ind = store_subgraph_dicts(lower_bd, upper_bd, data_ind, one_hop_ind, s_t_r_ind,
                         relation2id, entity2id, id2relation, id2entity)

#define the training lists
train_s_list, train_t_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_subgraph(lower_bd, upper_bd, data_ind, one_hop_ind, s_t_r_ind,
                      train_s_list, train_t_list, train_r_list, train_y_list, Dict_train_ind,
                      relation2id, entity2id, id2relation, id2entity)

#######################################
###do the training#####################

#generate the input arrays
x_train_s_1 = np.asarray(train_s_list['1'], dtype='int')
x_train_s_2 = np.asarray(train_s_list['2'], dtype='int')
x_train_s_3 = np.asarray(train_s_list['3'], dtype='int')

x_train_t_1 = np.asarray(train_t_list['1'], dtype='int')
x_train_t_2 = np.asarray(train_t_list['2'], dtype='int')
x_train_t_3 = np.asarray(train_t_list['3'], dtype='int')

x_train_r = np.asarray(train_r_list, dtype='int')
y_train = np.asarray(train_y_list, dtype='int')

model_2.fit([x_train_s_1, x_train_s_2, x_train_s_3, 
             x_train_t_1, x_train_t_2, x_train_t_3, x_train_r], y_train,
             batch_size=batch_size, epochs=2)

generating and storing paths for the path-based model 100 3566
generating and storing paths for the path-based model 200 3566
generating and storing paths for the path-based model 300 3566
generating and storing paths for the path-based model 400 3566
generating and storing paths for the path-based model 500 3566
generating and storing paths for the path-based model 600 3566
generating and storing paths for the path-based model 700 3566
generating and storing paths for the path-based model 800 3566
generating and storing paths for the path-based model 900 3566
generating and storing paths for the path-based model 1000 3566
generating and storing paths for the path-based model 1100 3566
generating and storing paths for the path-based model 1200 3566
generating and storing paths for the path-based model 1300 3566
generating and storing paths for the path-based model 1400 3566
generating and storing paths for the path-based model 1500 3566
generating and storing paths for the path-based m

<keras.callbacks.History at 0x7fe1c2efad30>

In [34]:
########################################################
#obtain the Hits@N for relation prediction##############

#we select all the triples in the inductive test set
selected = list(data_ind_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    #run the path-based scoring
    score_dict_path = path_based_relation_scoring(s_true, t_true, 1, 10, one_hop_ind, id2relation, model)
    
    #run the one-hop neighbour based scoring
    score_dict_subg = subgraph_relation_scoring(s_true, t_true, 1, 3, one_hop_ind, id2relation, model_2)
    
    #final score dict
    score_dict = defaultdict(float)
    
    for r in score_dict_path:
        score_dict[r] += score_dict_path[r]
    for r in score_dict_subg:
        score_dict[r] += score_dict_subg[r]
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        #again, we only care about initial relation prediciton
        if r % 2 == 0:
        
            if r in score_dict:

                temp_list.append([score_dict[r], r])

            else:

                temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
            (s_true, sorted_list[p][1], t_true) in data_valid) or (
            (s_true, sorted_list[p][1], t_true) in data) or (
            (s_true, sorted_list[p][1], t_true) in data_ind) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_valid) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_test):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', r_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

142 150
1420 724 873
checkcorrect 16 16 real score 1.9991523206233979 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 0 809
142 75
1420 43 10
checkcorrect 118 118 real score 1.999608075618744 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 1 809
142 26
1420 605 640
checkcorrect 198 198 real score 1.9544902563095092 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 2 total_num 2 809
142 53
1420 660 67
checkcorrect 108 108 real score 1.9539379000663757 Hits@1 0.75 Hits@3 1.0 Hits@10 1.0 MRR 0.875 cur_rank 1 abs_cur_rank 1 total_num 3 809
142 93
1420 28 136
checkcorrect 34 34 real score 1.9597204446792602 Hits@1 0.8 Hits@3 1.0 Hits@10 1.0 MRR 0.9 cur_rank 0 abs_cur_rank 0 total_num 4 809
142 3
1420 519 2087
checkcorrect 62 62 real score 0.517503422498703 Hits@1 0.6666666666666666 Hits@3 0.8333333333333334 Hits@10 1.0 MRR 0.7685185185185185 cur_rank 8 abs_cur_rank 10 total_num 5 809
142 106
1420 126 351
checkcorrec

142 146
1420 80 42
checkcorrect 74 74 real score 1.9941690504550933 Hits@1 0.7209302325581395 Hits@3 0.8837209302325582 Hits@10 0.9767441860465116 MRR 0.8056293835363603 cur_rank 0 abs_cur_rank 0 total_num 42 809
142 150
1420 1766 2041
checkcorrect 196 196 real score 1.900141680240631 Hits@1 0.7272727272727273 Hits@3 0.8863636363636364 Hits@10 0.9772727272727273 MRR 0.8100468975468975 cur_rank 0 abs_cur_rank 0 total_num 43 809
142 150
1420 181 184
checkcorrect 16 16 real score 1.999034398794174 Hits@1 0.7333333333333333 Hits@3 0.8888888888888888 Hits@10 0.9777777777777777 MRR 0.814268077601411 cur_rank 0 abs_cur_rank 0 total_num 44 809
142 24
1420 84 153
checkcorrect 162 162 real score 1.995279186964035 Hits@1 0.7391304347826086 Hits@3 0.8913043478260869 Hits@10 0.9782608695652174 MRR 0.8183057280883368 cur_rank 0 abs_cur_rank 0 total_num 45 809
142 6
1420 536 21
checkcorrect 176 176 real score 1.984796804189682 Hits@1 0.7446808510638298 Hits@3 0.8936170212765957 Hits@10 0.978723404255

142 139
1420 109 613
checkcorrect 12 12 real score 1.9634984970092773 Hits@1 0.7195121951219512 Hits@3 0.8536585365853658 Hits@10 0.975609756097561 MRR 0.8047566931713274 cur_rank 0 abs_cur_rank 0 total_num 81 809
0 1
1420 224 39
checkcorrect 10 10 real score 0.9819881916046143 Hits@1 0.7228915662650602 Hits@3 0.8554216867469879 Hits@10 0.9759036144578314 MRR 0.8071090221692632 cur_rank 0 abs_cur_rank 0 total_num 82 809
142 150
1420 223 259
checkcorrect 252 252 real score 1.9764680087566375 Hits@1 0.7261904761904762 Hits@3 0.8571428571428571 Hits@10 0.9761904761904762 MRR 0.8094053433339148 cur_rank 0 abs_cur_rank 0 total_num 83 809
142 6
1420 28 6
checkcorrect 34 34 real score 1.9383000910282135 Hits@1 0.7294117647058823 Hits@3 0.8588235294117647 Hits@10 0.9764705882352941 MRR 0.8116476334123394 cur_rank 0 abs_cur_rank 0 total_num 84 809
142 150
1420 306 2115
checkcorrect 116 116 real score 1.9675970792770385 Hits@1 0.7209302325581395 Hits@3 0.8604651162790697 Hits@10 0.97674418604651

142 124
1420 18 42
checkcorrect 62 62 real score 1.4683188259601594 Hits@1 0.7272727272727273 Hits@3 0.8842975206611571 Hits@10 0.9669421487603306 MRR 0.818362410434292 cur_rank 1 abs_cur_rank 1 total_num 120 809
142 137
1420 353 944
checkcorrect 12 12 real score 1.9721258461475373 Hits@1 0.7295081967213115 Hits@3 0.8852459016393442 Hits@10 0.9672131147540983 MRR 0.8198512431356503 cur_rank 0 abs_cur_rank 0 total_num 121 809
142 150
1420 218 181
checkcorrect 244 244 real score 1.8493283987045288 Hits@1 0.7317073170731707 Hits@3 0.8861788617886179 Hits@10 0.967479674796748 MRR 0.8213158671751979 cur_rank 0 abs_cur_rank 1 total_num 122 809
142 150
1420 1072 2072
checkcorrect 116 116 real score 1.9791572630405425 Hits@1 0.7338709677419355 Hits@3 0.8870967741935484 Hits@10 0.967741935483871 MRR 0.8227568682463656 cur_rank 0 abs_cur_rank 0 total_num 123 809
142 8
1420 218 251
checkcorrect 196 196 real score 1.3773263931274413 Hits@1 0.728 Hits@3 0.88 Hits@10 0.968 MRR 0.8181748133003947 cur

0 1
1420 19 169
checkcorrect 62 62 real score 0.9890041351318359 Hits@1 0.74375 Hits@3 0.90625 Hits@10 0.975 MRR 0.8350324062242669 cur_rank 0 abs_cur_rank 0 total_num 159 809
142 9
1420 28 5
checkcorrect 34 34 real score 1.958381599187851 Hits@1 0.7453416149068323 Hits@3 0.906832298136646 Hits@10 0.9751552795031055 MRR 0.8360570496638676 cur_rank 0 abs_cur_rank 0 total_num 160 809
142 88
1420 445 89
checkcorrect 30 30 real score 1.9844648122787476 Hits@1 0.7469135802469136 Hits@3 0.9074074074074074 Hits@10 0.9753086419753086 MRR 0.8370690431844611 cur_rank 0 abs_cur_rank 0 total_num 161 809
142 3
1420 257 771
checkcorrect 252 252 real score 1.9738841354846954 Hits@1 0.7484662576687117 Hits@3 0.9079754601226994 Hits@10 0.9754601226993865 MRR 0.8380686196066423 cur_rank 0 abs_cur_rank 0 total_num 162 809
142 55
1420 297 50
checkcorrect 130 130 real score 1.8692253232002258 Hits@1 0.75 Hits@3 0.9085365853658537 Hits@10 0.975609756097561 MRR 0.8390560060724555 cur_rank 0 abs_cur_rank 0 to

142 150
1420 628 390
checkcorrect 16 16 real score 1.9988008797168733 Hits@1 0.7688442211055276 Hits@3 0.914572864321608 Hits@10 0.9798994974874372 MRR 0.8505006504539052 cur_rank 0 abs_cur_rank 0 total_num 198 809
142 150
1420 324 464
checkcorrect 16 16 real score 1.9986595034599304 Hits@1 0.77 Hits@3 0.915 Hits@10 0.98 MRR 0.8512481472016357 cur_rank 0 abs_cur_rank 0 total_num 199 809
142 150
1420 1548 2562
checkcorrect 196 196 real score 1.8419682264328003 Hits@1 0.7661691542288557 Hits@3 0.9154228855721394 Hits@10 0.9800995024875622 MRR 0.8486714565853756 cur_rank 2 abs_cur_rank 4 total_num 200 809
142 150
1420 539 609
checkcorrect 16 16 real score 1.9986069321632385 Hits@1 0.7673267326732673 Hits@3 0.9158415841584159 Hits@10 0.9801980198019802 MRR 0.8494206077903984 cur_rank 0 abs_cur_rank 0 total_num 201 809
142 33
1420 133 236
checkcorrect 36 36 real score 1.6875813722610473 Hits@1 0.7635467980295566 Hits@3 0.916256157635468 Hits@10 0.9802955665024631 MRR 0.8476993240081797 cur_

142 85
1420 187 136
checkcorrect 162 162 real score 1.982445389032364 Hits@1 0.773109243697479 Hits@3 0.9201680672268907 Hits@10 0.9831932773109243 MRR 0.8543401797212625 cur_rank 0 abs_cur_rank 0 total_num 237 809
142 150
1420 213 200
checkcorrect 16 16 real score 1.7147584438323975 Hits@1 0.7740585774058577 Hits@3 0.9205020920502092 Hits@10 0.9832635983263598 MRR 0.8549496350362363 cur_rank 0 abs_cur_rank 0 total_num 238 809
142 150
1420 196 184
checkcorrect 16 16 real score 1.9950452506542207 Hits@1 0.775 Hits@3 0.9208333333333333 Hits@10 0.9833333333333333 MRR 0.8555540115569187 cur_rank 0 abs_cur_rank 0 total_num 239 809
142 150
1420 509 479
checkcorrect 244 244 real score 1.8111855864524842 Hits@1 0.7717842323651453 Hits@3 0.921161825726141 Hits@10 0.983402489626556 MRR 0.8540786837081348 cur_rank 1 abs_cur_rank 1 total_num 240 809
142 150
1420 415 218
checkcorrect 16 16 real score 1.9985944211483002 Hits@1 0.7727272727272727 Hits@3 0.9214876033057852 Hits@10 0.9834710743801653 M

142 101
1420 210 199
checkcorrect 16 16 real score 1.9980390906333922 Hits@1 0.7870036101083032 Hits@3 0.927797833935018 Hits@10 0.9855595667870036 MRR 0.8637170737917949 cur_rank 0 abs_cur_rank 0 total_num 276 809
142 150
1420 2528 2522
checkcorrect 252 252 real score 1.9697080910205842 Hits@1 0.7877697841726619 Hits@3 0.9280575539568345 Hits@10 0.9856115107913669 MRR 0.8642073001450618 cur_rank 0 abs_cur_rank 1 total_num 277 809
142 70
1420 29 8
checkcorrect 118 118 real score 1.999607765674591 Hits@1 0.7885304659498208 Hits@3 0.9283154121863799 Hits@10 0.985663082437276 MRR 0.8646940123309218 cur_rank 0 abs_cur_rank 0 total_num 278 809
142 150
1420 43 33
checkcorrect 74 74 real score 1.9941161334514619 Hits@1 0.7892857142857143 Hits@3 0.9285714285714286 Hits@10 0.9857142857142858 MRR 0.8651772480011685 cur_rank 0 abs_cur_rank 0 total_num 279 809
142 11
1420 198 19
checkcorrect 4 4 real score 1.8860952436923981 Hits@1 0.7900355871886121 Hits@3 0.9288256227758007 Hits@10 0.98576512455

142 146
1420 451 89
checkcorrect 16 16 real score 1.9985892474651337 Hits@1 0.7974683544303798 Hits@3 0.9335443037974683 Hits@10 0.9841772151898734 MRR 0.8705860284679833 cur_rank 0 abs_cur_rank 0 total_num 315 809
142 146
1420 426 1343
checkcorrect 48 48 real score 1.9521126389503478 Hits@1 0.7981072555205048 Hits@3 0.9337539432176656 Hits@10 0.9842271293375394 MRR 0.8709942744349612 cur_rank 0 abs_cur_rank 1 total_num 316 809
142 22
1420 197 63
checkcorrect 112 112 real score 0.9958536744117736 Hits@1 0.7955974842767296 Hits@3 0.9308176100628931 Hits@10 0.9842767295597484 MRR 0.868884229546801 cur_rank 4 abs_cur_rank 4 total_num 317 809
142 142
1420 154 497
checkcorrect 244 244 real score 1.8519975602626801 Hits@1 0.7931034482758621 Hits@3 0.9310344827586207 Hits@10 0.9843260188087775 MRR 0.86772785265167 cur_rank 1 abs_cur_rank 1 total_num 318 809
0 1
1420 569 59
checkcorrect 198 198 real score 0.9807940125465393 Hits@1 0.790625 Hits@3 0.93125 Hits@10 0.984375 MRR 0.8660578697788001

142 64
1420 58 143
checkcorrect 204 204 real score 1.9946714460849762 Hits@1 0.7859154929577464 Hits@3 0.9295774647887324 Hits@10 0.9859154929577465 MRR 0.8638710777885362 cur_rank 0 abs_cur_rank 0 total_num 354 809
142 150
1420 467 782
checkcorrect 12 12 real score 1.98401198387146 Hits@1 0.7865168539325843 Hits@3 0.9297752808988764 Hits@10 0.9859550561797753 MRR 0.8642534624014898 cur_rank 0 abs_cur_rank 0 total_num 355 809
142 110
1420 96 56
checkcorrect 62 62 real score 1.967198246717453 Hits@1 0.7871148459383753 Hits@3 0.9299719887955182 Hits@10 0.9859943977591037 MRR 0.8646337048037265 cur_rank 0 abs_cur_rank 0 total_num 356 809
142 17
1420 392 34
checkcorrect 164 164 real score 1.7956245422363282 Hits@1 0.7877094972067039 Hits@3 0.9301675977653632 Hits@10 0.9860335195530726 MRR 0.8650118229467328 cur_rank 0 abs_cur_rank 0 total_num 357 809
142 87
1420 24 183
checkcorrect 10 10 real score 1.9788369953632354 Hits@1 0.7883008356545961 Hits@3 0.9303621169916435 Hits@10 0.98607242339

142 150
1420 662 209
checkcorrect 92 92 real score 1.9611512780189515 Hits@1 0.7918781725888325 Hits@3 0.9314720812182741 Hits@10 0.9873096446700508 MRR 0.8682087122206353 cur_rank 0 abs_cur_rank 0 total_num 393 809
142 119
1420 825 102
checkcorrect 48 48 real score 1.9107027530670166 Hits@1 0.7924050632911392 Hits@3 0.9316455696202531 Hits@10 0.9873417721518988 MRR 0.8685423610504566 cur_rank 0 abs_cur_rank 1 total_num 394 809
142 150
1420 468 409
checkcorrect 16 16 real score 1.998845785856247 Hits@1 0.7929292929292929 Hits@3 0.9318181818181818 Hits@10 0.9873737373737373 MRR 0.8688743247851777 cur_rank 0 abs_cur_rank 0 total_num 395 809
0 0
1420 26 12
checkcorrect 118 118 real score 0.9996461868286133 Hits@1 0.7934508816120907 Hits@3 0.9319899244332494 Hits@10 0.9874055415617129 MRR 0.8692046161585147 cur_rank 0 abs_cur_rank 0 total_num 396 809
142 66
1420 60 186
checkcorrect 10 10 real score 1.9750630140304566 Hits@1 0.7939698492462312 Hits@3 0.9321608040201005 Hits@10 0.98743718592

142 148
1420 390 133
checkcorrect 212 212 real score 1.9296435296535492 Hits@1 0.7944572748267898 Hits@3 0.9237875288683602 Hits@10 0.9884526558891455 MRR 0.8681687985168303 cur_rank 0 abs_cur_rank 0 total_num 432 809
0 1
1420 4 7
checkcorrect 128 128 real score 0.8474056720733643 Hits@1 0.7926267281105991 Hits@3 0.9216589861751152 Hits@10 0.988479262672811 MRR 0.8667444464465149 cur_rank 3 abs_cur_rank 3 total_num 433 809
142 128
1420 385 70
checkcorrect 164 164 real score 1.9174618065357207 Hits@1 0.7931034482758621 Hits@3 0.9218390804597701 Hits@10 0.9885057471264368 MRR 0.8670507810523851 cur_rank 0 abs_cur_rank 0 total_num 434 809
142 70
1420 25 11
checkcorrect 118 118 real score 1.9996083199977874 Hits@1 0.7935779816513762 Hits@3 0.9220183486238532 Hits@10 0.9885321100917431 MRR 0.867355710453641 cur_rank 0 abs_cur_rank 0 total_num 435 809
0 1
1420 48 387
checkcorrect 84 84 real score 0.9433215856552124 Hits@1 0.7917620137299771 Hits@3 0.9221967963386728 Hits@10 0.988558352402746

142 150
1420 820 503
checkcorrect 16 16 real score 1.9989515960216522 Hits@1 0.7902542372881356 Hits@3 0.923728813559322 Hits@10 0.989406779661017 MRR 0.8655552749105666 cur_rank 0 abs_cur_rank 0 total_num 471 809
142 150
1420 158 144
checkcorrect 204 204 real score 1.9785252690315247 Hits@1 0.7906976744186046 Hits@3 0.9238900634249472 Hits@10 0.9894291754756871 MRR 0.8658395132299945 cur_rank 0 abs_cur_rank 0 total_num 472 809
142 150
1420 355 372
checkcorrect 16 16 real score 1.9986991822719573 Hits@1 0.7911392405063291 Hits@3 0.9240506329113924 Hits@10 0.989451476793249 MRR 0.866122552231619 cur_rank 0 abs_cur_rank 0 total_num 473 809
0 0
1420 91 5
checkcorrect 18 18 real score 0.9884066581726074 Hits@1 0.791578947368421 Hits@3 0.9242105263157895 Hits@10 0.9894736842105263 MRR 0.8664043994900789 cur_rank 0 abs_cur_rank 0 total_num 474 809
142 79
1420 28 27
checkcorrect 118 118 real score 1.999608051776886 Hits@1 0.792016806722689 Hits@3 0.9243697478991597 Hits@10 0.9894957983193278 

142 150
1420 250 395
checkcorrect 16 16 real score 1.99862397313118 Hits@1 0.7847358121330724 Hits@3 0.9197651663405088 Hits@10 0.9843444227005871 MRR 0.8606160601923697 cur_rank 0 abs_cur_rank 0 total_num 510 809
142 150
1420 276 785
checkcorrect 48 48 real score 1.9244144678115844 Hits@1 0.78515625 Hits@3 0.919921875 Hits@10 0.984375 MRR 0.8608882944498065 cur_rank 0 abs_cur_rank 1 total_num 511 809
142 134
1420 20 20
checkcorrect 118 118 real score 1.9996080994606018 Hits@1 0.7855750487329435 Hits@3 0.9200779727095516 Hits@10 0.9844054580896686 MRR 0.8611594673651091 cur_rank 0 abs_cur_rank 0 total_num 512 809
142 21
1420 260 56
checkcorrect 4 4 real score 1.58378427028656 Hits@1 0.7859922178988327 Hits@3 0.9202334630350194 Hits@10 0.9844357976653697 MRR 0.8614295851328813 cur_rank 0 abs_cur_rank 0 total_num 513 809
142 139
1420 123 254
checkcorrect 252 252 real score 1.965339595079422 Hits@1 0.7864077669902912 Hits@3 0.920388349514563 Hits@10 0.9844660194174757 MRR 0.86169865389961

0 0
1420 2466 596
checkcorrect 108 108 real score 0.9512722492218018 Hits@1 0.7872727272727272 Hits@3 0.9236363636363636 Hits@10 0.9854545454545455 MRR 0.863287527439335 cur_rank 0 abs_cur_rank 3 total_num 549 809
142 129
1420 261 158
checkcorrect 264 264 real score 1.9589517533779144 Hits@1 0.7876588021778584 Hits@3 0.9237749546279492 Hits@10 0.985480943738657 MRR 0.863535644449427 cur_rank 0 abs_cur_rank 0 total_num 550 809
142 150
1420 436 544
checkcorrect 16 16 real score 1.9988021075725555 Hits@1 0.7880434782608695 Hits@3 0.9239130434782609 Hits@10 0.9855072463768116 MRR 0.8637828624848446 cur_rank 0 abs_cur_rank 0 total_num 551 809
142 70
1420 249 40
checkcorrect 106 106 real score 1.8593162834644317 Hits@1 0.7884267631103075 Hits@3 0.9240506329113924 Hits@10 0.9855334538878843 MRR 0.864029186422485 cur_rank 0 abs_cur_rank 0 total_num 552 809
142 150
1420 382 357
checkcorrect 16 16 real score 1.9987252175807952 Hits@1 0.7888086642599278 Hits@3 0.924187725631769 Hits@10 0.98555956

142 143
1420 63 280
checkcorrect 40 40 real score 0.6698711489792913 Hits@1 0.7928692699490663 Hits@3 0.9252971137521222 Hits@10 0.9864176570458404 MRR 0.8668785626909467 cur_rank 4 abs_cur_rank 4 total_num 588 809
142 147
1420 206 344
checkcorrect 16 16 real score 1.640233087539673 Hits@1 0.7932203389830509 Hits@3 0.9254237288135593 Hits@10 0.9864406779661017 MRR 0.8671041922457077 cur_rank 0 abs_cur_rank 0 total_num 589 809
142 133
1420 632 261
checkcorrect 16 16 real score 1.999118399620056 Hits@1 0.7935702199661591 Hits@3 0.9255499153976311 Hits@10 0.9864636209813875 MRR 0.867329058248676 cur_rank 0 abs_cur_rank 0 total_num 590 809
0 1
1420 11 76
checkcorrect 208 208 real score 0.4091142416000366 Hits@1 0.7922297297297297 Hits@3 0.9239864864864865 Hits@10 0.9847972972972973 MRR 0.8659939130099503 cur_rank 12 abs_cur_rank 13 total_num 591 809
142 24
1420 68 26
checkcorrect 18 18 real score 1.984128975868225 Hits@1 0.7925801011804384 Hits@3 0.924114671163575 Hits@10 0.984822934232715

142 150
1420 359 506
checkcorrect 16 16 real score 1.9565814018249512 Hits@1 0.7977707006369427 Hits@3 0.9267515923566879 Hits@10 0.9856687898089171 MRR 0.8698718274092383 cur_rank 0 abs_cur_rank 0 total_num 627 809
142 83
1420 221 54
checkcorrect 10 10 real score 1.980538558959961 Hits@1 0.7980922098569158 Hits@3 0.9268680445151033 Hits@10 0.985691573926868 MRR 0.8700787084467435 cur_rank 0 abs_cur_rank 0 total_num 628 809
142 114
1420 153 516
checkcorrect 48 48 real score 1.9333224833011626 Hits@1 0.7984126984126985 Hits@3 0.926984126984127 Hits@10 0.9857142857142858 MRR 0.8702849327190503 cur_rank 0 abs_cur_rank 1 total_num 629 809
142 111
1420 145 193
checkcorrect 162 162 real score 1.991891998052597 Hits@1 0.7987321711568938 Hits@3 0.9270998415213946 Hits@10 0.9857369255150554 MRR 0.8704905033486556 cur_rank 0 abs_cur_rank 0 total_num 630 809
142 148
1420 434 487
checkcorrect 16 16 real score 1.9987688541412354 Hits@1 0.7990506329113924 Hits@3 0.9272151898734177 Hits@10 0.98575949

142 150
1420 502 666
checkcorrect 16 16 real score 1.9990258812904358 Hits@1 0.7961019490254873 Hits@3 0.9265367316341829 Hits@10 0.9850074962518741 MRR 0.8688476581032716 cur_rank 0 abs_cur_rank 0 total_num 666 809
0 0
1420 13 18
checkcorrect 54 54 real score 0.5330755114555359 Hits@1 0.7949101796407185 Hits@3 0.9251497005988024 Hits@10 0.9850299401197605 MRR 0.8678463891540152 cur_rank 4 abs_cur_rank 4 total_num 667 809
142 141
1420 259 123
checkcorrect 252 252 real score 1.9782699346542358 Hits@1 0.7952167414050823 Hits@3 0.9252615844544095 Hits@10 0.9850523168908819 MRR 0.8680439281836804 cur_rank 0 abs_cur_rank 0 total_num 668 809
142 150
1420 154 259
checkcorrect 70 70 real score 1.5317565023899078 Hits@1 0.7955223880597015 Hits@3 0.9253731343283582 Hits@10 0.9850746268656716 MRR 0.8682408775446002 cur_rank 0 abs_cur_rank 0 total_num 669 809
0 1
1420 176 42
checkcorrect 162 162 real score 0.9982903003692627 Hits@1 0.7958271236959762 Hits@3 0.9254843517138599 Hits@10 0.98509687034

142 82
1420 175 876
checkcorrect 16 16 real score 1.9991290807723998 Hits@1 0.7960339943342776 Hits@3 0.9291784702549575 Hits@10 0.9858356940509915 MRR 0.8697658941759427 cur_rank 0 abs_cur_rank 0 total_num 705 809
142 150
1420 1585 828
checkcorrect 48 48 real score 1.9509746491909028 Hits@1 0.7963224893917963 Hits@3 0.9292786421499293 Hits@10 0.9858557284299858 MRR 0.8699501008319881 cur_rank 0 abs_cur_rank 0 total_num 706 809
0 1
1420 471 62
checkcorrect 50 50 real score 0.9991054534912109 Hits@1 0.7966101694915254 Hits@3 0.9293785310734464 Hits@10 0.9858757062146892 MRR 0.8701337871302479 cur_rank 0 abs_cur_rank 0 total_num 707 809
0 1
1420 34 207
checkcorrect 28 28 real score 0.8391821980476379 Hits@1 0.7954866008462623 Hits@3 0.9280677009873061 Hits@10 0.9858956276445698 MRR 0.869188605484084 cur_rank 4 abs_cur_rank 4 total_num 708 809
142 136
1420 411 319
checkcorrect 16 16 real score 1.999057376384735 Hits@1 0.795774647887324 Hits@3 0.928169014084507 Hits@10 0.9859154929577465 M

142 140
1420 175 47
checkcorrect 212 212 real score 1.9303200662136077 Hits@1 0.7932885906040269 Hits@3 0.9261744966442953 Hits@10 0.9852348993288591 MRR 0.8668562177285315 cur_rank 0 abs_cur_rank 0 total_num 744 809
0 0
0 2 7
checkcorrect 2 2 real score 0.0 Hits@1 0.792225201072386 Hits@3 0.9262734584450402 Hits@10 0.985254691689008 MRR 0.8663644533616032 cur_rank 1 abs_cur_rank 1 total_num 745 809
142 150
1420 369 290
checkcorrect 16 16 real score 1.9986763179302216 Hits@1 0.7925033467202142 Hits@3 0.9263721552878179 Hits@10 0.9852744310575636 MRR 0.8665433496757108 cur_rank 0 abs_cur_rank 0 total_num 746 809
142 150
1420 1647 828
checkcorrect 196 196 real score 1.9109717726707458 Hits@1 0.7927807486631016 Hits@3 0.9264705882352942 Hits@10 0.9852941176470589 MRR 0.8667217676574277 cur_rank 0 abs_cur_rank 0 total_num 747 809
142 110
1420 498 160
checkcorrect 16 16 real score 1.9985837757587432 Hits@1 0.7930574098798397 Hits@3 0.9265687583444593 Hits@10 0.9853137516688919 MRR 0.8668997

142 42
1420 315 132
checkcorrect 16 16 real score 1.2981344059109687 Hits@1 0.7895408163265306 Hits@3 0.923469387755102 Hits@10 0.9834183673469388 MRR 0.864035664080521 cur_rank 0 abs_cur_rank 0 total_num 783 809
142 150
1420 1405 573
checkcorrect 192 192 real score 1.9701748311519622 Hits@1 0.7898089171974523 Hits@3 0.9235668789808917 Hits@10 0.9834394904458599 MRR 0.8642088670562147 cur_rank 0 abs_cur_rank 0 total_num 784 809
142 147
1420 608 99
checkcorrect 92 92 real score 1.9633618414402008 Hits@1 0.7900763358778626 Hits@3 0.9236641221374046 Hits@10 0.9834605597964376 MRR 0.8643816293118683 cur_rank 0 abs_cur_rank 0 total_num 785 809
0 1
1420 182 52
checkcorrect 100 100 real score 0.7798398733139038 Hits@1 0.7903430749682337 Hits@3 0.9237611181702668 Hits@10 0.9834815756035579 MRR 0.8645539525274822 cur_rank 0 abs_cur_rank 0 total_num 786 809
0 1
1420 12 41
checkcorrect 50 50 real score 0.9994046688079834 Hits@1 0.7906091370558376 Hits@3 0.9238578680203046 Hits@10 0.98350253807106