### Important

In [1]:
data_name = 'WN18RR_v3'
model_id = 'main_10'

In [2]:
#difine the names for saving
model_name = 'Model_' + model_id + '_' + data_name
one_hop_model_name = 'One_hop_model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [3]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
import pickle

from collections import defaultdict
from copy import deepcopy
from sklearn.utils import shuffle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [4]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = set()
            
            one_hop[s].add((r,t))
            
            if t not in one_hop:
                
                one_hop[t] = set()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            one_hop[t].add((r_inv,s))
            
        #change each set in one_hop to list
        for e in one_hop:
            
            one_hop[e] = list(one_hop[e])

In [5]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, size_bd=50, threshold=20000):
                
        self.size_bd = size_bd #size bound limit the number of paths to a target entity t
        
        #number of times paths with specific length been performed for recursion
        self.threshold = threshold
        
    '''
    Given an entity s, the function will find the paths from s to other entities, using recursion.
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on existing entities')

        #here is the result dict. Its key is each entity t sharing paths from s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #qualified_t contains the types of t we want to consider,
        #that is, what t will be added to the result set.
        qualified_t = set()

        #under this mode, we will only consider the direct neighbour of s
        if mode == 'direct_neighbour':
        
            for Tuple in one_hop[s]:
            
                t = Tuple[1]
                
                qualified_t.add(t)
        
        #under this mode, we will only consider one specified entity t
        elif mode == 'target_specified':
            
            qualified_t.add(t_input)
        
        #under this mode, we will consider any entity
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                qualified_t.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {s, e1, ..., ek-1, node}
        from s to this node, we will further find the direct neighbor t' of this node. 
        If t' is not an on-path entity (not among s, e1,...ek-1, node), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict):

            #when the current path is within lower_bd and upper_bd, 
            #and the node is among the qualified t, and it has not been fill of paths w.r.t size_limit,
            #we will add this path to the node
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in qualified_t) and (len(res[node]) < self.size_bd):
                
                res[node].add(tuple(path))
                    
            #won't start new recursions if the current path length already reaches upper limit
            #or the number of recursions performed on this length has reached the limit
            if (len(path) < upper_bd) and (count_dict[len(path)] <= self.threshold):
                                
                #temp list is the id list for us to go-over one_hop[node]
                temp_list = [i for i in range(len(one_hop[node]))]
                random.shuffle(temp_list) #so we random-shuffle the list
                
                #only take 50 recursions if there are too many (r,t)
                for i in temp_list[:50]:
                    
                    #obtain tuple of (r,t)
                    Tuple = one_hop[node][i]
                    r, t = Tuple[0], Tuple[1]
                    
                    #add to count_dict even if eventually this step not proceed
                    count_dict[len(path)] += 1
                    
                    #if t not on the path and we not exceed the computation threshold, 
                    #then finally proceed to next recursion
                    if (t not in on_path_en) and (count_dict[len(path)] <= self.threshold):

                        helper(t, path + [r], on_path_en.union({t}), res, qualified_t, 
                               lower_bd, upper_bd, one_hop, count_dict)

        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict)
        
        return(res, count_dict)

In [6]:
train_path = '../data/' + data_name + '/train.txt'
valid_path = '../data/' + data_name + '/valid.txt'
test_path = '../data/' + data_name + '/test.txt'

In [7]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [8]:
#define the dictionaries and sets for load KG
one_hop = dict() 
data = set()
s_t_r = defaultdict(set)

#define the dictionaries, which is shared by initail and inductive train/valid/test
entity2id = dict()
id2entity = dict()
relation2id = dict()
id2relation = dict()

#fill in the sets and dicts
Class_1.load_train_data(train_path, one_hop, data, s_t_r,
                        entity2id, id2entity, relation2id, id2relation)

In [9]:
#define the dictionaries and sets for load KG
one_hop_valid = dict() 
data_valid = set()
s_t_r_valid = defaultdict(set)

#fill in the sets and dicts
Class_1.load_train_data(valid_path, one_hop_valid, data_valid, s_t_r_valid,
                        entity2id, id2entity, relation2id, id2relation)

In [10]:
#define the dictionaries and sets for load KG
one_hop_test = dict() 
data_test = set()
s_t_r_test = defaultdict(set)

#fill in the sets and dicts
Class_1.load_train_data(test_path, one_hop_test, data_test, s_t_r_test,
                        entity2id, id2entity, relation2id, id2relation)

#### Build the path-based siamese neural network structure

We use biLSTM to train on the input path embedding sequence to predict the output embedding or the relation.

In [11]:
# Input layer, using integer to represent each relation type
#note that inputs_path is the path inputs, while inputs_out_re is the output relation inputs
fst_path = keras.Input(shape=(None,), dtype="int32")
scd_path = keras.Input(shape=(None,), dtype="int32")
thd_path = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of paths are not the same
in_embd_var = layers.Embedding(len(relation2id)+1, 300)

# Obtain the embedding
fst_p_embd = in_embd_var(fst_path)
scd_p_embd = in_embd_var(scd_path)
thd_p_embd = in_embd_var(thd_path)

# Embed each integer in a 300-dimensional vector as output
rela_embd = layers.Embedding(len(relation2id)+1, 300)(id_rela)

#add 2 layer bi-directional LSTM
lstm_layer_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_layer_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

#first LSTM layer
fst_lstm_mid = lstm_layer_1(fst_p_embd)
scd_lstm_mid = lstm_layer_1(scd_p_embd)
thd_lstm_mid = lstm_layer_1(thd_p_embd)

#second LSTM layer
fst_lstm_out = lstm_layer_2(fst_lstm_mid)
scd_lstm_out = lstm_layer_2(scd_lstm_mid)
thd_lstm_out = lstm_layer_2(thd_lstm_mid)

#reduce max
fst_reduce_max = tf.reduce_max(fst_lstm_out, axis=1)
scd_reduce_max = tf.reduce_max(scd_lstm_out, axis=1)
thd_reduce_max = tf.reduce_max(thd_lstm_out, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
path_concat = layers.concatenate([fst_reduce_max, scd_reduce_max, thd_reduce_max], axis=-1)

#add dropout on top of the concatenation from all channels
dropout = layers.Dropout(0.25)(path_concat)

#multiply into output embd size by dense layer: (Batch, 300)
path_out_vect = layers.Dense(300, activation='tanh')(dropout)

#remove the time dimension from the output embd since there is only one step
rela_out_embd = tf.reduce_sum(rela_embd, axis=1)

# Normalize the vectors to have unit length
path_out_vect_norm = tf.math.l2_normalize(path_out_vect, axis=-1)
rela_out_embd_norm = tf.math.l2_normalize(rela_out_embd, axis=-1)

# Calculate the dot product
dot_product = layers.Dot(axes=-1)([path_out_vect_norm, rela_out_embd_norm])

#put together the model
model = keras.Model([fst_path, scd_path, thd_path, id_rela], dot_product)

2023-03-02 22:46:24.259600: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [12]:
#config the Adam optimizer 
opt = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['binary_accuracy'])

#### Build the subgraph-based siamese neural network

In [13]:
#each input is an vector with number of relations to be dim:
#each dim represent the existence (1) or not (0) of an out-going relation from the entity
source_path_1 = keras.Input(shape=(None,), dtype="int32")
source_path_2 = keras.Input(shape=(None,), dtype="int32")
source_path_3 = keras.Input(shape=(None,), dtype="int32")

target_path_1 = keras.Input(shape=(None,), dtype="int32")
target_path_2 = keras.Input(shape=(None,), dtype="int32")
target_path_3 = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela_ = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of paths are not the same
in_embd_var_ = layers.Embedding(len(relation2id)+1, 300)

# Obtain the source embeddings
source_embd_1 = in_embd_var_(source_path_1)
source_embd_2 = in_embd_var_(source_path_2)
source_embd_3 = in_embd_var_(source_path_3)

#Obtain the target embeddings
target_embd_1 = in_embd_var_(target_path_1)
target_embd_2 = in_embd_var_(target_path_2)
target_embd_3 = in_embd_var_(target_path_3)

# Embed each integer in a 300-dimensional vector as output
rela_embd_ = layers.Embedding(len(relation2id)+1, 300)(id_rela_)

#add 2 layer bi-directional LSTM network
lstm_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

###source lstm implimentation########
#first LSTM layer
source_mid_1 = lstm_1(source_embd_1)
source_mid_2 = lstm_1(source_embd_2)
source_mid_3 = lstm_1(source_embd_3)

#second LSTM layer
source_out_1 = lstm_2(source_mid_1)
source_out_2 = lstm_2(source_mid_2)
source_out_3 = lstm_2(source_mid_3)

#reduce max
source_max_1 = tf.reduce_max(source_out_1, axis=1)
source_max_2 = tf.reduce_max(source_out_2, axis=1)
source_max_3 = tf.reduce_max(source_out_3, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
source_concat = layers.concatenate([source_max_1, source_max_2, source_max_3], axis=-1)

#add dropout on top of the concatenation from all channels
source_dropout = layers.Dropout(0.25)(source_concat)

###target lstm implimentation########
#first LSTM layer
target_mid_1 = lstm_1(target_embd_1)
target_mid_2 = lstm_1(target_embd_2)
target_mid_3 = lstm_1(target_embd_3)

#second LSTM layer
target_out_1 = lstm_2(target_mid_1)
target_out_2 = lstm_2(target_mid_2)
target_out_3 = lstm_2(target_mid_3)

#reduce max
target_max_1 = tf.reduce_max(target_out_1, axis=1)
target_max_2 = tf.reduce_max(target_out_2, axis=1)
target_max_3 = tf.reduce_max(target_out_3, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
target_concat = layers.concatenate([target_max_1, target_max_2, target_max_3], axis=-1)

#add dropout on top of the concatenation from all channels
target_dropout = layers.Dropout(0.25)(target_concat)

#further concatenate source and target output embeddings
final_concat = layers.concatenate([source_dropout, target_dropout], axis=-1)

#multiply into output embd size by dense layer: (Batch, 300)
out_vect = layers.Dense(300, activation='tanh')(final_concat)

#remove the time dimension from the output embd since there is only one step
rela_out_embd_ = tf.reduce_sum(rela_embd_, axis=1)

# Normalize the vectors to have unit length
out_vect_norm = tf.math.l2_normalize(out_vect, axis=-1)
rela_out_embd_norm_ = tf.math.l2_normalize(rela_out_embd_, axis=-1)

# Calculate the dot product
dot_product_ = layers.Dot(axes=-1)([out_vect_norm, rela_out_embd_norm_])

#put together the model
model_2 = keras.Model([source_path_1, source_path_2, source_path_3,
                       target_path_1, target_path_2, target_path_3, id_rela_], dot_product_)

In [14]:
#config the Adam optimizer 
opt_ = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model_2.compile(loss='binary_crossentropy', optimizer=opt_, metrics=['binary_accuracy'])

### Build the big-batch for path-based model
We will build the big-batch for the path-based model training. That is, we will build three list to store three paths, respectively. 
* At each step, three different paths between two entities s and t are selected. Each path is append to one of the list. 
* If this step is for positive samples, the existing relation r will be selected between s and t. If there are more than one relation from s to t, we randomly choose one. Also, the label list will be appended [1,0].
* If this step is for negative samples, one relation that does not exist between s and t will be selected randomly and append to the relation list. Also, the label list will be appended [0,1].
* In practice, the positive step is always fallowed by a negative step. The same paths in the positive step will be used in the next negative step, while the relation is a negative one chosen in the above way.
* We do this until the length limit is reached.

**For relation prediciton, we will only need to train using (s,r,t) triple. (t,r-1,s) is not necessary and hence not included in training.**

In [15]:
#function to build the big batche for path-based training
def build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    count = 0
    for s in existing_ids:
        
        #impliment the path finding algorithm to find paths between s and t
        result, length_dict = Class_2.obtain_paths('direct_neighbour', s, 'nb', lower_bd, upper_bd, one_hop)
        
        #proceed only if at least three paths are between s and t
        for t in result:
            
            if len(s_t_r[(s,t)]) == 0:
                
                raise ValueError(s,t,id2entity[s], id2entity[t])

            #we are only interested in forward link in relation prediciton
            ini_r_list = list()
            
            #obtain initial relations between s and t
            for r in s_t_r[(s,t)]:
                if r % 2 == 0:#initial relation id is always an even number
                    ini_r_list.append(r)
            
            #if there exist more than three paths between s and t, 
            #and inital connection between s and t exists,
            #and not every r in the relation dictionary exists between s and t (although this is rare)
            #we then proceed
            if len(result[t]) >= 3 and len(ini_r_list) > 0 and len(ini_r_list) < int(num_ini_r):
                
                #obtain the list form of all the paths from s to t
                temp_path_list = list(result[t])

                temp_pair = random.sample(temp_path_list, 3)

                path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

                #####positive#####################
                #append the paths: note that we add the space holder id at the end of the shorter path
                x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                #append relation
                r = random.choice(ini_r_list)
                x_r_list.append([r])
                y_list.append(1.)

                #####negative#####################
                #append the paths: note that we add the space holder id at the end
                #of the shorter path
                x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                #append relation
                neg_r_list = list(ini_r_id_set.difference(set(ini_r_list)))
                r_ran = random.choice(neg_r_list)
                x_r_list.append([r_ran])
                y_list.append(0.)
        
        count += 1
        if count % 100 == 0:
            print('generating big-batches for path-based model', count, len(existing_ids))

### Build the big-batch for the subgraph-based network training

* At each step, we will select one triple (s,r,t) from the dataset. Then, reaching out paths of s and t is generated respectively according to their out-going relations.
* We will select three paths for each of source and target entity. Add them to the corresponding list.
* If this is a positive sample step, the id of relation r is appended to the relation list.
* If this is a negative sample step, the id of a random relation is appended to the relation lsit.
* Similarly, one negative sample step always follows one positive step. The one-hop vectors from the previous positve step is used again for the negative step.

In [16]:
#function to build the big-batch for one-hope neighbor training
def build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_s_list, x_t_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
    
    data = list(data)    
    data = shuffle(data)
    
    for i_0 in range(len(data)):
        
        triple = data[i_0]
        
        s, r, t = triple[0], triple[1], triple[2] #obtain entities and relation IDs
        
        path_s, path_t = set(), set() #sets holding all the paths from s or t
        
        #obtain the paths out from s or t by "any target" mode. That is, 
        result_s, length_dict_s = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)
        result_t, length_dict_t = Class_2.obtain_paths('any_target', t, 'any', lower_bd, upper_bd, one_hop)
        
        #add paths to the source/target path_set
        for e in result_s:
            for path in result_s[e]:
                path_s.add(path)
        for e in result_t:
            for path in result_t[e]:
                path_t.add(path)
        
        #see if both path_s and path_t have at least three paths
        if len(path_s) >= 3 and len(path_t) >= 3:
            
            #change to lists
            path_s, path_t = list(path_s), list(path_t)
            
            #randomly obtain three paths
            temp_s = random.sample(path_s, 3)
            temp_t = random.sample(path_t, 3)
            s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
            t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]
            
            #####positive step###########
            #append the paths: note that we add the space holder id at the end of the shorter path
            x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
            x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
            x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
            
            x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
            x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
            x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

            #append relation
            x_r_list.append([r])
            y_list.append(1.)
            
            #####negative step###########
            #append the paths: note that we add the space holder id at the end of the shorter path
            x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
            x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
            x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
            
            x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
            x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
            x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

            #append relation
            neg_r_list = list(ini_r_id_set.difference({r}))
            r_ran = random.choice(neg_r_list)
            x_r_list.append([r_ran])
            y_list.append(0.)
        
        if i_0 % 500 == 0:
            print('generating big-batches for subgraph-based model', i_0, len(data))

### Start Training: load the KG and call classes

Here, we use the validation set to see the training efficiency. That is, we use the validation to check whether the true relation between entities can be predicted by paths.

The trick is: in validation, we have to use the same relation ID and entity ID as in the training. But we don't want to use the links in training anymore. That is, in validation, we want to use (and update if necessary) entity2id, id2entity, relation2id and id2relation. But we want to use new one_hop, data, data_ and s_t_r for validation set. Then, path-finding will also be based on new one_hop.


In [17]:
model_name

'Model_main_10_WN18RR_v3'

In [18]:
one_hop_model_name

'One_hop_model_main_10_WN18RR_v3'

In [19]:
ids_name

'IDs_main_10_WN18RR_v3'

In [20]:
#first, we save the relation and ids
Dict = dict()

#save training data
Dict['one_hop'] = one_hop
Dict['data'] = data
Dict['s_t_r'] = s_t_r

#save valid data
Dict['one_hop_valid'] = one_hop_valid
Dict['data_valid'] = data_valid
Dict['s_t_r_valid'] = s_t_r_valid

#save test data
Dict['one_hop_test'] = one_hop_test
Dict['data_test'] = data_test
Dict['s_t_r_test'] = s_t_r_test

#save shared dictionaries
Dict['entity2id'] = entity2id
Dict['id2entity'] = id2entity
Dict['relation2id'] = relation2id
Dict['id2relation'] = id2relation

with open('../weight_bin/' + ids_name + '.pickle', 'wb') as handle:
    pickle.dump(Dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [21]:
###train the path-based model
lower_bd = 1
upper_bd = 10
num_epoch = 50
batch_size = 32
    
#we do model.fit inside a for loop, which is okay after setting epoch=epoch+1 and initial_epoch=epoch
for epoch in range(num_epoch):
    
    print('Current epoch:', epoch, num_epoch)
    
    #rebuild the big-batch every 4 epoch
    if epoch % 5 == 0:
        
        #define the training lists
        train_p_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, list(), list()

        #define the validation lists
        valid_p_list, valid_r_list, valid_y_list = {'1': [], '2': [], '3': []}, list(), list()

        #######################################
        ###build the big-batches###############      

        #fill in the training array list
        build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                              train_p_list, train_r_list, train_y_list,
                              relation2id, entity2id, id2relation, id2entity)

        #fill in the validation array list
        build_big_batches_path(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                              valid_p_list, valid_r_list, valid_y_list,
                              relation2id, entity2id, id2relation, id2entity)    

        #######################################
        ###do the training#####################
        #sometimes the validation dataset is so small so sparse, 
        #which cannot find three paths between any pair of s and t.
        #in such a case, we will divide the training big-batch into train and valid
        if len(valid_y_list) >= 100:
            #generate the input arrays
            x_train_1 = np.asarray(train_p_list['1'], dtype='int')
            x_train_2 = np.asarray(train_p_list['2'], dtype='int')
            x_train_3 = np.asarray(train_p_list['3'], dtype='int')
            x_train_r = np.asarray(train_r_list, dtype='int')
            y_train = np.asarray(train_y_list, dtype='int')

            #generate the validation arrays
            x_valid_1 = np.asarray(valid_p_list['1'], dtype='int')
            x_valid_2 = np.asarray(valid_p_list['2'], dtype='int')
            x_valid_3 = np.asarray(valid_p_list['3'], dtype='int')
            x_valid_r = np.asarray(valid_r_list, dtype='int')
            y_valid = np.asarray(valid_y_list, dtype='int')
        
        else:
            split = int(len(train_y_list)*0.8)
            #generate the input arrays
            x_train_1 = np.asarray(train_p_list['1'][:split], dtype='int')
            x_train_2 = np.asarray(train_p_list['2'][:split], dtype='int')
            x_train_3 = np.asarray(train_p_list['3'][:split], dtype='int')
            x_train_r = np.asarray(train_r_list[:split], dtype='int')
            y_train = np.asarray(train_y_list[:split], dtype='int')

            #generate the validation arrays
            x_valid_1 = np.asarray(train_p_list['1'][split:], dtype='int')
            x_valid_2 = np.asarray(train_p_list['2'][split:], dtype='int')
            x_valid_3 = np.asarray(train_p_list['3'][split:], dtype='int')
            x_valid_r = np.asarray(train_r_list[split:], dtype='int')
            y_valid = np.asarray(train_y_list[split:], dtype='int')

    model.fit([x_train_1, x_train_2, x_train_3, x_train_r], y_train, 
              validation_data=([x_valid_1, x_valid_2, x_valid_3, x_valid_r], y_valid),
              batch_size=batch_size, epochs=epoch+1, initial_epoch=epoch)

    #rebuild the big-batch every 4 epoch
    if epoch % 5 == 4:

        del(x_train_1, x_train_2, x_train_3, x_train_r, y_train)
        del(x_valid_1, x_valid_2, x_valid_3, x_valid_r, y_valid)
        del(train_p_list, train_r_list, train_y_list)
        del(valid_p_list, valid_r_list, valid_y_list)    
    
# Save model and weights
add_h5 = model_name + '.h5'
save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path = os.path.join(save_dir, add_h5)
model.save(model_path)
print('Save model')
del(model)

Current epoch: 0 50
generating big-batches for path-based model 100 12078
generating big-batches for path-based model 200 12078
generating big-batches for path-based model 300 12078
generating big-batches for path-based model 400 12078
generating big-batches for path-based model 500 12078
generating big-batches for path-based model 600 12078
generating big-batches for path-based model 700 12078
generating big-batches for path-based model 800 12078
generating big-batches for path-based model 900 12078
generating big-batches for path-based model 1000 12078
generating big-batches for path-based model 1100 12078
generating big-batches for path-based model 1200 12078
generating big-batches for path-based model 1300 12078
generating big-batches for path-based model 1400 12078
generating big-batches for path-based model 1500 12078
generating big-batches for path-based model 1600 12078
generating big-batches for path-based model 1700 12078
generating big-batches for path-based model 1800 12078

Current epoch: 1 50
Epoch 2/2
Current epoch: 2 50
Epoch 3/3
Current epoch: 3 50
Epoch 4/4
Current epoch: 4 50
Epoch 5/5
Current epoch: 5 50
generating big-batches for path-based model 100 12078
generating big-batches for path-based model 200 12078
generating big-batches for path-based model 300 12078
generating big-batches for path-based model 400 12078
generating big-batches for path-based model 500 12078
generating big-batches for path-based model 600 12078
generating big-batches for path-based model 700 12078
generating big-batches for path-based model 800 12078
generating big-batches for path-based model 900 12078
generating big-batches for path-based model 1000 12078
generating big-batches for path-based model 1100 12078
generating big-batches for path-based model 1200 12078
generating big-batches for path-based model 1300 12078
generating big-batches for path-based model 1400 12078
generating big-batches for path-based model 1500 12078
generating big-batches for path-based model 

Current epoch: 6 50
Epoch 7/7
Current epoch: 7 50
Epoch 8/8
Current epoch: 8 50
Epoch 9/9
Current epoch: 9 50
Epoch 10/10
Current epoch: 10 50
generating big-batches for path-based model 100 12078
generating big-batches for path-based model 200 12078
generating big-batches for path-based model 300 12078
generating big-batches for path-based model 400 12078
generating big-batches for path-based model 500 12078
generating big-batches for path-based model 600 12078
generating big-batches for path-based model 700 12078
generating big-batches for path-based model 800 12078
generating big-batches for path-based model 900 12078
generating big-batches for path-based model 1000 12078
generating big-batches for path-based model 1100 12078
generating big-batches for path-based model 1200 12078
generating big-batches for path-based model 1300 12078
generating big-batches for path-based model 1400 12078
generating big-batches for path-based model 1500 12078
generating big-batches for path-based mod

Current epoch: 11 50
Epoch 12/12
Current epoch: 12 50
Epoch 13/13
Current epoch: 13 50
Epoch 14/14
Current epoch: 14 50
Epoch 15/15
Current epoch: 15 50
generating big-batches for path-based model 100 12078
generating big-batches for path-based model 200 12078
generating big-batches for path-based model 300 12078
generating big-batches for path-based model 400 12078
generating big-batches for path-based model 500 12078
generating big-batches for path-based model 600 12078
generating big-batches for path-based model 700 12078
generating big-batches for path-based model 800 12078
generating big-batches for path-based model 900 12078
generating big-batches for path-based model 1000 12078
generating big-batches for path-based model 1100 12078
generating big-batches for path-based model 1200 12078
generating big-batches for path-based model 1300 12078
generating big-batches for path-based model 1400 12078
generating big-batches for path-based model 1500 12078
generating big-batches for path

Epoch 16/16
Current epoch: 16 50
Epoch 17/17
Current epoch: 17 50
Epoch 18/18
Current epoch: 18 50
Epoch 19/19
Current epoch: 19 50
Epoch 20/20
Current epoch: 20 50
generating big-batches for path-based model 100 12078
generating big-batches for path-based model 200 12078
generating big-batches for path-based model 300 12078
generating big-batches for path-based model 400 12078
generating big-batches for path-based model 500 12078
generating big-batches for path-based model 600 12078
generating big-batches for path-based model 700 12078
generating big-batches for path-based model 800 12078
generating big-batches for path-based model 900 12078
generating big-batches for path-based model 1000 12078
generating big-batches for path-based model 1100 12078
generating big-batches for path-based model 1200 12078
generating big-batches for path-based model 1300 12078
generating big-batches for path-based model 1400 12078
generating big-batches for path-based model 1500 12078
generating big-batc

Current epoch: 21 50
Epoch 22/22
Current epoch: 22 50
Epoch 23/23
Current epoch: 23 50
Epoch 24/24
Current epoch: 24 50
Epoch 25/25
Current epoch: 25 50
generating big-batches for path-based model 100 12078
generating big-batches for path-based model 200 12078
generating big-batches for path-based model 300 12078
generating big-batches for path-based model 400 12078
generating big-batches for path-based model 500 12078
generating big-batches for path-based model 600 12078
generating big-batches for path-based model 700 12078
generating big-batches for path-based model 800 12078
generating big-batches for path-based model 900 12078
generating big-batches for path-based model 1000 12078
generating big-batches for path-based model 1100 12078
generating big-batches for path-based model 1200 12078
generating big-batches for path-based model 1300 12078
generating big-batches for path-based model 1400 12078
generating big-batches for path-based model 1500 12078
generating big-batches for path

Epoch 26/26
Current epoch: 26 50
Epoch 27/27
Current epoch: 27 50
Epoch 28/28
Current epoch: 28 50
Epoch 29/29
Current epoch: 29 50
Epoch 30/30
Current epoch: 30 50
generating big-batches for path-based model 100 12078
generating big-batches for path-based model 200 12078
generating big-batches for path-based model 300 12078
generating big-batches for path-based model 400 12078
generating big-batches for path-based model 500 12078
generating big-batches for path-based model 600 12078
generating big-batches for path-based model 700 12078
generating big-batches for path-based model 800 12078
generating big-batches for path-based model 900 12078
generating big-batches for path-based model 1000 12078
generating big-batches for path-based model 1100 12078
generating big-batches for path-based model 1200 12078
generating big-batches for path-based model 1300 12078
generating big-batches for path-based model 1400 12078
generating big-batches for path-based model 1500 12078
generating big-batc

Current epoch: 31 50
Epoch 32/32
Current epoch: 32 50
Epoch 33/33
Current epoch: 33 50
Epoch 34/34
Current epoch: 34 50
Epoch 35/35
Current epoch: 35 50
generating big-batches for path-based model 100 12078
generating big-batches for path-based model 200 12078
generating big-batches for path-based model 300 12078
generating big-batches for path-based model 400 12078
generating big-batches for path-based model 500 12078
generating big-batches for path-based model 600 12078
generating big-batches for path-based model 700 12078
generating big-batches for path-based model 800 12078
generating big-batches for path-based model 900 12078
generating big-batches for path-based model 1000 12078
generating big-batches for path-based model 1100 12078
generating big-batches for path-based model 1200 12078
generating big-batches for path-based model 1300 12078
generating big-batches for path-based model 1400 12078
generating big-batches for path-based model 1500 12078
generating big-batches for path

Current epoch: 36 50
Epoch 37/37
Current epoch: 37 50
Epoch 38/38
Current epoch: 38 50
Epoch 39/39
Current epoch: 39 50
Epoch 40/40
Current epoch: 40 50
generating big-batches for path-based model 100 12078
generating big-batches for path-based model 200 12078
generating big-batches for path-based model 300 12078
generating big-batches for path-based model 400 12078
generating big-batches for path-based model 500 12078
generating big-batches for path-based model 600 12078
generating big-batches for path-based model 700 12078
generating big-batches for path-based model 800 12078
generating big-batches for path-based model 900 12078
generating big-batches for path-based model 1000 12078
generating big-batches for path-based model 1100 12078
generating big-batches for path-based model 1200 12078
generating big-batches for path-based model 1300 12078
generating big-batches for path-based model 1400 12078
generating big-batches for path-based model 1500 12078
generating big-batches for path

Current epoch: 41 50
Epoch 42/42
Current epoch: 42 50
Epoch 43/43
Current epoch: 43 50
Epoch 44/44
Current epoch: 44 50
Epoch 45/45
Current epoch: 45 50
generating big-batches for path-based model 100 12078
generating big-batches for path-based model 200 12078
generating big-batches for path-based model 300 12078
generating big-batches for path-based model 400 12078
generating big-batches for path-based model 500 12078
generating big-batches for path-based model 600 12078
generating big-batches for path-based model 700 12078
generating big-batches for path-based model 800 12078
generating big-batches for path-based model 900 12078
generating big-batches for path-based model 1000 12078
generating big-batches for path-based model 1100 12078
generating big-batches for path-based model 1200 12078
generating big-batches for path-based model 1300 12078
generating big-batches for path-based model 1400 12078
generating big-batches for path-based model 1500 12078
generating big-batches for path

Current epoch: 46 50
Epoch 47/47
Current epoch: 47 50
Epoch 48/48
Current epoch: 48 50
Epoch 49/49
Current epoch: 49 50
Epoch 50/50
Save model


In [22]:
###train the subgraph-based model
lower_bd = 1
upper_bd = 3
num_epoch = 50
batch_size = 32
    
#we do model.fit inside a for loop, which is okay after setting epoch=epoch+1 and initial_epoch=epoch
for epoch in range(num_epoch):
    
    print('Current epoch:', epoch, num_epoch)

    #rebuild the big-batch every 4 epoch
    if epoch % 5 == 0:
        
        #define the training lists
        train_s_list, train_t_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, {'1': [], '2': [], '3': []}, list(), list()

        #define the validation lists
        valid_s_list, valid_t_list, valid_r_list, valid_y_list = {'1': [], '2': [], '3': []}, {'1': [], '2': [], '3': []}, list(), list()

        #######################################
        ###build the big-batches###############      

        #fill in the training array list
        build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                              train_s_list, train_t_list, train_r_list, train_y_list,
                              relation2id, entity2id, id2relation, id2entity)

        #fill in the validation array list
        build_big_batches_subgraph(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                              valid_s_list, valid_t_list, valid_r_list, valid_y_list,
                              relation2id, entity2id, id2relation, id2entity)    

        #######################################
        ###do the training#####################
        #sometimes the validation dataset is so small so sparse, 
        #which cannot find three paths between any pair of s and t.
        #in such a case, we will divide the training big-batch into train and valid
        if len(valid_y_list) >= 100:
            #generate the input arrays
            x_train_s_1 = np.asarray(train_s_list['1'], dtype='int')
            x_train_s_2 = np.asarray(train_s_list['2'], dtype='int')
            x_train_s_3 = np.asarray(train_s_list['3'], dtype='int')

            x_train_t_1 = np.asarray(train_t_list['1'], dtype='int')
            x_train_t_2 = np.asarray(train_t_list['2'], dtype='int')
            x_train_t_3 = np.asarray(train_t_list['3'], dtype='int')

            x_train_r = np.asarray(train_r_list, dtype='int')
            y_train = np.asarray(train_y_list, dtype='int')

            #generate the validation arrays
            x_valid_s_1 = np.asarray(valid_s_list['1'], dtype='int')
            x_valid_s_2 = np.asarray(valid_s_list['2'], dtype='int')
            x_valid_s_3 = np.asarray(valid_s_list['3'], dtype='int')

            x_valid_t_1 = np.asarray(valid_t_list['1'], dtype='int')
            x_valid_t_2 = np.asarray(valid_t_list['2'], dtype='int')
            x_valid_t_3 = np.asarray(valid_t_list['3'], dtype='int')

            x_valid_r = np.asarray(valid_r_list, dtype='int')
            y_valid = np.asarray(valid_y_list, dtype='int')
        
        else:
            split = int(len(train_y_list)*0.8)
            #generate the input arrays
            x_train_s_1 = np.asarray(train_s_list['1'][:split], dtype='int')
            x_train_s_2 = np.asarray(train_s_list['2'][:split], dtype='int')
            x_train_s_3 = np.asarray(train_s_list['3'][:split], dtype='int')

            x_train_t_1 = np.asarray(train_t_list['1'][:split], dtype='int')
            x_train_t_2 = np.asarray(train_t_list['2'][:split], dtype='int')
            x_train_t_3 = np.asarray(train_t_list['3'][:split], dtype='int')

            x_train_r = np.asarray(train_r_list[:split], dtype='int')
            y_train = np.asarray(train_y_list[:split], dtype='int')

            #generate the validation arrays
            x_valid_s_1 = np.asarray(train_s_list['1'][split:], dtype='int')
            x_valid_s_2 = np.asarray(train_s_list['2'][split:], dtype='int')
            x_valid_s_3 = np.asarray(train_s_list['3'][split:], dtype='int')

            x_valid_t_1 = np.asarray(train_t_list['1'][split:], dtype='int')
            x_valid_t_2 = np.asarray(train_t_list['2'][split:], dtype='int')
            x_valid_t_3 = np.asarray(train_t_list['3'][split:], dtype='int')

            x_valid_r = np.asarray(train_r_list[split:], dtype='int')
            y_valid = np.asarray(train_y_list[split:], dtype='int')

    model_2.fit([x_train_s_1, x_train_s_2, x_train_s_3, x_train_t_1, x_train_t_2, x_train_t_3, x_train_r], y_train, 
              validation_data=([x_valid_s_1, x_valid_s_2, x_valid_s_3, x_valid_t_1, x_valid_t_2, x_valid_t_3, x_valid_r], y_valid),
              batch_size=batch_size, epochs=epoch+1, initial_epoch=epoch)

    #rebuild the big-batch every 4 epoch
    if epoch % 5 == 4:
        
        del(x_train_s_1, x_train_s_2, x_train_s_3, x_train_t_1, x_train_t_2, x_train_t_3, x_train_r, y_train)
        del(x_valid_s_1, x_valid_s_2, x_valid_s_3, x_valid_t_1, x_valid_t_2, x_valid_t_3, x_valid_r, y_valid)
        del(train_s_list, train_t_list, train_r_list, train_y_list)
        del(valid_s_list, valid_t_list, valid_r_list, valid_y_list)

# Save model and weights
one_hop_add_h5 = one_hop_model_name + '.h5'
one_hop_save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(one_hop_save_dir):
    os.makedirs(one_hop_save_dir)
one_hop_model_path = os.path.join(one_hop_save_dir, one_hop_add_h5)
model_2.save(one_hop_model_path)
print('Save model')
del(model_2)

Current epoch: 0 50
generating big-batches for subgraph-based model 0 25901
generating big-batches for subgraph-based model 500 25901
generating big-batches for subgraph-based model 1000 25901
generating big-batches for subgraph-based model 1500 25901
generating big-batches for subgraph-based model 2000 25901
generating big-batches for subgraph-based model 2500 25901
generating big-batches for subgraph-based model 3000 25901
generating big-batches for subgraph-based model 3500 25901
generating big-batches for subgraph-based model 4000 25901
generating big-batches for subgraph-based model 4500 25901
generating big-batches for subgraph-based model 5000 25901
generating big-batches for subgraph-based model 5500 25901
generating big-batches for subgraph-based model 6000 25901
generating big-batches for subgraph-based model 6500 25901
generating big-batches for subgraph-based model 7000 25901
generating big-batches for subgraph-based model 7500 25901
generating big-batches for subgraph-base

Current epoch: 7 50
Epoch 8/8
Current epoch: 8 50
Epoch 9/9
Current epoch: 9 50
Epoch 10/10
Current epoch: 10 50
generating big-batches for subgraph-based model 0 25901
generating big-batches for subgraph-based model 500 25901
generating big-batches for subgraph-based model 1000 25901
generating big-batches for subgraph-based model 1500 25901
generating big-batches for subgraph-based model 2000 25901
generating big-batches for subgraph-based model 2500 25901
generating big-batches for subgraph-based model 3000 25901
generating big-batches for subgraph-based model 3500 25901
generating big-batches for subgraph-based model 4000 25901
generating big-batches for subgraph-based model 4500 25901
generating big-batches for subgraph-based model 5000 25901
generating big-batches for subgraph-based model 5500 25901
generating big-batches for subgraph-based model 6000 25901
generating big-batches for subgraph-based model 6500 25901
generating big-batches for subgraph-based model 7000 25901
genera

generating big-batches for subgraph-based model 1500 3097
generating big-batches for subgraph-based model 2000 3097
generating big-batches for subgraph-based model 2500 3097
generating big-batches for subgraph-based model 3000 3097
Epoch 16/16
Current epoch: 16 50
Epoch 17/17
Current epoch: 17 50
Epoch 18/18
Current epoch: 18 50
Epoch 19/19
Current epoch: 19 50
Epoch 20/20
Current epoch: 20 50
generating big-batches for subgraph-based model 0 25901
generating big-batches for subgraph-based model 500 25901
generating big-batches for subgraph-based model 1000 25901
generating big-batches for subgraph-based model 1500 25901
generating big-batches for subgraph-based model 2000 25901
generating big-batches for subgraph-based model 2500 25901
generating big-batches for subgraph-based model 3000 25901
generating big-batches for subgraph-based model 3500 25901
generating big-batches for subgraph-based model 4000 25901
generating big-batches for subgraph-based model 4500 25901
generating big-ba

generating big-batches for subgraph-based model 22500 25901
generating big-batches for subgraph-based model 23000 25901
generating big-batches for subgraph-based model 23500 25901
generating big-batches for subgraph-based model 24000 25901
generating big-batches for subgraph-based model 24500 25901
generating big-batches for subgraph-based model 25000 25901
generating big-batches for subgraph-based model 25500 25901
generating big-batches for subgraph-based model 0 3097
generating big-batches for subgraph-based model 500 3097
generating big-batches for subgraph-based model 1000 3097
generating big-batches for subgraph-based model 1500 3097
generating big-batches for subgraph-based model 2000 3097
generating big-batches for subgraph-based model 2500 3097
generating big-batches for subgraph-based model 3000 3097
Epoch 26/26
Current epoch: 26 50
Epoch 27/27
Current epoch: 27 50
Epoch 28/28
Current epoch: 28 50
Epoch 29/29
Current epoch: 29 50
Epoch 30/30
Current epoch: 30 50
generating bi

generating big-batches for subgraph-based model 17500 25901
generating big-batches for subgraph-based model 18000 25901
generating big-batches for subgraph-based model 18500 25901
generating big-batches for subgraph-based model 19000 25901
generating big-batches for subgraph-based model 19500 25901
generating big-batches for subgraph-based model 20000 25901
generating big-batches for subgraph-based model 20500 25901
generating big-batches for subgraph-based model 21000 25901
generating big-batches for subgraph-based model 21500 25901
generating big-batches for subgraph-based model 22000 25901
generating big-batches for subgraph-based model 22500 25901
generating big-batches for subgraph-based model 23000 25901
generating big-batches for subgraph-based model 23500 25901
generating big-batches for subgraph-based model 24000 25901
generating big-batches for subgraph-based model 24500 25901
generating big-batches for subgraph-based model 25000 25901
generating big-batches for subgraph-base

generating big-batches for subgraph-based model 12500 25901
generating big-batches for subgraph-based model 13000 25901
generating big-batches for subgraph-based model 13500 25901
generating big-batches for subgraph-based model 14000 25901
generating big-batches for subgraph-based model 14500 25901
generating big-batches for subgraph-based model 15000 25901
generating big-batches for subgraph-based model 15500 25901
generating big-batches for subgraph-based model 16000 25901
generating big-batches for subgraph-based model 16500 25901
generating big-batches for subgraph-based model 17000 25901
generating big-batches for subgraph-based model 17500 25901
generating big-batches for subgraph-based model 18000 25901
generating big-batches for subgraph-based model 18500 25901
generating big-batches for subgraph-based model 19000 25901
generating big-batches for subgraph-based model 19500 25901
generating big-batches for subgraph-based model 20000 25901
generating big-batches for subgraph-base

### Result on the testset for inductive link prediction

We use the testset for inductive link prediction.

In [1]:
data_name = 'WN18RR_v3'
model_id = 'main_10'

In [2]:
#difine the names for saving
model_name = 'Model_' + model_id + '_' + data_name
one_hop_model_name = 'One_hop_model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [3]:
ids_name

'IDs_main_10_fb237_v3'

In [4]:
one_hop_model_name

'One_hop_model_main_10_fb237_v3'

In [5]:
model_name

'Model_main_10_fb237_v3'

In [6]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
import pickle

from collections import defaultdict
from copy import deepcopy
from sklearn.utils import shuffle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [7]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = set()
            
            one_hop[s].add((r,t))
            
            if t not in one_hop:
                
                one_hop[t] = set()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            one_hop[t].add((r_inv,s))
            
        #change each set in one_hop to list
        for e in one_hop:
            
            one_hop[e] = list(one_hop[e])

In [8]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, size_bd=50, threshold=20000):
                
        self.size_bd = size_bd #size bound limit the number of paths to a target entity t
        
        #number of times paths with specific length been performed for recursion
        self.threshold = threshold
        
    '''
    Given an entity s, the function will find the paths from s to other entities, using recursion.
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on existing entities')

        #here is the result dict. Its key is each entity t sharing paths from s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #qualified_t contains the types of t we want to consider,
        #that is, what t will be added to the result set.
        qualified_t = set()

        #under this mode, we will only consider the direct neighbour of s
        if mode == 'direct_neighbour':
        
            for Tuple in one_hop[s]:
            
                t = Tuple[1]
                
                qualified_t.add(t)
        
        #under this mode, we will only consider one specified entity t
        elif mode == 'target_specified':
            
            qualified_t.add(t_input)
        
        #under this mode, we will consider any entity
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                qualified_t.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {s, e1, ..., ek-1, node}
        from s to this node, we will further find the direct neighbor t' of this node. 
        If t' is not an on-path entity (not among s, e1,...ek-1, node), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict):

            #when the current path is within lower_bd and upper_bd, 
            #and the node is among the qualified t, and it has not been fill of paths w.r.t size_limit,
            #we will add this path to the node
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in qualified_t) and (len(res[node]) < self.size_bd):
                
                res[node].add(tuple(path))
                    
            #won't start new recursions if the current path length already reaches upper limit
            #or the number of recursions performed on this length has reached the limit
            if (len(path) < upper_bd) and (count_dict[len(path)] <= self.threshold):
                                
                #temp list is the id list for us to go-over one_hop[node]
                temp_list = [i for i in range(len(one_hop[node]))]
                random.shuffle(temp_list) #so we random-shuffle the list
                
                #only take 50 recursions if there are too many (r,t)
                for i in temp_list[:50]:
                    
                    #obtain tuple of (r,t)
                    Tuple = one_hop[node][i]
                    r, t = Tuple[0], Tuple[1]
                    
                    #add to count_dict even if eventually this step not proceed
                    count_dict[len(path)] += 1
                    
                    #if t not on the path and we not exceed the computation threshold, 
                    #then finally proceed to next recursion
                    if (t not in on_path_en) and (count_dict[len(path)] <= self.threshold):

                        helper(t, path + [r], on_path_en.union({t}), res, qualified_t, 
                               lower_bd, upper_bd, one_hop, count_dict)

        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict)
        
        return(res, count_dict)

In [9]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [10]:
#load ids and relation/entity dicts
with open('../weight_bin/' + ids_name + '.pickle', 'rb') as handle:
    Dict = pickle.load(handle)
    
#save training data
one_hop = Dict['one_hop']
data = Dict['data']
s_t_r = Dict['s_t_r']

#save valid data
one_hop_valid = Dict['one_hop_valid']
data_valid = Dict['data_valid']
s_t_r_valid = Dict['s_t_r_valid']

#save test data
one_hop_test = Dict['one_hop_test']
data_test = Dict['data_test']
s_t_r_test = Dict['s_t_r_test']

#save shared dictionaries
entity2id = Dict['entity2id']
id2entity = Dict['id2entity']
relation2id = Dict['relation2id']
id2relation = Dict['id2relation']

#we want to keep the initial entity/relation dicts before adding new entities
entity2id_ini = deepcopy(entity2id)
id2entity_ini = deepcopy(id2entity)
relation2id_ini = deepcopy(relation2id)
id2relation_ini = deepcopy(id2relation)

num_r = len(id2relation)
num_r

430

In [11]:
ids_name

'IDs_main_10_fb237_v3'

In [12]:
model_name

'Model_main_10_fb237_v3'

In [13]:
#load the model
model = keras.models.load_model('../weight_bin/' + model_name + '.h5')

2023-02-28 20:48:43.663140: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [14]:
#load the one-hop neighbor model
model_2 = keras.models.load_model('../weight_bin/' + one_hop_model_name + '.h5')

In [15]:
ind_train_path = '../data/' + data_name + '_ind/train.txt'
ind_valid_path = '../data/' + data_name + '_ind/valid.txt'
ind_test_path = '../data/' + data_name + '_ind/test.txt'

In [16]:
#load the test dataset
one_hop_ind = dict() 
data_ind = set()
s_t_r_ind = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_train_path, 
                        one_hop_ind, data_ind, s_t_r_ind,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [17]:
print(size_0, size_1, len(data_ind))

3668 6169 7406


In [18]:
#load the test dataset
one_hop_ind_test = dict() 
data_ind_test = set()
s_t_r_ind_test = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_test_path, 
                        one_hop_ind_test, data_ind_test, s_t_r_ind_test,
                        entity2id, id2entity, relation2id, id2relation)


len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [19]:
print(size_0, size_1, len(data_ind_test))

6169 6169 865


In [20]:
#load the validation for existing triple removal when ranking
one_hop_ind_valid = dict() 
data_ind_valid = set()
s_t_r_ind_valid = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_valid_path, 
                        one_hop_ind_valid, data_ind_valid, s_t_r_ind_valid,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [21]:
print(size_0, size_1, len(data_ind_valid))

6169 6169 866


In [22]:
print(len(entity2id), len(entity2id_ini))

6169 3668


In [23]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_ind_test:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [24]:
overlapping = 0

for ele in data_ind_valid:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [25]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_ind:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [28]:
#the function to do path-based relation scoring
def path_based_relation_scoring(s, t, lower_bd, upper_bd, one_hop, id2relation, model):
    
    path_holder = set()
    
    for iteration in range(3):
    
        result, length_dict = Class_2.obtain_paths('target_specified', 
                                                   s, t, lower_bd, upper_bd, one_hop)
        if t in result:
            
            for path in result[t]:
                
                path_holder.add(path)
                
        del(result, length_dict)
    
    path_holder = list(path_holder)
    random.shuffle(path_holder)
    
    score_dict = defaultdict(float)
    
    count = 0
    
    if len(path_holder) >= 3:
    
        #iterate over path_1
        while count < 10:

            temp_pair = random.sample(path_holder, 3)

            path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

            list_1 = list()
            list_2 = list()
            list_3 = list()
            list_r = list()

            for i in range(len(id2relation)):

                if i not in id2relation:

                    raise ValueError ('error when generating id2relation')
                
                #only care about initial relations
                if i % 2 == 0:

                    list_1.append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    list_2.append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    list_3.append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))
                    list_r.append([i])
            
            #change to arrays
            input_1 = np.array(list_1)
            input_2 = np.array(list_2)
            input_3 = np.array(list_3)
            input_r = np.array(list_r)

            pred = model.predict([input_1, input_2, input_3, input_r], verbose = 0)

            for i in range(pred.shape[0]):
                #need to times 2 to go back to relation id from pred position
                score_dict[2*i] += float(pred[i])

            count += 1
    
    print(len(score_dict), len(path_holder))

    return(score_dict)

In [29]:
#subgraph based relation scoring
def subgraph_relation_scoring(s, t, lower_bd, upper_bd, one_hop, id2relation, model_2):
    
    #lists holding the input to the network
    list_s_1 = list()
    list_s_2 = list()
    list_s_3 = list()
    list_t_1 = list()
    list_t_2 = list()
    list_t_3 = list()
    list_r = list()
    
    path_s, path_t = set(), set() #sets holding all the paths from s or t
    
    for iteration in range(3):
    
        #obtain the paths out from s or t by "any target" mode. That is, 
        result_s, length_dict_s = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)
        result_t, length_dict_t = Class_2.obtain_paths('any_target', t, 'any', lower_bd, upper_bd, one_hop)

        #add paths to the source/target path_set
        for e in result_s:
            for path in result_s[e]:
                path_s.add(path)
        for e in result_t:
            for path in result_t[e]:
                path_t.add(path)
                
        del(result_s, length_dict_s, result_t, length_dict_t)
    
    #final output: the score dict
    score_dict = defaultdict(float)
    
    #see if both path_s and path_t have at least three paths
    if len(path_s) >= 3 and len(path_t) >= 3:

        #change to lists
        path_s, path_t = list(path_s), list(path_t)
        
        count = 0
        while count < 10:

            #randomly obtain three paths
            temp_s = random.sample(path_s, 3)
            temp_t = random.sample(path_t, 3)
            s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
            t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]
            
            #add all forward (initial relation)
            for i in range(len(id2relation)):

                if i not in id2relation:

                    raise ValueError ('error when generating id2relation')
                    
                if i % 2 == 0:

                    #append the paths: note that we add the space holder id at the end of the shorter path
                    list_s_1.append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                    list_s_2.append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                    list_s_3.append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                    
                    list_t_1.append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                    list_t_2.append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                    list_t_3.append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                    
                    list_r.append([i])
                
            #change to arrays
            input_s_1 = np.array(list_s_1)
            input_s_2 = np.array(list_s_2)
            input_s_3 = np.array(list_s_3)
            input_t_1 = np.array(list_t_1)
            input_t_2 = np.array(list_t_2)
            input_t_3 = np.array(list_t_3)
            input_r = np.array(list_r)
            
            pred = model_2.predict([input_s_1, input_s_2, input_s_3,
                                    input_t_1, input_t_2, input_t_3, input_r], verbose = 0)

            for i in range(pred.shape[0]):
                #need to times 2 to go back to relation id from pred position
                score_dict[2*i] += float(pred[i])

            count += 1
            
    print(len(score_dict), len(path_s), len(path_t))
        
    return(score_dict)

#### Not fine tuned 

In [30]:
########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_ind_test), min(len(data_ind_test), 500))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    #first run the path-based scoring
    score_dict = path_based_relation_scoring(s_true, t_true, 1, 10, one_hop_ind, id2relation, model)
    
    #run the one-hop neighbour based scoring when not enough paths
    if len(score_dict) == 0:
        del(score_dict)
        score_dict = subgraph_relation_scoring(s_true, t_true, 1, 3, one_hop_ind, id2relation, model_2)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        #again, we only care about initial relation prediciton
        if r % 2 == 0:
        
            if r in score_dict:

                temp_list.append([score_dict[r], r])

            else:

                temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
            (s_true, sorted_list[p][1], t_true) in data_valid) or (
            (s_true, sorted_list[p][1], t_true) in data) or (
            (s_true, sorted_list[p][1], t_true) in data_ind) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_valid) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_test):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', r_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

215 114
checkcorrect 80 80 real score 9.974733293056488 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 0 500
215 9
checkcorrect 104 104 real score 9.734038293361664 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 1 500
215 150
checkcorrect 64 64 real score 9.966531872749329 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 2 500
215 150
checkcorrect 64 64 real score 9.967007339000702 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 3 500
215 149
checkcorrect 2 2 real score 9.98240315914154 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 4 500
215 150
checkcorrect 176 176 real score 9.95535123348236 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 5 500
215 150
checkcorrect 102 102 real score 9.955845892429352 Hits@1 0.8571428571428571 Hits@3 1.0 Hits@10 1.0 MRR 0.9285714285714286 cur_rank 1 abs_cur_rank 1 total_

215 150
checkcorrect 22 22 real score 9.986281216144562 Hits@1 0.8367346938775511 Hits@3 0.9591836734693877 Hits@10 1.0 MRR 0.9013605442176872 cur_rank 0 abs_cur_rank 0 total_num 48 500
215 150
checkcorrect 22 22 real score 9.986466348171234 Hits@1 0.84 Hits@3 0.96 Hits@10 1.0 MRR 0.9033333333333334 cur_rank 0 abs_cur_rank 0 total_num 49 500
215 6
checkcorrect 62 62 real score 9.747672080993652 Hits@1 0.8235294117647058 Hits@3 0.9607843137254902 Hits@10 1.0 MRR 0.8954248366013072 cur_rank 1 abs_cur_rank 1 total_num 50 500
215 17
checkcorrect 144 144 real score 9.719329953193665 Hits@1 0.8269230769230769 Hits@3 0.9615384615384616 Hits@10 1.0 MRR 0.8974358974358976 cur_rank 0 abs_cur_rank 0 total_num 51 500
215 150
checkcorrect 14 14 real score 9.99780684709549 Hits@1 0.8301886792452831 Hits@3 0.9622641509433962 Hits@10 1.0 MRR 0.8993710691823901 cur_rank 0 abs_cur_rank 0 total_num 52 500
215 134
checkcorrect 144 144 real score 9.917206883430481 Hits@1 0.8333333333333334 Hits@3 0.9629629

215 94
checkcorrect 62 62 real score 8.182882726192474 Hits@1 0.8260869565217391 Hits@3 0.9456521739130435 Hits@10 0.9891304347826086 MRR 0.889130434782609 cur_rank 0 abs_cur_rank 1 total_num 91 500
215 150
checkcorrect 10 10 real score 9.97201156616211 Hits@1 0.8279569892473119 Hits@3 0.946236559139785 Hits@10 0.989247311827957 MRR 0.8903225806451616 cur_rank 0 abs_cur_rank 0 total_num 92 500
215 145
checkcorrect 64 64 real score 9.966530561447144 Hits@1 0.8297872340425532 Hits@3 0.9468085106382979 Hits@10 0.9893617021276596 MRR 0.8914893617021279 cur_rank 0 abs_cur_rank 0 total_num 93 500
215 70
checkcorrect 2 2 real score 9.982407331466675 Hits@1 0.8315789473684211 Hits@3 0.9473684210526315 Hits@10 0.9894736842105263 MRR 0.8926315789473687 cur_rank 0 abs_cur_rank 0 total_num 94 500
215 146
checkcorrect 206 206 real score 8.210960507392883 Hits@1 0.8333333333333334 Hits@3 0.9479166666666666 Hits@10 0.9895833333333334 MRR 0.8937500000000003 cur_rank 0 abs_cur_rank 0 total_num 95 500
2

215 146
checkcorrect 150 150 real score 8.333629846572876 Hits@1 0.835820895522388 Hits@3 0.9402985074626866 Hits@10 0.9925373134328358 MRR 0.8906716417910449 cur_rank 2 abs_cur_rank 2 total_num 133 500
215 150
checkcorrect 22 22 real score 9.986063957214355 Hits@1 0.837037037037037 Hits@3 0.9407407407407408 Hits@10 0.9925925925925926 MRR 0.8914814814814815 cur_rank 0 abs_cur_rank 0 total_num 134 500
215 109
checkcorrect 112 112 real score 9.996204614639282 Hits@1 0.8382352941176471 Hits@3 0.9411764705882353 Hits@10 0.9926470588235294 MRR 0.8922794117647059 cur_rank 0 abs_cur_rank 0 total_num 135 500
215 150
checkcorrect 40 40 real score 9.994498789310455 Hits@1 0.8394160583941606 Hits@3 0.9416058394160584 Hits@10 0.9927007299270073 MRR 0.893065693430657 cur_rank 0 abs_cur_rank 0 total_num 136 500
215 145
checkcorrect 144 144 real score 9.016707181930542 Hits@1 0.8405797101449275 Hits@3 0.9420289855072463 Hits@10 0.9927536231884058 MRR 0.893840579710145 cur_rank 0 abs_cur_rank 0 total_

215 146
checkcorrect 102 102 real score 9.922627866268158 Hits@1 0.84 Hits@3 0.9428571428571428 Hits@10 0.9885714285714285 MRR 0.8952142857142859 cur_rank 0 abs_cur_rank 0 total_num 174 500
0 2
2150 49 27
checkcorrect 190 190 real score 4.432106614112854 Hits@1 0.8352272727272727 Hits@3 0.9431818181818182 Hits@10 0.9886363636363636 MRR 0.8929687500000001 cur_rank 1 abs_cur_rank 1 total_num 175 500
215 149
checkcorrect 52 52 real score 9.94662219285965 Hits@1 0.8305084745762712 Hits@3 0.943502824858757 Hits@10 0.9887005649717514 MRR 0.8907485875706216 cur_rank 1 abs_cur_rank 1 total_num 176 500
215 33
checkcorrect 26 26 real score 9.97364330291748 Hits@1 0.8314606741573034 Hits@3 0.9438202247191011 Hits@10 0.9887640449438202 MRR 0.891362359550562 cur_rank 0 abs_cur_rank 0 total_num 177 500
215 8
checkcorrect 94 94 real score 9.612185180187225 Hits@1 0.8324022346368715 Hits@3 0.9441340782122905 Hits@10 0.9888268156424581 MRR 0.8919692737430169 cur_rank 0 abs_cur_rank 0 total_num 178 500


215 96
checkcorrect 22 22 real score 9.987213015556335 Hits@1 0.8379629629629629 Hits@3 0.9398148148148148 Hits@10 0.9861111111111112 MRR 0.8934349279835393 cur_rank 0 abs_cur_rank 0 total_num 215 500
215 97
checkcorrect 346 346 real score 6.360789597034454 Hits@1 0.8341013824884793 Hits@3 0.9354838709677419 Hits@10 0.9861751152073732 MRR 0.8902393753200206 cur_rank 4 abs_cur_rank 5 total_num 216 500
215 58
checkcorrect 48 48 real score 9.919833540916443 Hits@1 0.8302752293577982 Hits@3 0.9357798165137615 Hits@10 0.9862385321100917 MRR 0.8884492864424058 cur_rank 1 abs_cur_rank 1 total_num 217 500
215 149
checkcorrect 68 68 real score 9.954762637615204 Hits@1 0.8310502283105022 Hits@3 0.9360730593607306 Hits@10 0.9863013698630136 MRR 0.8889586504312533 cur_rank 0 abs_cur_rank 0 total_num 218 500
215 150
checkcorrect 102 102 real score 9.934009790420532 Hits@1 0.8318181818181818 Hits@3 0.9363636363636364 Hits@10 0.9863636363636363 MRR 0.889463383838384 cur_rank 0 abs_cur_rank 0 total_nu

215 149
checkcorrect 362 362 real score 9.591470420360565 Hits@1 0.8171206225680934 Hits@3 0.9299610894941635 Hits@10 0.9844357976653697 MRR 0.8796620035822371 cur_rank 0 abs_cur_rank 0 total_num 256 500
215 129
checkcorrect 112 112 real score 9.984529495239258 Hits@1 0.8178294573643411 Hits@3 0.9302325581395349 Hits@10 0.9844961240310077 MRR 0.8801284299249417 cur_rank 0 abs_cur_rank 0 total_num 257 500
215 150
checkcorrect 28 28 real score 9.996497094631195 Hits@1 0.8185328185328186 Hits@3 0.9305019305019305 Hits@10 0.9845559845559846 MRR 0.880591254519826 cur_rank 0 abs_cur_rank 0 total_num 258 500
215 131
checkcorrect 0 0 real score 9.930523455142975 Hits@1 0.8192307692307692 Hits@3 0.9307692307692308 Hits@10 0.9846153846153847 MRR 0.881050518925519 cur_rank 0 abs_cur_rank 0 total_num 259 500
215 147
checkcorrect 144 144 real score 9.77740228176117 Hits@1 0.8199233716475096 Hits@3 0.9310344827586207 Hits@10 0.9846743295019157 MRR 0.8815062640637354 cur_rank 0 abs_cur_rank 0 total_n

215 150
checkcorrect 40 40 real score 9.994495689868927 Hits@1 0.8120805369127517 Hits@3 0.9194630872483222 Hits@10 0.9798657718120806 MRR 0.873704845676322 cur_rank 0 abs_cur_rank 0 total_num 297 500
215 41
checkcorrect 340 340 real score 8.483961939811707 Hits@1 0.8127090301003345 Hits@3 0.919732441471572 Hits@10 0.979933110367893 MRR 0.8741272374968025 cur_rank 0 abs_cur_rank 0 total_num 298 500
215 131
checkcorrect 2 2 real score 9.98238354921341 Hits@1 0.8133333333333334 Hits@3 0.92 Hits@10 0.98 MRR 0.8745468133718132 cur_rank 0 abs_cur_rank 0 total_num 299 500
215 150
checkcorrect 2 2 real score 9.982386112213135 Hits@1 0.813953488372093 Hits@3 0.920265780730897 Hits@10 0.9800664451827242 MRR 0.8749636013672557 cur_rank 0 abs_cur_rank 0 total_num 300 500
215 150
checkcorrect 28 28 real score 9.996475577354431 Hits@1 0.8145695364238411 Hits@3 0.9205298013245033 Hits@10 0.9801324503311258 MRR 0.8753776291772979 cur_rank 0 abs_cur_rank 0 total_num 301 500
215 82
checkcorrect 0 0 rea

215 150
checkcorrect 102 102 real score 9.95671546459198 Hits@1 0.8053097345132744 Hits@3 0.9144542772861357 Hits@10 0.976401179941003 MRR 0.8693361448449942 cur_rank 0 abs_cur_rank 0 total_num 338 500
215 120
checkcorrect 86 86 real score 9.944234371185303 Hits@1 0.8029411764705883 Hits@3 0.9147058823529411 Hits@10 0.9764705882352941 MRR 0.8682498620660384 cur_rank 1 abs_cur_rank 1 total_num 339 500
215 42
checkcorrect 30 30 real score 9.982017755508423 Hits@1 0.8035190615835777 Hits@3 0.9149560117302052 Hits@10 0.9765395894428153 MRR 0.8686362261069004 cur_rank 0 abs_cur_rank 0 total_num 340 500
215 84
checkcorrect 2 2 real score 9.982370853424072 Hits@1 0.804093567251462 Hits@3 0.9152046783625731 Hits@10 0.9766081871345029 MRR 0.869020330708927 cur_rank 0 abs_cur_rank 0 total_num 341 500
215 150
checkcorrect 122 122 real score 9.960891664028168 Hits@1 0.8046647230320699 Hits@3 0.9154518950437318 Hits@10 0.9766763848396501 MRR 0.8694021956339738 cur_rank 0 abs_cur_rank 0 total_num 34

215 150
checkcorrect 362 362 real score 9.732589721679688 Hits@1 0.8052631578947368 Hits@3 0.9157894736842105 Hits@10 0.9789473684210527 MRR 0.8705832099187361 cur_rank 0 abs_cur_rank 0 total_num 379 500
215 122
checkcorrect 142 142 real score 7.655006408691406 Hits@1 0.8031496062992126 Hits@3 0.916010498687664 Hits@10 0.979002624671916 MRR 0.8691731052557822 cur_rank 2 abs_cur_rank 2 total_num 380 500
215 103
checkcorrect 102 102 real score 9.988345324993134 Hits@1 0.8036649214659686 Hits@3 0.9162303664921466 Hits@10 0.9790575916230366 MRR 0.8695155840378351 cur_rank 0 abs_cur_rank 0 total_num 381 500
215 150
checkcorrect 142 142 real score 8.141679584980011 Hits@1 0.804177545691906 Hits@3 0.9164490861618799 Hits@10 0.97911227154047 MRR 0.8698562744189374 cur_rank 0 abs_cur_rank 0 total_num 382 500
215 150
checkcorrect 64 64 real score 9.966722190380096 Hits@1 0.8046875 Hits@3 0.9166666666666666 Hits@10 0.9791666666666666 MRR 0.8701951903709714 cur_rank 0 abs_cur_rank 0 total_num 383 

215 75
checkcorrect 62 62 real score 9.767525494098663 Hits@1 0.7985781990521327 Hits@3 0.9170616113744076 Hits@10 0.9786729857819905 MRR 0.8665425724253843 cur_rank 1 abs_cur_rank 1 total_num 421 500
215 150
checkcorrect 68 68 real score 9.954949080944061 Hits@1 0.7990543735224587 Hits@3 0.91725768321513 Hits@10 0.9787234042553191 MRR 0.8668580746182322 cur_rank 0 abs_cur_rank 0 total_num 422 500
215 120
checkcorrect 200 200 real score 8.512043356895447 Hits@1 0.7971698113207547 Hits@3 0.9174528301886793 Hits@10 0.9787735849056604 MRR 0.8659928433101702 cur_rank 1 abs_cur_rank 2 total_num 423 500
215 150
checkcorrect 194 194 real score 9.40682327747345 Hits@1 0.7952941176470588 Hits@3 0.9176470588235294 Hits@10 0.9788235294117648 MRR 0.8651316836788522 cur_rank 1 abs_cur_rank 3 total_num 424 500
215 150
checkcorrect 142 142 real score 8.841113150119781 Hits@1 0.795774647887324 Hits@3 0.9178403755868545 Hits@10 0.9788732394366197 MRR 0.8654482759706859 cur_rank 0 abs_cur_rank 0 total_n

215 147
checkcorrect 26 26 real score 9.974011242389679 Hits@1 0.7926565874730022 Hits@3 0.9114470842332614 Hits@10 0.9762419006479481 MRR 0.861664506302836 cur_rank 0 abs_cur_rank 0 total_num 462 500
215 150
checkcorrect 340 340 real score 8.607909739017487 Hits@1 0.790948275862069 Hits@3 0.9116379310344828 Hits@10 0.9762931034482759 MRR 0.8605258615335052 cur_rank 2 abs_cur_rank 2 total_num 463 500
215 147
checkcorrect 144 144 real score 9.919025242328644 Hits@1 0.789247311827957 Hits@3 0.9118279569892473 Hits@10 0.9763440860215054 MRR 0.8597505371000997 cur_rank 1 abs_cur_rank 1 total_num 464 500
215 9
checkcorrect 94 94 real score 9.865845739841461 Hits@1 0.7896995708154506 Hits@3 0.9120171673819742 Hits@10 0.9763948497854077 MRR 0.8600515016127604 cur_rank 0 abs_cur_rank 0 total_num 465 500
215 150
checkcorrect 86 86 real score 9.945280611515045 Hits@1 0.7880085653104925 Hits@3 0.9122055674518201 Hits@10 0.9764453961456103 MRR 0.8592805133866089 cur_rank 1 abs_cur_rank 1 total_num

#### Fine tuned

In [31]:
#function to build the big batche for path-based training
def build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    count = 0
    for s in existing_ids:
        
        #impliment the path finding algorithm to find paths between s and t
        result, length_dict = Class_2.obtain_paths('direct_neighbour', s, 'nb', lower_bd, upper_bd, one_hop)
        
        #proceed only if at least three paths are between s and t
        for t in result:
            
            if len(s_t_r[(s,t)]) == 0:
                
                raise ValueError(s,t,id2entity[s], id2entity[t])

            #we are only interested in forward link in relation prediciton
            ini_r_list = list()
            
            #obtain initial relations between s and t
            for r in s_t_r[(s,t)]:
                if r % 2 == 0:#initial relation id is always an even number
                    ini_r_list.append(r)
            
            #if there exist more than three paths between s and t, 
            #and inital connection between s and t exists,
            #and not every r in the relation dictionary exists between s and t (although this is rare)
            #we then proceed
            if len(result[t]) >= 3 and len(ini_r_list) > 0 and len(ini_r_list) < int(num_ini_r):
                
                #obtain the list form of all the paths from s to t
                temp_path_list = list(result[t])

                temp_pair = random.sample(temp_path_list, 3)

                path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

                #####positive#####################
                #append the paths: note that we add the space holder id at the end of the shorter path
                x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                #append relation
                r = random.choice(ini_r_list)
                x_r_list.append([r])
                y_list.append(1.)

                #####negative#####################
                #append the paths: note that we add the space holder id at the end
                #of the shorter path
                x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                #append relation
                neg_r_list = list(ini_r_id_set.difference(set(ini_r_list)))
                r_ran = random.choice(neg_r_list)
                x_r_list.append([r_ran])
                y_list.append(0.)
        
        count += 1
        if count % 100 == 0:
            print('generating big-batches for path-based model', count, len(existing_ids))

In [32]:
#function to build the big-batch for one-hope neighbor training
def build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_s_list, x_t_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
    
    data = list(data)    
    data = shuffle(data)
    
    for i_0 in range(len(data)):
        
        triple = data[i_0]
        
        s, r, t = triple[0], triple[1], triple[2] #obtain entities and relation IDs
        
        path_s, path_t = set(), set() #sets holding all the paths from s or t
        
        #obtain the paths out from s or t by "any target" mode. That is, 
        result_s, length_dict_s = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)
        result_t, length_dict_t = Class_2.obtain_paths('any_target', t, 'any', lower_bd, upper_bd, one_hop)
        
        #add paths to the source/target path_set
        for e in result_s:
            for path in result_s[e]:
                path_s.add(path)
        for e in result_t:
            for path in result_t[e]:
                path_t.add(path)
        
        #see if both path_s and path_t have at least three paths
        if len(path_s) >= 3 and len(path_t) >= 3:
            
            #change to lists
            path_s, path_t = list(path_s), list(path_t)
            
            #randomly obtain three paths
            temp_s = random.sample(path_s, 3)
            temp_t = random.sample(path_t, 3)
            s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
            t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]
            
            #####positive step###########
            #append the paths: note that we add the space holder id at the end of the shorter path
            x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
            x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
            x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
            
            x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
            x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
            x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

            #append relation
            x_r_list.append([r])
            y_list.append(1.)
            
            #####negative step###########
            #append the paths: note that we add the space holder id at the end of the shorter path
            x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
            x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
            x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
            
            x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
            x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
            x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

            #append relation
            neg_r_list = list(ini_r_id_set.difference({r}))
            r_ran = random.choice(neg_r_list)
            x_r_list.append([r_ran])
            y_list.append(0.)
        
        if i_0 % 500 == 0:
            print('generating big-batches for subgraph-based model', i_0, len(data))

In [33]:
###fine tune the path-based model
lower_bd = 1
upper_bd = 10
batch_size = 32

#define the training lists
train_p_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_path(lower_bd, upper_bd, data_ind, one_hop_ind, s_t_r_ind,
                      train_p_list, train_r_list, train_y_list,
                      relation2id, entity2id, id2relation, id2entity)   

#######################################
###do the training#####################

#generate the input arrays
x_train_1 = np.asarray(train_p_list['1'], dtype='int')
x_train_2 = np.asarray(train_p_list['2'], dtype='int')
x_train_3 = np.asarray(train_p_list['3'], dtype='int')
x_train_r = np.asarray(train_r_list, dtype='int')
y_train = np.asarray(train_y_list, dtype='int')

model.fit([x_train_1, x_train_2, x_train_3, x_train_r], y_train,
           batch_size=batch_size, epochs=2)

generating big-batches for path-based model 100 2501
generating big-batches for path-based model 200 2501
generating big-batches for path-based model 300 2501
generating big-batches for path-based model 400 2501
generating big-batches for path-based model 500 2501
generating big-batches for path-based model 600 2501
generating big-batches for path-based model 700 2501
generating big-batches for path-based model 800 2501
generating big-batches for path-based model 900 2501
generating big-batches for path-based model 1000 2501
generating big-batches for path-based model 1100 2501
generating big-batches for path-based model 1200 2501
generating big-batches for path-based model 1300 2501
generating big-batches for path-based model 1400 2501
generating big-batches for path-based model 1500 2501
generating big-batches for path-based model 1600 2501
generating big-batches for path-based model 1700 2501
generating big-batches for path-based model 1800 2501
generating big-batches for path-based

<keras.callbacks.History at 0x7f98100b0a00>

In [34]:
###fine tune the subgraph model
lower_bd = 1
upper_bd = 3
batch_size = 32

#define the training lists
train_s_list, train_t_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_subgraph(lower_bd, upper_bd, data_ind, one_hop_ind, s_t_r_ind,
                      train_s_list, train_t_list, train_r_list, train_y_list,
                      relation2id, entity2id, id2relation, id2entity)

#######################################
###do the training#####################

#generate the input arrays
x_train_s_1 = np.asarray(train_s_list['1'], dtype='int')
x_train_s_2 = np.asarray(train_s_list['2'], dtype='int')
x_train_s_3 = np.asarray(train_s_list['3'], dtype='int')

x_train_t_1 = np.asarray(train_t_list['1'], dtype='int')
x_train_t_2 = np.asarray(train_t_list['2'], dtype='int')
x_train_t_3 = np.asarray(train_t_list['3'], dtype='int')

x_train_r = np.asarray(train_r_list, dtype='int')
y_train = np.asarray(train_y_list, dtype='int')

model_2.fit([x_train_s_1, x_train_s_2, x_train_s_3, 
             x_train_t_1, x_train_t_2, x_train_t_3, x_train_r], y_train,
             batch_size=batch_size, epochs=2)

generating big-batches for subgraph-based model 0 7406
generating big-batches for subgraph-based model 500 7406
generating big-batches for subgraph-based model 1000 7406
generating big-batches for subgraph-based model 1500 7406
generating big-batches for subgraph-based model 2000 7406
generating big-batches for subgraph-based model 2500 7406
generating big-batches for subgraph-based model 3000 7406
generating big-batches for subgraph-based model 3500 7406
generating big-batches for subgraph-based model 4000 7406
generating big-batches for subgraph-based model 4500 7406
generating big-batches for subgraph-based model 5000 7406
generating big-batches for subgraph-based model 5500 7406
generating big-batches for subgraph-based model 6000 7406
generating big-batches for subgraph-based model 6500 7406
generating big-batches for subgraph-based model 7000 7406
Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7f9863457ac0>

In [35]:
########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_ind_test), min(len(data_ind_test), 500))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    #first run the path-based scoring
    score_dict = path_based_relation_scoring(s_true, t_true, 1, 10, one_hop_ind, id2relation, model)
    
    #run the one-hop neighbour based scoring when not enough paths
    if len(score_dict) == 0:
        del(score_dict)
        score_dict = subgraph_relation_scoring(s_true, t_true, 1, 3, one_hop_ind, id2relation, model_2)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        #again, we only care about initial relation prediciton
        if r % 2 == 0:
        
            if r in score_dict:

                temp_list.append([score_dict[r], r])

            else:

                temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
            (s_true, sorted_list[p][1], t_true) in data_valid) or (
            (s_true, sorted_list[p][1], t_true) in data) or (
            (s_true, sorted_list[p][1], t_true) in data_ind) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_valid) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_test):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', r_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

215 150
checkcorrect 102 102 real score 9.877023696899414 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 0 500
215 145
checkcorrect 22 22 real score 9.97813481092453 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 1 500
215 9
checkcorrect 122 122 real score 9.909422099590302 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 2 500
215 29
checkcorrect 216 216 real score 9.830390691757202 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 3 500
215 121
checkcorrect 42 42 real score 9.468562245368958 Hits@1 0.8 Hits@3 1.0 Hits@10 1.0 MRR 0.8666666666666666 cur_rank 2 abs_cur_rank 2 total_num 4 500
215 150
checkcorrect 144 144 real score 9.849848568439484 Hits@1 0.8333333333333334 Hits@3 1.0 Hits@10 1.0 MRR 0.8888888888888888 cur_rank 0 abs_cur_rank 1 total_num 5 500
215 150
checkcorrect 118 118 real score 8.371740877628326 Hits@1 0.7142857142857143 Hits@3 0.8571428571428571 H

215 125
checkcorrect 28 28 real score 9.936674416065216 Hits@1 0.782608695652174 Hits@3 0.9130434782608695 Hits@10 0.9782608695652174 MRR 0.8548567977915803 cur_rank 0 abs_cur_rank 0 total_num 45 500
215 97
checkcorrect 202 202 real score 9.796805441379547 Hits@1 0.7659574468085106 Hits@3 0.9148936170212766 Hits@10 0.9787234042553191 MRR 0.8473066531577169 cur_rank 1 abs_cur_rank 1 total_num 46 500
215 150
checkcorrect 172 172 real score 6.849293768405914 Hits@1 0.75 Hits@3 0.8958333333333334 Hits@10 0.9791666666666666 MRR 0.8326306216931217 cur_rank 6 abs_cur_rank 8 total_num 47 500
215 150
checkcorrect 102 102 real score 9.711022019386292 Hits@1 0.7551020408163265 Hits@3 0.8979591836734694 Hits@10 0.9795918367346939 MRR 0.8360463232912213 cur_rank 0 abs_cur_rank 0 total_num 48 500
215 150
checkcorrect 62 62 real score 9.881958782672882 Hits@1 0.74 Hits@3 0.9 Hits@10 0.98 MRR 0.8293253968253969 cur_rank 1 abs_cur_rank 1 total_num 49 500
215 150
checkcorrect 2 2 real score 9.9960463643

215 73
checkcorrect 122 122 real score 9.927406251430511 Hits@1 0.7613636363636364 Hits@3 0.8977272727272727 Hits@10 0.9772727272727273 MRR 0.8410218253968254 cur_rank 0 abs_cur_rank 0 total_num 87 500
215 150
checkcorrect 0 0 real score 9.988512694835663 Hits@1 0.7640449438202247 Hits@3 0.898876404494382 Hits@10 0.9775280898876404 MRR 0.8428080970215801 cur_rank 0 abs_cur_rank 0 total_num 88 500
215 137
checkcorrect 22 22 real score 9.976214468479156 Hits@1 0.7666666666666667 Hits@3 0.9 Hits@10 0.9777777777777777 MRR 0.8445546737213404 cur_rank 0 abs_cur_rank 0 total_num 89 500
215 150
checkcorrect 154 154 real score 7.347459554672241 Hits@1 0.7582417582417582 Hits@3 0.8901098901098901 Hits@10 0.978021978021978 MRR 0.8380211058782487 cur_rank 3 abs_cur_rank 4 total_num 90 500
215 150
checkcorrect 64 64 real score 9.998583674430847 Hits@1 0.7608695652173914 Hits@3 0.8913043478260869 Hits@10 0.9782608695652174 MRR 0.839781746031746 cur_rank 0 abs_cur_rank 0 total_num 91 500
215 150
chec

215 100
checkcorrect 254 254 real score 9.933134078979492 Hits@1 0.7846153846153846 Hits@3 0.9153846153846154 Hits@10 0.9846153846153847 MRR 0.8609737484737485 cur_rank 0 abs_cur_rank 0 total_num 129 500
215 106
checkcorrect 42 42 real score 9.65622067451477 Hits@1 0.7786259541984732 Hits@3 0.916030534351145 Hits@10 0.9847328244274809 MRR 0.8569459590451957 cur_rank 2 abs_cur_rank 2 total_num 130 500
215 13
checkcorrect 14 14 real score 9.980480253696442 Hits@1 0.7803030303030303 Hits@3 0.9166666666666666 Hits@10 0.9848484848484849 MRR 0.8580297017797017 cur_rank 0 abs_cur_rank 0 total_num 131 500
215 150
checkcorrect 362 362 real score 9.882170140743256 Hits@1 0.7819548872180451 Hits@3 0.9172932330827067 Hits@10 0.9849624060150376 MRR 0.8590971476309822 cur_rank 0 abs_cur_rank 0 total_num 132 500
215 18
checkcorrect 84 84 real score 9.986371219158173 Hits@1 0.7835820895522388 Hits@3 0.917910447761194 Hits@10 0.9850746268656716 MRR 0.8601486614546315 cur_rank 0 abs_cur_rank 0 total_num

215 126
checkcorrect 10 10 real score 9.970841407775879 Hits@1 0.7894736842105263 Hits@3 0.9181286549707602 Hits@10 0.9883040935672515 MRR 0.8633133760326743 cur_rank 0 abs_cur_rank 0 total_num 170 500
215 150
checkcorrect 362 362 real score 9.872535526752472 Hits@1 0.7906976744186046 Hits@3 0.9186046511627907 Hits@10 0.9883720930232558 MRR 0.8641080657069029 cur_rank 0 abs_cur_rank 0 total_num 171 500
215 150
checkcorrect 2 2 real score 9.996042132377625 Hits@1 0.791907514450867 Hits@3 0.9190751445086706 Hits@10 0.9884393063583815 MRR 0.8648935682172677 cur_rank 0 abs_cur_rank 0 total_num 172 500
0 1
2150 201 27
checkcorrect 360 360 real score 8.139489889144897 Hits@1 0.7873563218390804 Hits@3 0.9137931034482759 Hits@10 0.9885057471264368 MRR 0.8613596971355593 cur_rank 3 abs_cur_rank 3 total_num 173 500
215 59
checkcorrect 422 422 real score 1.7312836796045303 Hits@1 0.7828571428571428 Hits@3 0.9085714285714286 Hits@10 0.9828571428571429 MRR 0.8567737761771376 cur_rank 16 abs_cur_ran

215 150
checkcorrect 28 28 real score 9.939910769462585 Hits@1 0.7783018867924528 Hits@3 0.9056603773584906 Hits@10 0.9811320754716981 MRR 0.8516438965523316 cur_rank 0 abs_cur_rank 0 total_num 211 500
215 129
checkcorrect 22 22 real score 9.980900168418884 Hits@1 0.7793427230046949 Hits@3 0.9061032863849765 Hits@10 0.9812206572769953 MRR 0.852340404080255 cur_rank 0 abs_cur_rank 0 total_num 212 500
215 150
checkcorrect 144 144 real score 9.15781420469284 Hits@1 0.7757009345794392 Hits@3 0.9018691588785047 Hits@10 0.9813084112149533 MRR 0.8495257292948333 cur_rank 3 abs_cur_rank 5 total_num 213 500
215 145
checkcorrect 342 342 real score 9.69904226064682 Hits@1 0.772093023255814 Hits@3 0.9023255813953488 Hits@10 0.9813953488372092 MRR 0.8471248344298961 cur_rank 2 abs_cur_rank 2 total_num 214 500
215 49
checkcorrect 64 64 real score 9.999322891235352 Hits@1 0.7731481481481481 Hits@3 0.9027777777777778 Hits@10 0.9814814814814815 MRR 0.847832589826054 cur_rank 0 abs_cur_rank 0 total_num 

215 150
checkcorrect 200 200 real score 8.45682978630066 Hits@1 0.7747035573122529 Hits@3 0.9090909090909091 Hits@10 0.9841897233201581 MRR 0.850064676727081 cur_rank 2 abs_cur_rank 2 total_num 252 500
215 150
checkcorrect 40 40 real score 9.980799198150635 Hits@1 0.7755905511811023 Hits@3 0.9094488188976378 Hits@10 0.984251968503937 MRR 0.8506549732753995 cur_rank 0 abs_cur_rank 0 total_num 253 500
215 150
checkcorrect 64 64 real score 9.999295890331268 Hits@1 0.7764705882352941 Hits@3 0.9098039215686274 Hits@10 0.984313725490196 MRR 0.8512406400468686 cur_rank 0 abs_cur_rank 0 total_num 254 500
215 89
checkcorrect 144 144 real score 9.81760549545288 Hits@1 0.7734375 Hits@3 0.91015625 Hits@10 0.984375 MRR 0.8498686062966855 cur_rank 1 abs_cur_rank 1 total_num 255 500
215 142
checkcorrect 68 68 real score 9.933445036411285 Hits@1 0.77431906614786 Hits@3 0.9105058365758755 Hits@10 0.9844357976653697 MRR 0.85045277514378 cur_rank 0 abs_cur_rank 0 total_num 256 500
215 78
checkcorrect 160

215 27
checkcorrect 150 150 real score 9.946824848651886 Hits@1 0.7789115646258503 Hits@3 0.9115646258503401 Hits@10 0.9829931972789115 MRR 0.853475914175194 cur_rank 0 abs_cur_rank 0 total_num 293 500
215 141
checkcorrect 62 62 real score 9.88294804096222 Hits@1 0.7762711864406779 Hits@3 0.911864406779661 Hits@10 0.9830508474576272 MRR 0.852277690737312 cur_rank 1 abs_cur_rank 1 total_num 294 500
215 150
checkcorrect 16 16 real score 9.978627502918243 Hits@1 0.777027027027027 Hits@3 0.9121621621621622 Hits@10 0.9831081081081081 MRR 0.8527767525929292 cur_rank 0 abs_cur_rank 0 total_num 295 500
215 143
checkcorrect 144 144 real score 9.458842098712921 Hits@1 0.7777777777777778 Hits@3 0.9124579124579124 Hits@10 0.9831649831649831 MRR 0.8532724537626499 cur_rank 0 abs_cur_rank 2 total_num 296 500
215 26
checkcorrect 70 70 real score 9.533487439155579 Hits@1 0.7751677852348994 Hits@3 0.9093959731543624 Hits@10 0.9832214765100671 MRR 0.8512480495553928 cur_rank 3 abs_cur_rank 3 total_num 2

215 146
checkcorrect 2 2 real score 9.995907545089722 Hits@1 0.7761194029850746 Hits@3 0.9074626865671642 Hits@10 0.9850746268656716 MRR 0.851138634558301 cur_rank 0 abs_cur_rank 0 total_num 334 500
215 150
checkcorrect 342 342 real score 9.44409030675888 Hits@1 0.7767857142857143 Hits@3 0.9077380952380952 Hits@10 0.9851190476190477 MRR 0.8515816743364013 cur_rank 0 abs_cur_rank 0 total_num 335 500
215 113
checkcorrect 150 150 real score 9.94233912229538 Hits@1 0.7774480712166172 Hits@3 0.9080118694362018 Hits@10 0.9851632047477745 MRR 0.8520220847983111 cur_rank 0 abs_cur_rank 0 total_num 336 500
215 150
checkcorrect 52 52 real score 9.961074590682983 Hits@1 0.7751479289940828 Hits@3 0.908284023668639 Hits@10 0.985207100591716 MRR 0.8509805993403279 cur_rank 1 abs_cur_rank 1 total_num 337 500
215 4
checkcorrect 158 158 real score 9.972592830657959 Hits@1 0.775811209439528 Hits@3 0.9085545722713865 Hits@10 0.9852507374631269 MRR 0.8514201845930113 cur_rank 0 abs_cur_rank 0 total_num 33

215 148
checkcorrect 112 112 real score 9.827461540699005 Hits@1 0.7792553191489362 Hits@3 0.9069148936170213 Hits@10 0.9813829787234043 MRR 0.8518004850220475 cur_rank 0 abs_cur_rank 0 total_num 375 500
215 150
checkcorrect 2 2 real score 9.996079325675964 Hits@1 0.7798408488063661 Hits@3 0.9071618037135278 Hits@10 0.9814323607427056 MRR 0.8521935871837927 cur_rank 0 abs_cur_rank 0 total_num 376 500
0 1
0 1 1
checkcorrect 52 52 real score 0.0 Hits@1 0.7777777777777778 Hits@3 0.9047619047619048 Hits@10 0.9788359788359788 MRR 0.8500408568961677 cur_rank 25 abs_cur_rank 26 total_num 377 500
215 150
checkcorrect 48 48 real score 9.987634658813477 Hits@1 0.7783641160949868 Hits@3 0.9050131926121372 Hits@10 0.978891820580475 MRR 0.850436527458447 cur_rank 0 abs_cur_rank 0 total_num 378 500
215 150
checkcorrect 70 70 real score 9.529508650302887 Hits@1 0.7763157894736842 Hits@3 0.9026315789473685 Hits@10 0.9789473684210527 MRR 0.8488564313335563 cur_rank 3 abs_cur_rank 3 total_num 379 500
21

215 150
checkcorrect 48 48 real score 9.971401333808899 Hits@1 0.7817745803357314 Hits@3 0.9064748201438849 Hits@10 0.9808153477218226 MRR 0.8533144138451274 cur_rank 0 abs_cur_rank 0 total_num 416 500
215 150
checkcorrect 14 14 real score 9.980780601501465 Hits@1 0.7822966507177034 Hits@3 0.9066985645933014 Hits@10 0.9808612440191388 MRR 0.8536653363000433 cur_rank 0 abs_cur_rank 0 total_num 417 500
215 61
checkcorrect 186 186 real score 9.645696818828583 Hits@1 0.7804295942720764 Hits@3 0.9069212410501193 Hits@10 0.9809069212410502 MRR 0.8528212662850074 cur_rank 1 abs_cur_rank 1 total_num 418 500
215 130
checkcorrect 102 102 real score 9.979909598827362 Hits@1 0.780952380952381 Hits@3 0.9071428571428571 Hits@10 0.9809523809523809 MRR 0.8531716918414717 cur_rank 0 abs_cur_rank 0 total_num 419 500
0 1
2150 24 6
checkcorrect 388 388 real score 9.836360812187195 Hits@1 0.7814726840855107 Hits@3 0.9073634204275535 Hits@10 0.9809976247030879 MRR 0.8535204526684516 cur_rank 0 abs_cur_rank 

215 150
checkcorrect 102 102 real score 9.963393867015839 Hits@1 0.7816593886462883 Hits@3 0.9061135371179039 Hits@10 0.982532751091703 MRR 0.8535999503058619 cur_rank 0 abs_cur_rank 0 total_num 457 500
215 11
checkcorrect 132 132 real score 9.922476172447205 Hits@1 0.7821350762527233 Hits@3 0.906318082788671 Hits@10 0.9825708061002179 MRR 0.853918904662494 cur_rank 0 abs_cur_rank 0 total_num 458 500
215 150
checkcorrect 28 28 real score 9.93679940700531 Hits@1 0.782608695652174 Hits@3 0.9065217391304348 Hits@10 0.9826086956521739 MRR 0.8542364722610537 cur_rank 0 abs_cur_rank 0 total_num 459 500
215 141
checkcorrect 14 14 real score 9.98076981306076 Hits@1 0.7830802603036876 Hits@3 0.9067245119305857 Hits@10 0.982646420824295 MRR 0.8545526621259972 cur_rank 0 abs_cur_rank 0 total_num 460 500
0 1
2150 4 8
checkcorrect 38 38 real score 9.967418909072876 Hits@1 0.7835497835497836 Hits@3 0.9069264069264069 Hits@10 0.9826839826839827 MRR 0.8548674832036466 cur_rank 0 abs_cur_rank 1 total_n

215 148
checkcorrect 144 144 real score 9.916107833385468 Hits@1 0.7855711422845691 Hits@3 0.905811623246493 Hits@10 0.9799599198396793 MRR 0.8556237087299274 cur_rank 1 abs_cur_rank 1 total_num 498 500
215 15
checkcorrect 64 64 real score 9.999295115470886 Hits@1 0.786 Hits@3 0.906 Hits@10 0.98 MRR 0.8559124613124676 cur_rank 0 abs_cur_rank 0 total_num 499 500
