### Train the inductive link prediction model

In [1]:
data_name = 'WN18RR_v4'
model_id = 'SiaLP_3_new'
lower_bound = 1
upper_bound_path = 10
upper_bound_subg = 3

In [2]:
#difine the names for saving
model_name = 'Model_' + model_id + '_' + data_name
one_hop_model_name = 'One_hop_model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [3]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
import pickle

from collections import defaultdict
from copy import deepcopy
from sklearn.utils import shuffle
from sys import getsizeof

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [4]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = set()
            
            one_hop[s].add((r,t))
            
            if t not in one_hop:
                
                one_hop[t] = set()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            one_hop[t].add((r_inv,s))
            
        #change each set in one_hop to list
        for e in one_hop:
            
            one_hop[e] = list(one_hop[e])

In [5]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, amount_bd=50, size_bd=50, threshold=20000):
        
        self.amount_bd = amount_bd #how many Tuples we choose in one_hop[node] for next recursion
                        
        self.size_bd = size_bd #size bound limit the number of paths to a target entity t
        
        #number of times paths with specific length been performed for recursion
        self.threshold = threshold
        
    '''
    Given an entity s, the function will find the paths from s to other entities, using recursion.
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on existing entities')

        #here is the result dict. Its key is each entity t sharing paths from s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #qualified_t contains the types of t we want to consider,
        #that is, what t will be added to the result set.
        qualified_t = set()

        #under this mode, we will only consider the direct neighbour of s
        if mode == 'direct_neighbour':
        
            for Tuple in one_hop[s]:
            
                t = Tuple[1]
                
                qualified_t.add(t)
        
        #under this mode, we will only consider one specified entity t
        elif mode == 'target_specified':
            
            qualified_t.add(t_input)
        
        #under this mode, we will consider any entity
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                qualified_t.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {s, e1, ..., ek-1, node}
        from s to this node, we will further find the direct neighbor t' of this node. 
        If t' is not an on-path entity (not among s, e1,...ek-1, node), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict):

            #when the current path is within lower_bd and upper_bd, 
            #and the node is among the qualified t, and it has not been fill of paths w.r.t size_limit,
            #we will add this path to the node
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in qualified_t) and (len(res[node]) < self.size_bd):
                
                res[node].add(tuple(path))
                    
            #won't start new recursions if the current path length already reaches upper limit
            #or the number of recursions performed on this length has reached the limit
            if (len(path) < upper_bd) and (count_dict[len(path)] <= self.threshold):
                                
                #temp list is the id list for us to go-over one_hop[node]
                temp_list = [i for i in range(len(one_hop[node]))]
                random.shuffle(temp_list) #so we random-shuffle the list
                
                #only take 20 recursions if there are too many (r,t)
                for i in temp_list[:self.amount_bd]:
                    
                    #obtain tuple of (r,t)
                    Tuple = one_hop[node][i]
                    r, t = Tuple[0], Tuple[1]
                    
                    #add to count_dict even if eventually this step not proceed
                    count_dict[len(path)] += 1
                    
                    #if t not on the path and we not exceed the computation threshold, 
                    #then finally proceed to next recursion
                    if (t not in on_path_en) and (count_dict[len(path)] <= self.threshold):

                        helper(t, path + [r], on_path_en.union({t}), res, qualified_t, 
                               lower_bd, upper_bd, one_hop, count_dict)

        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict)
        
        return(res, count_dict)

In [6]:
train_path = '../data/' + data_name + '/train.txt'
valid_path = '../data/' + data_name + '/valid.txt'
test_path = '../data/' + data_name + '/test.txt'

In [7]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [8]:
#define the dictionaries and sets for load KG
one_hop = dict() 
data = set()
s_t_r = defaultdict(set)

#define the dictionaries, which is shared by initail and inductive train/valid/test
entity2id = dict()
id2entity = dict()
relation2id = dict()
id2relation = dict()

#fill in the sets and dicts
Class_1.load_train_data(train_path, one_hop, data, s_t_r,
                        entity2id, id2entity, relation2id, id2relation)

In [9]:
#define the dictionaries and sets for load KG
one_hop_valid = dict() 
data_valid = set()
s_t_r_valid = defaultdict(set)

#fill in the sets and dicts
Class_1.load_train_data(valid_path, one_hop_valid, data_valid, s_t_r_valid,
                        entity2id, id2entity, relation2id, id2relation)

In [10]:
#define the dictionaries and sets for load KG
one_hop_test = dict() 
data_test = set()
s_t_r_test = defaultdict(set)

#fill in the sets and dicts
Class_1.load_train_data(test_path, one_hop_test, data_test, s_t_r_test,
                        entity2id, id2entity, relation2id, id2relation)

#### Build the path-based siamese neural network structure

We use biLSTM to train on the input path embedding sequence to predict the output embedding or the relation.

In [11]:
# Input layer, using integer to represent each relation type
#note that inputs_path is the path inputs, while inputs_out_re is the output relation inputs
fst_path = keras.Input(shape=(None,), dtype="int32")
scd_path = keras.Input(shape=(None,), dtype="int32")
thd_path = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of paths are not the same
in_embd_var = layers.Embedding(len(relation2id)+1, 300)

# Obtain the embedding
fst_p_embd = in_embd_var(fst_path)
scd_p_embd = in_embd_var(scd_path)
thd_p_embd = in_embd_var(thd_path)

# Embed each integer in a 300-dimensional vector as output
rela_embd = layers.Embedding(len(relation2id)+1, 300)(id_rela)

#add 2 layer bi-directional LSTM
lstm_layer_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_layer_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

#first LSTM layer
fst_lstm_mid = lstm_layer_1(fst_p_embd)
scd_lstm_mid = lstm_layer_1(scd_p_embd)
thd_lstm_mid = lstm_layer_1(thd_p_embd)

#second LSTM layer
fst_lstm_out = lstm_layer_2(fst_lstm_mid)
scd_lstm_out = lstm_layer_2(scd_lstm_mid)
thd_lstm_out = lstm_layer_2(thd_lstm_mid)

#reduce max
fst_reduce_max = tf.reduce_max(fst_lstm_out, axis=1)
scd_reduce_max = tf.reduce_max(scd_lstm_out, axis=1)
thd_reduce_max = tf.reduce_max(thd_lstm_out, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
path_concat = layers.concatenate([fst_reduce_max, scd_reduce_max, thd_reduce_max], axis=-1)

#add dropout on top of the concatenation from all channels
dropout = layers.Dropout(0.25)(path_concat)

#multiply into output embd size by dense layer: (Batch, 300)
path_out_vect = layers.Dense(300, activation='tanh')(dropout)

#remove the time dimension from the output embd since there is only one step
rela_out_embd = tf.reduce_sum(rela_embd, axis=1)

# Normalize the vectors to have unit length
path_out_vect_norm = tf.math.l2_normalize(path_out_vect, axis=-1)
rela_out_embd_norm = tf.math.l2_normalize(rela_out_embd, axis=-1)

# Calculate the dot product
dot_product = layers.Dot(axes=-1)([path_out_vect_norm, rela_out_embd_norm])

#put together the model
model = keras.Model([fst_path, scd_path, thd_path, id_rela], dot_product)

2023-05-15 13:37:22.889519: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [12]:
#config the Adam optimizer 
opt = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['binary_accuracy'])

#### Build the subgraph-based siamese neural network

In [13]:
#each input is an vector with number of relations to be dim:
#each dim represent the existence (1) or not (0) of an out-going relation from the entity
source_path_1 = keras.Input(shape=(None,), dtype="int32")
source_path_2 = keras.Input(shape=(None,), dtype="int32")
source_path_3 = keras.Input(shape=(None,), dtype="int32")

target_path_1 = keras.Input(shape=(None,), dtype="int32")
target_path_2 = keras.Input(shape=(None,), dtype="int32")
target_path_3 = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela_ = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of paths are not the same
in_embd_var_ = layers.Embedding(len(relation2id)+1, 300)

# Obtain the source embeddings
source_embd_1 = in_embd_var_(source_path_1)
source_embd_2 = in_embd_var_(source_path_2)
source_embd_3 = in_embd_var_(source_path_3)

#Obtain the target embeddings
target_embd_1 = in_embd_var_(target_path_1)
target_embd_2 = in_embd_var_(target_path_2)
target_embd_3 = in_embd_var_(target_path_3)

# Embed each integer in a 300-dimensional vector as output
rela_embd_ = layers.Embedding(len(relation2id)+1, 300)(id_rela_)

#add 2 layer bi-directional LSTM network
lstm_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

###source lstm implimentation########
#first LSTM layer
source_mid_1 = lstm_1(source_embd_1)
source_mid_2 = lstm_1(source_embd_2)
source_mid_3 = lstm_1(source_embd_3)

#second LSTM layer
source_out_1 = lstm_2(source_mid_1)
source_out_2 = lstm_2(source_mid_2)
source_out_3 = lstm_2(source_mid_3)

#reduce max
source_max_1 = tf.reduce_max(source_out_1, axis=1)
source_max_2 = tf.reduce_max(source_out_2, axis=1)
source_max_3 = tf.reduce_max(source_out_3, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
source_concat = layers.concatenate([source_max_1, source_max_2, source_max_3], axis=-1)

#add dropout on top of the concatenation from all channels
source_dropout = layers.Dropout(0.25)(source_concat)

###target lstm implimentation########
#first LSTM layer
target_mid_1 = lstm_1(target_embd_1)
target_mid_2 = lstm_1(target_embd_2)
target_mid_3 = lstm_1(target_embd_3)

#second LSTM layer
target_out_1 = lstm_2(target_mid_1)
target_out_2 = lstm_2(target_mid_2)
target_out_3 = lstm_2(target_mid_3)

#reduce max
target_max_1 = tf.reduce_max(target_out_1, axis=1)
target_max_2 = tf.reduce_max(target_out_2, axis=1)
target_max_3 = tf.reduce_max(target_out_3, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
target_concat = layers.concatenate([target_max_1, target_max_2, target_max_3], axis=-1)

#add dropout on top of the concatenation from all channels
target_dropout = layers.Dropout(0.25)(target_concat)

#further concatenate source and target output embeddings: (Batch, 1800)
final_concat = layers.concatenate([source_dropout, target_dropout], axis=-1)

#multiply into output embd size by dense layer: (Batch, 300)
out_vect = layers.Dense(300, activation='tanh')(final_concat)

#remove the time dimension from the output embd since there is only one step
rela_out_embd_ = tf.reduce_sum(rela_embd_, axis=1)

# Normalize the vectors to have unit length
out_vect_norm = tf.math.l2_normalize(out_vect, axis=-1)
rela_out_embd_norm_ = tf.math.l2_normalize(rela_out_embd_, axis=-1)

# Calculate the dot product
dot_product_ = layers.Dot(axes=-1)([out_vect_norm, rela_out_embd_norm_])

#put together the model
model_2 = keras.Model([source_path_1, source_path_2, source_path_3,
                       target_path_1, target_path_2, target_path_3, id_rela_], dot_product_)

In [14]:
#config the Adam optimizer 
opt_ = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model_2.compile(loss='binary_crossentropy', optimizer=opt_, metrics=['binary_accuracy'])

### Build the big-batch for path-based model
We will build the big-batch for the path-based model training. That is, we will build three list to store three paths, respectively.

In order to reduce computational complexity, we will run the path-finding algorithm for each entity e in the dataset before the training. That is, for each entity e, we will have two dictionaries. Dict 1 stores the paths between e and any other entities in the dataset. Will Dict 2 stores the paths between e and its direct neighbors. The two dicts will be used and invariant throughout the training.

* At each step, three different paths between two entities s and t are selected. Each path is append to one of the list. 
* If this step is for positive samples, the existing relation r will be selected between s and t. If there are more than one relation from s to t, we randomly choose one. Also, the label list will be appended 1.
* If this step is for negative samples, one relation that does not exist between s and t will be selected randomly and append to the relation list. Also, the label list will be appended 0.
* In practice, the positive step is always fallowed by a negative step. The same paths in the positive step will be used in the next negative step, while the relation is a negative one chosen in the above way.
* We do this until the length limit is reached.

**For relation prediciton, we will only need to train using (s,r,t) triple. (t,r-1,s) is not necessary and hence not included in training.**

In [15]:
#function to build the big batche for path-based training
def build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    count = 0
    for s in existing_ids:
        
        #impliment the path finding algorithm to find paths between s and t
        result, length_dict = Class_2.obtain_paths('direct_neighbour', s, 'nb', lower_bd, upper_bd, one_hop)
        
        for iteration in range(10):

            #proceed only if at least three paths are between s and t
            for t in result:

                if len(s_t_r[(s,t)]) == 0:

                    raise ValueError(s,t,id2entity[s], id2entity[t])

                #we are only interested in forward link in relation prediciton
                ini_r_list = list()

                #obtain initial relations between s and t
                for r in s_t_r[(s,t)]:
                    if r % 2 == 0:#initial relation id is always an even number
                        ini_r_list.append(r)

                #if there exist more than three paths between s and t, 
                #and inital connection between s and t exists,
                #and not every r in the relation dictionary exists between s and t (although this is rare)
                #we then proceed
                if len(result[t]) >= 3 and len(ini_r_list) > 0 and len(ini_r_list) < int(num_ini_r):

                    #obtain the list form of all the paths from s to t
                    temp_path_list = list(result[t])

                    temp_pair = random.sample(temp_path_list, 3)

                    path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

                    #####positive#####################
                    #append the paths: note that we add the space holder id at the end of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    r = random.choice(ini_r_list)
                    x_r_list.append([r])
                    y_list.append(1.)

                    #####negative#####################
                    #append the paths: note that we add the space holder id at the end
                    #of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    neg_r_list = list(ini_r_id_set.difference(set(ini_r_list)))
                    r_ran = random.choice(neg_r_list)
                    x_r_list.append([r_ran])
                    y_list.append(0.)
        
        count += 1
        if count % 100 == 0:
            print('generating big-batches for path-based model', count, len(existing_ids))

### Build the big-batch for the subgraph-based network training

Again, to reduce computational complexity, we store the subgraph of each entity e at the biginning.

* At each step, we will select one triple (s,r,t) from the dataset. Then, reaching out paths of s and t is generated respectively according to their out-going relations.
* We will select three paths for each of source and target entity. Add them to the corresponding list.
* If this is a positive sample step, the id of relation r is appended to the relation list.
* If this is a negative sample step, the id of a random relation is appended to the relation lsit.
* Similarly, one negative sample step always follows one positive step. The one-hop vectors from the previous positve step is used again for the negative step.

In [16]:
#Again, it is too slow to run the path-finding algorithm again and again on the complete FB15K-237
#Instead, we will find the subgraph for each entity once.
#then in the subgraph based training, the subgraphs are stored and used for multiple times
def store_subgraph_dicts(lower_bd, upper_bd, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
    
    #the ids to start path finding
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    #Dict stores the subgraph for each entity
    Dict_1 = dict()
    
    count = 0
    for s in existing_ids:
        
        path_set = set()
            
        result, length_dict = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)

        for t_ in result:
            for path in result[t_]:
                path_set.add(path)

        del(result, length_dict)
        
        path_list = list(path_set)
        
        path_select = random.sample(path_list, min(len(path_list), 100))
            
        Dict_1[s] = deepcopy(path_select)
        
        count += 1
        if count % 100 == 0:
            print('generating and storing paths for the path-based model', count, len(existing_ids))
        
    return(Dict_1)

In [17]:
#function to build the big-batch for one-hope neighbor training
def build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_s_list, x_t_list, x_r_list, y_list, Dict,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
        
    #if an entity has at least three out-stretching paths, it is a qualified one
    qualified = set()
    for e in Dict:
        if len(Dict[e]) >= 3:
            qualified.add(e)
    qualified = list(qualified)
    
    data = list(data)
    
    for iteration in range(10):

        data = shuffle(data)

        for i_0 in range(len(data)):

            triple = data[i_0]

            s, r, t = triple[0], triple[1], triple[2] #obtain entities and relation IDs

            if s in qualified and t in qualified:

                #obtain the path list for true entities
                path_s, path_t = list(Dict[s]), list(Dict[t])

                #####positive step###########
                #randomly obtain three paths for true entities
                temp_s = random.sample(path_s, 3)
                temp_t = random.sample(path_t, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative step for relation###########
                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                neg_r_list = list(ini_r_id_set.difference({r}))
                r_ran = random.choice(neg_r_list)
                x_r_list.append([r_ran])
                y_list.append(0.)
                
                ##############################################
                ##############################################
                #randomly choose two negative sampled entities
                s_ran = random.choice(qualified)
                t_ran = random.choice(qualified)

                #obtain the path list for random entities
                path_s_ran, path_t_ran = list(Dict[s_ran]), list(Dict[t_ran])
                
                #####positive step#################
                #Again: randomly obtain three paths
                temp_s = random.sample(path_s, 3)
                temp_t = random.sample(path_t, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative for source entity###########
                #randomly obtain three paths
                temp_s = random.sample(path_s_ran, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(0.)

                #####positive step###########
                #Again: randomly obtain three paths
                temp_s = random.sample(path_s, 3)
                temp_t = random.sample(path_t, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative for target entity###########
                #randomly obtain three paths
                temp_t = random.sample(path_t_ran, 3)
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(0.)

            if i_0 % 200 == 0:
                print('generating big-batches for subgraph-based model', i_0, len(data), iteration)

### Start Training: load the KG and call classes

Here, we use the validation set to see the training efficiency. That is, we use the validation to check whether the true relation between entities can be predicted by paths.

The trick is: in validation, we have to use the same relation ID and entity ID as in the training. But we don't want to use the links in training anymore. That is, in validation, we want to use (and update if necessary) entity2id, id2entity, relation2id and id2relation. But we want to use new one_hop, data, data_ and s_t_r for validation set. Then, path-finding will also be based on new one_hop.


In [18]:
model_name

'Model_SiaLP_3_new_WN18RR_v4'

In [19]:
one_hop_model_name

'One_hop_model_SiaLP_3_new_WN18RR_v4'

In [20]:
ids_name

'IDs_SiaLP_3_new_WN18RR_v4'

In [21]:
#first, we save the relation and ids
Dict = dict()

#save training data
Dict['one_hop'] = one_hop
Dict['data'] = data
Dict['s_t_r'] = s_t_r

#save valid data
Dict['one_hop_valid'] = one_hop_valid
Dict['data_valid'] = data_valid
Dict['s_t_r_valid'] = s_t_r_valid

#save test data
Dict['one_hop_test'] = one_hop_test
Dict['data_test'] = data_test
Dict['s_t_r_test'] = s_t_r_test

#save shared dictionaries
Dict['entity2id'] = entity2id
Dict['id2entity'] = id2entity
Dict['relation2id'] = relation2id
Dict['id2relation'] = id2relation

with open('../weight_bin/' + ids_name + '.pickle', 'wb') as handle:
    pickle.dump(Dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [22]:
###train the path-based model
lower_bd = lower_bound
upper_bd = upper_bound_path
num_epoch = 10
batch_size = 32
        
#define the training lists
train_p_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, list(), list()

#define the validation lists
valid_p_list, valid_r_list, valid_y_list = {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      train_p_list, train_r_list, train_y_list,
                      relation2id, entity2id, id2relation, id2entity)

#fill in the validation array list
build_big_batches_path(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                      valid_p_list, valid_r_list, valid_y_list,
                      relation2id, entity2id, id2relation, id2entity)    

#######################################
###do the training#####################
#sometimes the validation dataset is so small so sparse, 
#which cannot find three paths between any pair of s and t.
#in such a case, we will divide the training big-batch into train and valid
if len(valid_y_list) >= 100:
    #generate the input arrays
    x_train_1 = np.asarray(train_p_list['1'], dtype='int')
    x_train_2 = np.asarray(train_p_list['2'], dtype='int')
    x_train_3 = np.asarray(train_p_list['3'], dtype='int')
    x_train_r = np.asarray(train_r_list, dtype='int')
    y_train = np.asarray(train_y_list, dtype='int')

    #generate the validation arrays
    x_valid_1 = np.asarray(valid_p_list['1'], dtype='int')
    x_valid_2 = np.asarray(valid_p_list['2'], dtype='int')
    x_valid_3 = np.asarray(valid_p_list['3'], dtype='int')
    x_valid_r = np.asarray(valid_r_list, dtype='int')
    y_valid = np.asarray(valid_y_list, dtype='int')

else:
    split = int(len(train_y_list)*0.8)
    #generate the input arrays
    x_train_1 = np.asarray(train_p_list['1'][:split], dtype='int')
    x_train_2 = np.asarray(train_p_list['2'][:split], dtype='int')
    x_train_3 = np.asarray(train_p_list['3'][:split], dtype='int')
    x_train_r = np.asarray(train_r_list[:split], dtype='int')
    y_train = np.asarray(train_y_list[:split], dtype='int')

    #generate the validation arrays
    x_valid_1 = np.asarray(train_p_list['1'][split:], dtype='int')
    x_valid_2 = np.asarray(train_p_list['2'][split:], dtype='int')
    x_valid_3 = np.asarray(train_p_list['3'][split:], dtype='int')
    x_valid_r = np.asarray(train_r_list[split:], dtype='int')
    y_valid = np.asarray(train_y_list[split:], dtype='int')

#do the training
model.fit([x_train_1, x_train_2, x_train_3, x_train_r], y_train, 
          validation_data=([x_valid_1, x_valid_2, x_valid_3, x_valid_r], y_valid),
          batch_size=batch_size, epochs=num_epoch)   

# Save model and weights
add_h5 = model_name + '.h5'
save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path = os.path.join(save_dir, add_h5)
model.save(model_path)
print('Save model')
del(model)

generating big-batches for path-based model 100 3861
generating big-batches for path-based model 200 3861
generating big-batches for path-based model 300 3861
generating big-batches for path-based model 400 3861
generating big-batches for path-based model 500 3861
generating big-batches for path-based model 600 3861
generating big-batches for path-based model 700 3861
generating big-batches for path-based model 800 3861
generating big-batches for path-based model 900 3861
generating big-batches for path-based model 1000 3861
generating big-batches for path-based model 1100 3861
generating big-batches for path-based model 1200 3861
generating big-batches for path-based model 1300 3861
generating big-batches for path-based model 1400 3861
generating big-batches for path-based model 1500 3861
generating big-batches for path-based model 1600 3861
generating big-batches for path-based model 1700 3861
generating big-batches for path-based model 1800 3861
generating big-batches for path-based

In [23]:
###train the subgraph-based model
lower_bd = lower_bound
upper_bd = upper_bound_subg
num_epoch = 10
batch_size = 32

Dict_train = store_subgraph_dicts(lower_bd, upper_bd, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity)

Dict_valid = store_subgraph_dicts(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                         relation2id, entity2id, id2relation, id2entity)
        
#define the training lists
train_s_list, train_t_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, {'1': [], '2': [], '3': []}, list(), list()

#define the validation lists
valid_s_list, valid_t_list, valid_r_list, valid_y_list = {'1': [], '2': [], '3': []}, {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      train_s_list, train_t_list, train_r_list, train_y_list, Dict_train,
                      relation2id, entity2id, id2relation, id2entity)

#fill in the validation array list
build_big_batches_subgraph(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                      valid_s_list, valid_t_list, valid_r_list, valid_y_list, Dict_valid,
                      relation2id, entity2id, id2relation, id2entity)    

#######################################
###do the training#####################
#sometimes the validation dataset is so small so sparse, 
#which cannot find three paths between any pair of s and t.
#in such a case, we will divide the training big-batch into train and valid
if len(valid_y_list) >= 100:
    #generate the input arrays
    x_train_s_1 = np.asarray(train_s_list['1'], dtype='int')
    x_train_s_2 = np.asarray(train_s_list['2'], dtype='int')
    x_train_s_3 = np.asarray(train_s_list['3'], dtype='int')

    x_train_t_1 = np.asarray(train_t_list['1'], dtype='int')
    x_train_t_2 = np.asarray(train_t_list['2'], dtype='int')
    x_train_t_3 = np.asarray(train_t_list['3'], dtype='int')

    x_train_r = np.asarray(train_r_list, dtype='int')
    y_train = np.asarray(train_y_list, dtype='int')

    #generate the validation arrays
    x_valid_s_1 = np.asarray(valid_s_list['1'], dtype='int')
    x_valid_s_2 = np.asarray(valid_s_list['2'], dtype='int')
    x_valid_s_3 = np.asarray(valid_s_list['3'], dtype='int')

    x_valid_t_1 = np.asarray(valid_t_list['1'], dtype='int')
    x_valid_t_2 = np.asarray(valid_t_list['2'], dtype='int')
    x_valid_t_3 = np.asarray(valid_t_list['3'], dtype='int')

    x_valid_r = np.asarray(valid_r_list, dtype='int')
    y_valid = np.asarray(valid_y_list, dtype='int')

else:
    split = int(len(train_y_list)*0.8)
    #generate the input arrays
    x_train_s_1 = np.asarray(train_s_list['1'][:split], dtype='int')
    x_train_s_2 = np.asarray(train_s_list['2'][:split], dtype='int')
    x_train_s_3 = np.asarray(train_s_list['3'][:split], dtype='int')

    x_train_t_1 = np.asarray(train_t_list['1'][:split], dtype='int')
    x_train_t_2 = np.asarray(train_t_list['2'][:split], dtype='int')
    x_train_t_3 = np.asarray(train_t_list['3'][:split], dtype='int')

    x_train_r = np.asarray(train_r_list[:split], dtype='int')
    y_train = np.asarray(train_y_list[:split], dtype='int')

    #generate the validation arrays
    x_valid_s_1 = np.asarray(train_s_list['1'][split:], dtype='int')
    x_valid_s_2 = np.asarray(train_s_list['2'][split:], dtype='int')
    x_valid_s_3 = np.asarray(train_s_list['3'][split:], dtype='int')

    x_valid_t_1 = np.asarray(train_t_list['1'][split:], dtype='int')
    x_valid_t_2 = np.asarray(train_t_list['2'][split:], dtype='int')
    x_valid_t_3 = np.asarray(train_t_list['3'][split:], dtype='int')

    x_valid_r = np.asarray(train_r_list[split:], dtype='int')
    y_valid = np.asarray(train_y_list[split:], dtype='int')

#do the training
model_2.fit([x_train_s_1, x_train_s_2, x_train_s_3, x_train_t_1, x_train_t_2, x_train_t_3, x_train_r], y_train, 
          validation_data=([x_valid_s_1, x_valid_s_2, x_valid_s_3, x_valid_t_1, x_valid_t_2, x_valid_t_3, x_valid_r], y_valid),
          batch_size=batch_size, epochs=num_epoch)

# Save model and weights
one_hop_add_h5 = one_hop_model_name + '.h5'
one_hop_save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(one_hop_save_dir):
    os.makedirs(one_hop_save_dir)
one_hop_model_path = os.path.join(one_hop_save_dir, one_hop_add_h5)
model_2.save(one_hop_model_path)
print('Save model')
del(model_2, Dict_train, Dict_valid)

generating and storing paths for the path-based model 100 3861
generating and storing paths for the path-based model 200 3861
generating and storing paths for the path-based model 300 3861
generating and storing paths for the path-based model 400 3861
generating and storing paths for the path-based model 500 3861
generating and storing paths for the path-based model 600 3861
generating and storing paths for the path-based model 700 3861
generating and storing paths for the path-based model 800 3861
generating and storing paths for the path-based model 900 3861
generating and storing paths for the path-based model 1000 3861
generating and storing paths for the path-based model 1100 3861
generating and storing paths for the path-based model 1200 3861
generating and storing paths for the path-based model 1300 3861
generating and storing paths for the path-based model 1400 3861
generating and storing paths for the path-based model 1500 3861
generating and storing paths for the path-based m

generating big-batches for subgraph-based model 1400 7940 2
generating big-batches for subgraph-based model 1600 7940 2
generating big-batches for subgraph-based model 1800 7940 2
generating big-batches for subgraph-based model 2000 7940 2
generating big-batches for subgraph-based model 2200 7940 2
generating big-batches for subgraph-based model 2400 7940 2
generating big-batches for subgraph-based model 2600 7940 2
generating big-batches for subgraph-based model 2800 7940 2
generating big-batches for subgraph-based model 3000 7940 2
generating big-batches for subgraph-based model 3200 7940 2
generating big-batches for subgraph-based model 3400 7940 2
generating big-batches for subgraph-based model 3600 7940 2
generating big-batches for subgraph-based model 3800 7940 2
generating big-batches for subgraph-based model 4000 7940 2
generating big-batches for subgraph-based model 4200 7940 2
generating big-batches for subgraph-based model 4400 7940 2
generating big-batches for subgraph-base

generating big-batches for subgraph-based model 200 7940 6
generating big-batches for subgraph-based model 400 7940 6
generating big-batches for subgraph-based model 600 7940 6
generating big-batches for subgraph-based model 800 7940 6
generating big-batches for subgraph-based model 1000 7940 6
generating big-batches for subgraph-based model 1200 7940 6
generating big-batches for subgraph-based model 1400 7940 6
generating big-batches for subgraph-based model 1600 7940 6
generating big-batches for subgraph-based model 1800 7940 6
generating big-batches for subgraph-based model 2000 7940 6
generating big-batches for subgraph-based model 2200 7940 6
generating big-batches for subgraph-based model 2400 7940 6
generating big-batches for subgraph-based model 2600 7940 6
generating big-batches for subgraph-based model 2800 7940 6
generating big-batches for subgraph-based model 3000 7940 6
generating big-batches for subgraph-based model 3200 7940 6
generating big-batches for subgraph-based mo

generating big-batches for subgraph-based model 4000 7940 9
generating big-batches for subgraph-based model 4200 7940 9
generating big-batches for subgraph-based model 4400 7940 9
generating big-batches for subgraph-based model 4600 7940 9
generating big-batches for subgraph-based model 4800 7940 9
generating big-batches for subgraph-based model 5000 7940 9
generating big-batches for subgraph-based model 5200 7940 9
generating big-batches for subgraph-based model 5400 7940 9
generating big-batches for subgraph-based model 5600 7940 9
generating big-batches for subgraph-based model 5800 7940 9
generating big-batches for subgraph-based model 6000 7940 9
generating big-batches for subgraph-based model 6200 7940 9
generating big-batches for subgraph-based model 6400 7940 9
generating big-batches for subgraph-based model 6600 7940 9
generating big-batches for subgraph-based model 6800 7940 9
generating big-batches for subgraph-based model 7000 7940 9
generating big-batches for subgraph-base

### Result on the testset for inductive link prediction

We use the testset for inductive link prediction.

In [1]:
data_name = 'fb237_v2'
model_id = 'SiaLP_3_new'
lower_bound = 1
upper_bound_path = 10
upper_bound_subg = 3

In [2]:
#difine the names for saving
model_name = 'Model_' + model_id + '_' + data_name
one_hop_model_name = 'One_hop_model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [3]:
ids_name

'IDs_SiaLP_3_new_fb237_v2'

In [4]:
one_hop_model_name

'One_hop_model_SiaLP_3_new_fb237_v2'

In [5]:
model_name

'Model_SiaLP_3_new_fb237_v2'

In [6]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
import pickle

from collections import defaultdict
from copy import deepcopy
from sklearn.utils import shuffle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [7]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = set()
            
            one_hop[s].add((r,t))
            
            if t not in one_hop:
                
                one_hop[t] = set()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            one_hop[t].add((r_inv,s))
            
        #change each set in one_hop to list
        for e in one_hop:
            
            one_hop[e] = list(one_hop[e])

In [8]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, amount_bd=50, size_bd=50, threshold=20000):
        
        self.amount_bd = amount_bd #how many Tuples we choose in one_hop[node] for next recursion
                        
        self.size_bd = size_bd #size bound limit the number of paths to a target entity t
        
        #number of times paths with specific length been performed for recursion
        self.threshold = threshold
        
    '''
    Given an entity s, the function will find the paths from s to other entities, using recursion.
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on existing entities')

        #here is the result dict. Its key is each entity t sharing paths from s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #qualified_t contains the types of t we want to consider,
        #that is, what t will be added to the result set.
        qualified_t = set()

        #under this mode, we will only consider the direct neighbour of s
        if mode == 'direct_neighbour':
        
            for Tuple in one_hop[s]:
            
                t = Tuple[1]
                
                qualified_t.add(t)
        
        #under this mode, we will only consider one specified entity t
        elif mode == 'target_specified':
            
            qualified_t.add(t_input)
        
        #under this mode, we will consider any entity
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                qualified_t.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {s, e1, ..., ek-1, node}
        from s to this node, we will further find the direct neighbor t' of this node. 
        If t' is not an on-path entity (not among s, e1,...ek-1, node), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict):

            #when the current path is within lower_bd and upper_bd, 
            #and the node is among the qualified t, and it has not been fill of paths w.r.t size_limit,
            #we will add this path to the node
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in qualified_t) and (len(res[node]) < self.size_bd):
                
                res[node].add(tuple(path))
                    
            #won't start new recursions if the current path length already reaches upper limit
            #or the number of recursions performed on this length has reached the limit
            if (len(path) < upper_bd) and (count_dict[len(path)] <= self.threshold):
                                
                #temp list is the id list for us to go-over one_hop[node]
                temp_list = [i for i in range(len(one_hop[node]))]
                random.shuffle(temp_list) #so we random-shuffle the list
                
                #only take 20 recursions if there are too many (r,t)
                for i in temp_list[:self.amount_bd]:
                    
                    #obtain tuple of (r,t)
                    Tuple = one_hop[node][i]
                    r, t = Tuple[0], Tuple[1]
                    
                    #add to count_dict even if eventually this step not proceed
                    count_dict[len(path)] += 1
                    
                    #if t not on the path and we not exceed the computation threshold, 
                    #then finally proceed to next recursion
                    if (t not in on_path_en) and (count_dict[len(path)] <= self.threshold):

                        helper(t, path + [r], on_path_en.union({t}), res, qualified_t, 
                               lower_bd, upper_bd, one_hop, count_dict)

        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict)
        
        return(res, count_dict)

In [9]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [10]:
#load ids and relation/entity dicts
with open('../weight_bin/' + ids_name + '.pickle', 'rb') as handle:
    Dict = pickle.load(handle)
    
#save training data
one_hop = Dict['one_hop']
data = Dict['data']
s_t_r = Dict['s_t_r']

#save valid data
one_hop_valid = Dict['one_hop_valid']
data_valid = Dict['data_valid']
s_t_r_valid = Dict['s_t_r_valid']

#save test data
one_hop_test = Dict['one_hop_test']
data_test = Dict['data_test']
s_t_r_test = Dict['s_t_r_test']

#save shared dictionaries
entity2id = Dict['entity2id']
id2entity = Dict['id2entity']
relation2id = Dict['relation2id']
id2relation = Dict['id2relation']

#we want to keep the initial entity/relation dicts before adding new entities
entity2id_ini = deepcopy(entity2id)
id2entity_ini = deepcopy(id2entity)
relation2id_ini = deepcopy(relation2id)
id2relation_ini = deepcopy(id2relation)

num_r = len(id2relation)
num_r

400

In [11]:
ids_name

'IDs_SiaLP_3_new_fb237_v2'

In [12]:
model_name

'Model_SiaLP_3_new_fb237_v2'

In [13]:
#load the model
model = keras.models.load_model('../weight_bin/' + model_name + '.h5')

2023-05-16 13:17:24.759985: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [14]:
#load the one-hop neighbor model
model_2 = keras.models.load_model('../weight_bin/' + one_hop_model_name + '.h5')

In [15]:
ind_train_path = '../data/' + data_name + '_ind/train.txt'
ind_valid_path = '../data/' + data_name + '_ind/valid.txt'
ind_test_path = '../data/' + data_name + '_ind/test.txt'

In [16]:
#load the test dataset
one_hop_ind = dict() 
data_ind = set()
s_t_r_ind = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_train_path, 
                        one_hop_ind, data_ind, s_t_r_ind,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [17]:
print(size_0, size_1, len(data_ind))

2608 4268 4145


In [18]:
#load the test dataset
one_hop_ind_test = dict() 
data_ind_test = set()
s_t_r_ind_test = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_test_path, 
                        one_hop_ind_test, data_ind_test, s_t_r_ind_test,
                        entity2id, id2entity, relation2id, id2relation)


len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [19]:
print(size_0, size_1, len(data_ind_test))

4268 4268 478


In [20]:
#load the validation for existing triple removal when ranking
one_hop_ind_valid = dict() 
data_ind_valid = set()
s_t_r_ind_valid = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_valid_path, 
                        one_hop_ind_valid, data_ind_valid, s_t_r_ind_valid,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [21]:
print(size_0, size_1, len(data_ind_valid))

4268 4268 469


In [22]:
print(len(entity2id), len(entity2id_ini))

4268 2608


In [23]:
#obtain all the inital entities and new entities
ini_ent_set, new_ent_set, all_ent_set = set(), set(), set()

for ID in id2entity:
    all_ent_set.add(ID)
    if ID in id2entity_ini:
        ini_ent_set.add(ID)
    else:
        new_ent_set.add(ID)
        
print(len(ini_ent_set), len(new_ent_set), len(all_ent_set))

2608 1660 4268


In [24]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_ind_test:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [25]:
overlapping = 0

for ele in data_ind_valid:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [26]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_ind:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [27]:
#the function to do path-based relation scoring
def path_based_relation_scoring(s, t, lower_bd, upper_bd, one_hop, id2relation, model):
    
    path_holder = set()
    
    for iteration in range(3):
    
        result, length_dict = Class_2.obtain_paths('target_specified', 
                                                   s, t, lower_bd, upper_bd, one_hop)
        if t in result:
            
            for path in result[t]:
                
                path_holder.add(path)
                
        del(result, length_dict)
    
    path_holder = list(path_holder)
    random.shuffle(path_holder)
    
    score_dict = defaultdict(float)
    count_dict = defaultdict(int)
    
    count = 0
    
    if len(path_holder) >= 3:
    
        #iterate over path_1
        while count < 10:

            temp_pair = random.sample(path_holder, 3)

            path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

            list_1 = list()
            list_2 = list()
            list_3 = list()
            list_r = list()

            for i in range(len(id2relation)):

                if i not in id2relation:

                    raise ValueError ('error when generating id2relation')
                
                #only care about initial relations
                if i % 2 == 0:

                    list_1.append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    list_2.append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    list_3.append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))
                    list_r.append([i])
            
            #change to arrays
            input_1 = np.array(list_1)
            input_2 = np.array(list_2)
            input_3 = np.array(list_3)
            input_r = np.array(list_r)

            pred = model.predict([input_1, input_2, input_3, input_r], verbose = 0)

            for i in range(pred.shape[0]):
                #need to times 2 to go back to relation id from pred position
                score_dict[2*i] += float(pred[i])
                count_dict[2*i] += 1

            count += 1
            
    #average the score
    for r in score_dict:
        score_dict[r] = deepcopy(score_dict[r]/float(count_dict[r]))
    
    print(len(score_dict), len(path_holder))

    return(score_dict)

In [28]:
#the function to do path-based triple scoring: input one triple
def path_based_triple_scoring(s, r, t, lower_bd, upper_bd, one_hop, id2relation, model):
    
    path_holder = set()
    
    for iteration in range(3):
    
        result, length_dict = Class_2.obtain_paths('target_specified', 
                                                   s, t, lower_bd, upper_bd, one_hop)
        if t in result:
            
            for path in result[t]:
                
                path_holder.add(path)
                
        del(result, length_dict)
    
    path_holder = list(path_holder)
    random.shuffle(path_holder)
    
    score = 0.
    count = 0
    
    if len(path_holder) >= 3:
        
        list_1 = list()
        list_2 = list()
        list_3 = list()
        list_r = list()
    
        #iterate over path_1
        while count < 10:

            temp_pair = random.sample(path_holder, 3)
            path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

            list_1.append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
            list_2.append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
            list_3.append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))
            list_r.append([r])
            
            count += 1
            
        #change to arrays
        input_1 = np.array(list_1)
        input_2 = np.array(list_2)
        input_3 = np.array(list_3)
        input_r = np.array(list_r)

        pred = model.predict([input_1, input_2, input_3, input_r], verbose = 0)

        for i in range(pred.shape[0]):
            score += float(pred[i])
            
        #average the score
        score = score/float(count)

    return(score)

In [29]:
#subgraph based relation scoring
def subgraph_relation_scoring(s, t, lower_bd, upper_bd, one_hop, id2relation, model_2):
    
    path_s, path_t = set(), set() #sets holding all the paths from s or t
    
    for iteration in range(3):
    
        #obtain the paths out from s or t by "any target" mode. That is, 
        result_s, length_dict_s = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)
        result_t, length_dict_t = Class_2.obtain_paths('any_target', t, 'any', lower_bd, upper_bd, one_hop)

        #add paths to the source/target path_set
        for e in result_s:
            for path in result_s[e]:
                path_s.add(path)
        for e in result_t:
            for path in result_t[e]:
                path_t.add(path)
                
        del(result_s, length_dict_s, result_t, length_dict_t)
    
    #final output: the score dict
    score_dict = defaultdict(float)
    count_dict = defaultdict(int)
    
    #see if both path_s and path_t have at least three paths
    if len(path_s) >= 3 and len(path_t) >= 3:

        #change to lists
        path_s, path_t = list(path_s), list(path_t)
        
        count = 0
        while count < 10:
            
            #lists holding the input to the network
            list_s_1 = list()
            list_s_2 = list()
            list_s_3 = list()
            list_t_1 = list()
            list_t_2 = list()
            list_t_3 = list()
            list_r = list()

            #randomly obtain three paths
            temp_s = random.sample(path_s, 3)
            temp_t = random.sample(path_t, 3)
            s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
            t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]
            
            #add all forward (initial relation)
            for i in range(len(id2relation)):

                if i not in id2relation:

                    raise ValueError ('error when generating id2relation')
                    
                if i % 2 == 0:

                    #append the paths: note that we add the space holder id at the end of the shorter path
                    list_s_1.append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                    list_s_2.append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                    list_s_3.append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))
                    
                    list_t_1.append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                    list_t_2.append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                    list_t_3.append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))
                    
                    list_r.append([i])
                
            #change to arrays
            input_s_1 = np.array(list_s_1)
            input_s_2 = np.array(list_s_2)
            input_s_3 = np.array(list_s_3)
            input_t_1 = np.array(list_t_1)
            input_t_2 = np.array(list_t_2)
            input_t_3 = np.array(list_t_3)
            input_r = np.array(list_r)
            
            pred = model_2.predict([input_s_1, input_s_2, input_s_3,
                                    input_t_1, input_t_2, input_t_3, input_r], verbose = 0)

            for i in range(pred.shape[0]):
                #need to times 2 to go back to relation id from pred position
                score_dict[2*i] += float(pred[i])
                count_dict[2*i] += 1

            count += 1
            
    #average the score
    for r in score_dict:
        score_dict[r] = deepcopy(score_dict[r]/float(count_dict[r]))
            
    print(len(score_dict), len(path_s), len(path_t))
        
    return(score_dict)

In [30]:
#subgraph based triple scoring
def subgraph_triple_scoring(s, r, t, lower_bd, upper_bd, one_hop, id2relation, model_2):
    
    path_s, path_t = set(), set() #sets holding all the paths from s or t
    
    for iteration in range(3):
    
        #obtain the paths out from s or t by "any target" mode. That is, 
        result_s, length_dict_s = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)
        result_t, length_dict_t = Class_2.obtain_paths('any_target', t, 'any', lower_bd, upper_bd, one_hop)

        #add paths to the source/target path_set
        for e in result_s:
            for path in result_s[e]:
                path_s.add(path)
        for e in result_t:
            for path in result_t[e]:
                path_t.add(path)
                
        del(result_s, length_dict_s, result_t, length_dict_t)
    
    #final output: the score dict
    score = 0.
    
    #see if both path_s and path_t have at least three paths
    if len(path_s) >= 3 and len(path_t) >= 3:

        #change to lists
        path_s, path_t = list(path_s), list(path_t)
        
        #lists holding the input to the network
        list_s_1 = list()
        list_s_2 = list()
        list_s_3 = list()
        list_t_1 = list()
        list_t_2 = list()
        list_t_3 = list()
        list_r = list()
        
        count = 0
        while count < 10:

            #randomly obtain three paths
            temp_s = random.sample(path_s, 3)
            temp_t = random.sample(path_t, 3)
            s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
            t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

            #append the paths: note that we add the space holder id at the end of the shorter path
            list_s_1.append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
            list_s_2.append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
            list_s_3.append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

            list_t_1.append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
            list_t_2.append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
            list_t_3.append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

            list_r.append([r])
            count += 1
                
        #change to arrays
        input_s_1 = np.array(list_s_1)
        input_s_2 = np.array(list_s_2)
        input_s_3 = np.array(list_s_3)
        input_t_1 = np.array(list_t_1)
        input_t_2 = np.array(list_t_2)
        input_t_3 = np.array(list_t_3)
        input_r = np.array(list_r)

        pred = model_2.predict([input_s_1, input_s_2, input_s_3,
                                input_t_1, input_t_2, input_t_3, input_r], verbose = 0)

        for i in range(pred.shape[0]):
            score += float(pred[i])

        #average the score
        score = score/float(count)
        
    return(score)

#### Not fine tuned 

In [31]:
########################################################
#obtain the Hits@N for relation prediction##############

#we select all the triples in the inductive test set
selected = list(data_ind_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    #run the path-based scoring
    score_dict_path = path_based_relation_scoring(s_true, t_true, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)
    
    #run the one-hop neighbour based scoring
    score_dict_subg = subgraph_relation_scoring(s_true, t_true, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    #final score dict
    score_dict = defaultdict(float)
    
    for r in score_dict_path:
        score_dict[r] += score_dict_path[r]
    for r in score_dict_subg:
        score_dict[r] += score_dict_subg[r]
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        #again, we only care about initial relation prediciton
        if r % 2 == 0:
        
            if r in score_dict:

                temp_list.append([score_dict[r], r])

            else:

                temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
            (s_true, sorted_list[p][1], t_true) in data_valid) or (
            (s_true, sorted_list[p][1], t_true) in data) or (
            (s_true, sorted_list[p][1], t_true) in data_ind) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_valid) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_test):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', r_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

200 9
200 64 27
checkcorrect 62 62 real score 1.9108411490917205 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 0 478
200 150
200 97 703
checkcorrect 158 158 real score 1.6362456351518633 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 1 478
200 150
200 50 356
checkcorrect 74 74 real score 1.9243963003158568 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 2 478
200 39
200 123 67
checkcorrect 216 216 real score 1.4700393468141555 Hits@1 0.75 Hits@3 1.0 Hits@10 1.0 MRR 0.875 cur_rank 1 abs_cur_rank 1 total_num 3 478
200 73
200 82 73
checkcorrect 380 380 real score 1.0794658376835287 Hits@1 0.6 Hits@3 0.8 Hits@10 1.0 MRR 0.75 cur_rank 3 abs_cur_rank 3 total_num 4 478
0 1
200 3 3
checkcorrect 66 66 real score 0.9896178603172302 Hits@1 0.6666666666666666 Hits@3 0.8333333333333334 Hits@10 1.0 MRR 0.7916666666666666 cur_rank 0 abs_cur_rank 0 total_num 5 478
200 150
200 611 339
checkcorrect 48 48 real sc

200 85
200 20 174
checkcorrect 74 74 real score 1.9228682994842528 Hits@1 0.8444444444444444 Hits@3 0.9333333333333333 Hits@10 0.9777777777777777 MRR 0.902020202020202 cur_rank 0 abs_cur_rank 0 total_num 44 478
200 149
200 163 277
checkcorrect 78 78 real score 1.8364808320999146 Hits@1 0.8478260869565217 Hits@3 0.9347826086956522 Hits@10 0.9782608695652174 MRR 0.9041501976284586 cur_rank 0 abs_cur_rank 0 total_num 45 478
200 150
200 824 748
checkcorrect 228 228 real score 1.9637518525123596 Hits@1 0.8297872340425532 Hits@3 0.9361702127659575 Hits@10 0.9787234042553191 MRR 0.8955512572533849 cur_rank 1 abs_cur_rank 2 total_num 46 478
200 41
200 24 98
checkcorrect 80 80 real score 1.0792312014847993 Hits@1 0.8333333333333334 Hits@3 0.9375 Hits@10 0.9791666666666666 MRR 0.8977272727272728 cur_rank 0 abs_cur_rank 0 total_num 47 478
200 150
200 355 521
checkcorrect 138 138 real score 1.659937709569931 Hits@1 0.8163265306122449 Hits@3 0.9387755102040817 Hits@10 0.9795918367346939 MRR 0.88961

200 147
200 112 306
checkcorrect 12 12 real score 1.8252848625183105 Hits@1 0.8235294117647058 Hits@3 0.9647058823529412 Hits@10 0.9882352941176471 MRR 0.899108734402852 cur_rank 0 abs_cur_rank 0 total_num 84 478
200 136
200 52 57
checkcorrect 258 258 real score 1.9561445593833922 Hits@1 0.8255813953488372 Hits@3 0.9651162790697675 Hits@10 0.9883720930232558 MRR 0.9002818886539816 cur_rank 0 abs_cur_rank 0 total_num 85 478
200 5
200 369 255
checkcorrect 178 178 real score 1.4796856105327607 Hits@1 0.8160919540229885 Hits@3 0.9540229885057471 Hits@10 0.9885057471264368 MRR 0.8922326715430163 cur_rank 4 abs_cur_rank 6 total_num 86 478
200 51
200 112 99
checkcorrect 74 74 real score 1.686429101228714 Hits@1 0.8181818181818182 Hits@3 0.9545454545454546 Hits@10 0.9886363636363636 MRR 0.8934573002754821 cur_rank 0 abs_cur_rank 0 total_num 87 478
200 134
200 65 203
checkcorrect 38 38 real score 1.8777843296527863 Hits@1 0.8202247191011236 Hits@3 0.9550561797752809 Hits@10 0.9887640449438202 M

200 150
200 495 583
checkcorrect 42 42 real score 1.852071362733841 Hits@1 0.8145161290322581 Hits@3 0.9596774193548387 Hits@10 0.9919354838709677 MRR 0.8923998044965786 cur_rank 0 abs_cur_rank 1 total_num 123 478
200 121
200 51 374
checkcorrect 74 74 real score 1.8251884758472443 Hits@1 0.816 Hits@3 0.96 Hits@10 0.992 MRR 0.8932606060606061 cur_rank 0 abs_cur_rank 0 total_num 124 478
200 114
200 64 50
checkcorrect 80 80 real score 1.9218526899814607 Hits@1 0.8174603174603174 Hits@3 0.9603174603174603 Hits@10 0.9920634920634921 MRR 0.8941077441077441 cur_rank 0 abs_cur_rank 0 total_num 125 478
200 150
200 396 403
checkcorrect 28 28 real score 1.9344902157783508 Hits@1 0.8188976377952756 Hits@3 0.9606299212598425 Hits@10 0.9921259842519685 MRR 0.8949415413982343 cur_rank 0 abs_cur_rank 0 total_num 126 478
200 56
200 70 54
checkcorrect 20 20 real score 1.1833135902881622 Hits@1 0.8125 Hits@3 0.953125 Hits@10 0.9921875 MRR 0.8899029356060606 cur_rank 3 abs_cur_rank 3 total_num 127 478
200

200 78
200 122 71
checkcorrect 12 12 real score 1.721058374643326 Hits@1 0.8170731707317073 Hits@3 0.9390243902439024 Hits@10 0.9939024390243902 MRR 0.8877155703375216 cur_rank 0 abs_cur_rank 0 total_num 163 478
200 150
200 346 561
checkcorrect 48 48 real score 1.9711020171642304 Hits@1 0.8181818181818182 Hits@3 0.9393939393939394 Hits@10 0.9939393939393939 MRR 0.8883960820324457 cur_rank 0 abs_cur_rank 0 total_num 164 478
200 5
200 20 152
checkcorrect 302 302 real score 0.8113203063607216 Hits@1 0.8132530120481928 Hits@3 0.9337349397590361 Hits@10 0.9939759036144579 MRR 0.8842491176828525 cur_rank 4 abs_cur_rank 4 total_num 165 478
200 103
200 164 702
checkcorrect 214 214 real score 1.8021708011627195 Hits@1 0.8143712574850299 Hits@3 0.9341317365269461 Hits@10 0.9940119760479041 MRR 0.8849422367386438 cur_rank 0 abs_cur_rank 2 total_num 166 478
200 150
200 254 834
checkcorrect 18 18 real score 1.9884229838848113 Hits@1 0.8154761904761905 Hits@3 0.9345238095238095 Hits@10 0.99404761904

200 44
200 114 134
checkcorrect 164 164 real score 1.5182994425296783 Hits@1 0.8177339901477833 Hits@3 0.9310344827586207 Hits@10 0.9802955665024631 MRR 0.8819482723555855 cur_rank 2 abs_cur_rank 2 total_num 202 478
200 20
200 253 174
checkcorrect 42 42 real score 1.8481536090373991 Hits@1 0.8186274509803921 Hits@3 0.9313725490196079 Hits@10 0.9803921568627451 MRR 0.8825269572950188 cur_rank 0 abs_cur_rank 0 total_num 203 478
200 150
200 71 322
checkcorrect 12 12 real score 1.916122806072235 Hits@1 0.8195121951219512 Hits@3 0.9317073170731708 Hits@10 0.9804878048780488 MRR 0.8830999965277261 cur_rank 0 abs_cur_rank 0 total_num 204 478
200 65
200 87 244
checkcorrect 68 68 real score 1.6640178680419924 Hits@1 0.8203883495145631 Hits@3 0.9320388349514563 Hits@10 0.9805825242718447 MRR 0.8836674722727371 cur_rank 0 abs_cur_rank 1 total_num 205 478
200 129
200 664 182
checkcorrect 70 70 real score 1.7586157739162445 Hits@1 0.821256038647343 Hits@3 0.9323671497584541 Hits@10 0.98067632850241

200 150
200 333 225
checkcorrect 74 74 real score 1.9049596905708313 Hits@1 0.8181818181818182 Hits@3 0.9338842975206612 Hits@10 0.9752066115702479 MRR 0.8814119658488214 cur_rank 0 abs_cur_rank 0 total_num 241 478
0 1
200 402 71
checkcorrect 202 202 real score 0.6594969987869262 Hits@1 0.8148148148148148 Hits@3 0.9300411522633745 Hits@10 0.9753086419753086 MRR 0.8783726455895965 cur_rank 6 abs_cur_rank 6 total_num 242 478
200 120
200 300 109
checkcorrect 24 24 real score 1.9750674724578858 Hits@1 0.8155737704918032 Hits@3 0.930327868852459 Hits@10 0.9754098360655737 MRR 0.8788711183535735 cur_rank 0 abs_cur_rank 0 total_num 243 478
200 150
200 114 271
checkcorrect 74 74 real score 1.8727739572525026 Hits@1 0.8163265306122449 Hits@3 0.9306122448979591 Hits@10 0.9755102040816327 MRR 0.8793655219521304 cur_rank 0 abs_cur_rank 0 total_num 244 478
200 150
200 201 608
checkcorrect 16 16 real score 1.9925607860088348 Hits@1 0.8170731707317073 Hits@3 0.9308943089430894 Hits@10 0.9756097560975

0 1
200 135 702
checkcorrect 0 0 real score 0.8881226778030396 Hits@1 0.8149466192170819 Hits@3 0.9288256227758007 Hits@10 0.9715302491103203 MRR 0.8774146876770331 cur_rank 0 abs_cur_rank 0 total_num 280 478
200 137
200 140 191
checkcorrect 26 26 real score 1.6935439258813858 Hits@1 0.8156028368794326 Hits@3 0.9290780141843972 Hits@10 0.9716312056737588 MRR 0.8778493873661216 cur_rank 0 abs_cur_rank 0 total_num 281 478
200 150
200 196 308
checkcorrect 208 208 real score 1.8821768939495085 Hits@1 0.8162544169611308 Hits@3 0.9293286219081273 Hits@10 0.9717314487632509 MRR 0.8782810149726018 cur_rank 0 abs_cur_rank 0 total_num 282 478
200 149
200 248 200
checkcorrect 24 24 real score 1.7593129195272923 Hits@1 0.8169014084507042 Hits@3 0.9295774647887324 Hits@10 0.971830985915493 MRR 0.8787096029480503 cur_rank 0 abs_cur_rank 0 total_num 283 478
200 135
200 247 293
checkcorrect 68 68 real score 1.9195111095905304 Hits@1 0.8175438596491228 Hits@3 0.9298245614035088 Hits@10 0.97192982456140

200 142
200 207 112
checkcorrect 26 26 real score 1.741340148448944 Hits@1 0.821875 Hits@3 0.93125 Hits@10 0.971875 MRR 0.880935019529975 cur_rank 0 abs_cur_rank 0 total_num 319 478
200 57
200 25 69
checkcorrect 118 118 real score 1.831160020828247 Hits@1 0.822429906542056 Hits@3 0.9314641744548287 Hits@10 0.9719626168224299 MRR 0.8813059384722493 cur_rank 0 abs_cur_rank 0 total_num 320 478
200 150
200 292 408
checkcorrect 24 24 real score 1.8454511940479277 Hits@1 0.8229813664596274 Hits@3 0.9316770186335404 Hits@10 0.9720496894409938 MRR 0.8816745535701614 cur_rank 0 abs_cur_rank 0 total_num 321 478
200 150
200 737 357
checkcorrect 96 96 real score 1.949259376525879 Hits@1 0.8204334365325078 Hits@3 0.9318885448916409 Hits@10 0.9721362229102167 MRR 0.880492898605548 cur_rank 1 abs_cur_rank 1 total_num 322 478
0 0
0 64 1
checkcorrect 296 296 real score 0.0 Hits@1 0.8179012345679012 Hits@3 0.9290123456790124 Hits@10 0.9691358024691358 MRR 0.8777960421573703 cur_rank 148 abs_cur_rank 148

200 150
200 85 236
checkcorrect 74 74 real score 1.8896984636783598 Hits@1 0.8133704735376045 Hits@3 0.9192200557103064 Hits@10 0.9665738161559888 MRR 0.872705339287708 cur_rank 0 abs_cur_rank 0 total_num 358 478
200 60
200 158 101
checkcorrect 8 8 real score 1.9604782938957215 Hits@1 0.8138888888888889 Hits@3 0.9194444444444444 Hits@10 0.9666666666666667 MRR 0.8730589355674643 cur_rank 0 abs_cur_rank 0 total_num 359 478
200 85
200 51 452
checkcorrect 124 124 real score 1.8293111503124238 Hits@1 0.814404432132964 Hits@3 0.9196675900277008 Hits@10 0.9667590027700831 MRR 0.8734105728650614 cur_rank 0 abs_cur_rank 0 total_num 360 478
200 150
200 95 121
checkcorrect 98 98 real score 1.949452668428421 Hits@1 0.8149171270718232 Hits@3 0.919889502762431 Hits@10 0.9668508287292817 MRR 0.8737602674151579 cur_rank 0 abs_cur_rank 0 total_num 361 478
200 150
200 843 303
checkcorrect 48 48 real score 1.8171990677714347 Hits@1 0.8154269972451791 Hits@3 0.9201101928374655 Hits@10 0.9669421487603306 M

200 17
200 55 22
checkcorrect 74 74 real score 1.7133781850337981 Hits@1 0.8165829145728644 Hits@3 0.9170854271356784 Hits@10 0.964824120603015 MRR 0.8737249486070349 cur_rank 0 abs_cur_rank 0 total_num 397 478
200 52
200 113 201
checkcorrect 242 242 real score 1.7363760769367218 Hits@1 0.8170426065162907 Hits@3 0.9172932330827067 Hits@10 0.9649122807017544 MRR 0.8740414274325812 cur_rank 0 abs_cur_rank 1 total_num 398 478
200 78
200 136 349
checkcorrect 200 200 real score 1.9779090881347656 Hits@1 0.8175 Hits@3 0.9175 Hits@10 0.965 MRR 0.8743563238639998 cur_rank 0 abs_cur_rank 0 total_num 399 478
200 140
200 43 57
checkcorrect 72 72 real score 1.9699953317642211 Hits@1 0.8179551122194514 Hits@3 0.9177057356608479 Hits@10 0.9650872817955112 MRR 0.8746696497396506 cur_rank 0 abs_cur_rank 0 total_num 400 478
200 150
200 103 869
checkcorrect 56 56 real score 1.958080530166626 Hits@1 0.818407960199005 Hits@3 0.917910447761194 Hits@10 0.9651741293532339 MRR 0.8749814167800992 cur_rank 0 ab

200 150
200 847 191
checkcorrect 134 134 real score 1.8982531368732452 Hits@1 0.8260869565217391 Hits@3 0.9221967963386728 Hits@10 0.9679633867276888 MRR 0.8809897701272309 cur_rank 0 abs_cur_rank 0 total_num 436 478
200 149
200 364 239
checkcorrect 62 62 real score 1.8970001995563508 Hits@1 0.8264840182648402 Hits@3 0.9223744292237442 Hits@10 0.9680365296803652 MRR 0.8812614829808216 cur_rank 0 abs_cur_rank 0 total_num 437 478
0 2
200 92 340
checkcorrect 364 364 real score 0.884582805633545 Hits@1 0.826879271070615 Hits@3 0.9225512528473804 Hits@10 0.9681093394077449 MRR 0.8815319579626422 cur_rank 0 abs_cur_rank 1 total_num 438 478
200 104
200 524 372
checkcorrect 28 28 real score 1.904426145553589 Hits@1 0.8272727272727273 Hits@3 0.9227272727272727 Hits@10 0.9681818181818181 MRR 0.8818012035127271 cur_rank 0 abs_cur_rank 0 total_num 439 478
200 58
200 280 152
checkcorrect 28 28 real score 1.8506553769111633 Hits@1 0.8276643990929705 Hits@3 0.9229024943310657 Hits@10 0.96825396825396

200 150
200 96 174
checkcorrect 26 26 real score 1.9433418631553652 Hits@1 0.8298319327731093 Hits@3 0.9243697478991597 Hits@10 0.9684873949579832 MRR 0.8833110739753498 cur_rank 0 abs_cur_rank 0 total_num 475 478
200 24
200 70 62
checkcorrect 10 10 real score 1.825635540485382 Hits@1 0.8301886792452831 Hits@3 0.9245283018867925 Hits@10 0.9685534591194969 MRR 0.8835557048475189 cur_rank 0 abs_cur_rank 0 total_num 476 478
200 150
200 405 547
checkcorrect 48 48 real score 1.9568548321723938 Hits@1 0.8305439330543933 Hits@3 0.9246861924686193 Hits@10 0.9686192468619247 MRR 0.8837993121595534 cur_rank 0 abs_cur_rank 0 total_num 477 478


In [32]:
###########################################
##obtain the AUC-PR for the test triples###
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import auc, plot_precision_recall_curve
import matplotlib.pyplot as plt

#we select all the triples in the inductive test set
pos_triples = list(data_ind_test)

#we build the negative samples by randomly replace head or tail entity in the triple.
neg_triples = list()

for i in range(len(pos_triples)):
    
    s_pos, r_pos, t_pos = pos_triples[i][0], pos_triples[i][1], pos_triples[i][2]
    
    #decide to replace the head or tail entity
    number_0 = random.uniform(0, 1)
    
    if number_0 < 0.5: #replace head entity
        
        s_neg = random.choice(list(new_ent_set))
        
        #filter out the existing triples
        while ((s_neg, r_pos, t_pos) in data_test) or (
               (s_neg, r_pos, t_pos) in data_valid) or (
               (s_neg, r_pos, t_pos) in data) or (
               (s_neg, r_pos, t_pos) in data_ind) or (
               (s_neg, r_pos, t_pos) in data_ind_valid) or (
               (s_neg, r_pos, t_pos) in data_ind_test):
            
            s_neg = random.choice(list(new_ent_set))
        
        neg_triples.append((s_neg, r_pos, t_pos))
    
    else: #replace tail entity

        t_neg = random.choice(list(new_ent_set))
        
        #filter out the existing triples
        while ((s_pos, r_pos, t_neg) in data_test) or (
               (s_pos, r_pos, t_neg) in data_valid) or (
               (s_pos, r_pos, t_neg) in data) or (
               (s_pos, r_pos, t_neg) in data_ind) or (
               (s_pos, r_pos, t_neg) in data_ind_valid) or (
               (s_pos, r_pos, t_neg) in data_ind_test):
            
            t_neg = random.choice(list(new_ent_set))
        
        neg_triples.append((s_pos, r_pos, t_neg))

if len(pos_triples) != len(neg_triples):
    raise ValueError('error when generating negative triples')
        
#combine all triples
all_triples = pos_triples + neg_triples

#obtain the label array
arr1 = np.ones((len(pos_triples),))
arr2 = np.zeros((len(neg_triples),))
y_test = np.concatenate((arr1, arr2))

#shuffle positive and negative triples (optional)
all_triples, y_test = shuffle(all_triples, y_test)

#obtain the score aray
y_score = np.zeros((len(y_test),))

#implement the scoring
for i in range(len(all_triples)):
    
    s, r, t = all_triples[i][0], all_triples[i][1], all_triples[i][2]
    
    path_score = path_based_triple_scoring(s, r, t, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)
    
    subg_score = subgraph_triple_scoring(s, r, t, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    ave_score = (path_score + subg_score)/float(2)
    
    y_score[i] = ave_score
    
    if i % 20 == 0 and i > 0:
        print('evaluating scores', i, len(all_triples))
        
        # Data to plot precision - recall curve
        precision, recall, thresholds = precision_recall_curve(y_test[:i], y_score[:i])
        # Use AUC function to calculate the area under the curve of precision recall curve
        auc_precision_recall = auc(recall, precision)
        print('AUC-PR is:', auc_precision_recall)
        
        
# Data to plot precision - recall curve
precision, recall, thresholds = precision_recall_curve(y_test, y_score)
# Use AUC function to calculate the area under the curve of precision recall curve
auc_precision_recall = auc(recall, precision)
print('AUC-PR is:', auc_precision_recall)

evaluating scores 20 956
AUC-PR is: 0.9428104575163401
evaluating scores 40 956
AUC-PR is: 0.9304020326079152
evaluating scores 60 956
AUC-PR is: 0.9667135622483582
evaluating scores 80 956
AUC-PR is: 0.9713179240822574
evaluating scores 100 956
AUC-PR is: 0.9573408606611016
evaluating scores 120 956
AUC-PR is: 0.9410166569651407
evaluating scores 140 956
AUC-PR is: 0.9360520513576525
evaluating scores 160 956
AUC-PR is: 0.9349266988461623
evaluating scores 180 956
AUC-PR is: 0.935354040836165
evaluating scores 200 956
AUC-PR is: 0.9330070390575947
evaluating scores 220 956
AUC-PR is: 0.9280776893187379
evaluating scores 240 956
AUC-PR is: 0.9363563679084085
evaluating scores 260 956
AUC-PR is: 0.9358637288791885
evaluating scores 280 956
AUC-PR is: 0.9364016428630065
evaluating scores 300 956
AUC-PR is: 0.9407425780391334
evaluating scores 320 956
AUC-PR is: 0.943779486524014
evaluating scores 340 956
AUC-PR is: 0.9422729729801615
evaluating scores 360 956
AUC-PR is: 0.934823234647539

In [33]:
##########################################################
##obtain the AUC-PR for the test triples, using sklearn###
from sklearn import datasets, metrics
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import auc, plot_precision_recall_curve
import matplotlib.pyplot as plt

#we select all the triples in the inductive test set
pos_triples = list(data_ind_test)

#we build the negative samples by randomly replace head or tail entity in the triple.
neg_triples = list()

for i in range(len(pos_triples)):
    
    s_pos, r_pos, t_pos = pos_triples[i][0], pos_triples[i][1], pos_triples[i][2]
    
    #decide to replace the head or tail entity
    number_0 = random.uniform(0, 1)
    
    if number_0 < 0.5: #replace head entity
        
        s_neg = random.choice(list(new_ent_set))
        
        #filter out the existing triples
        while ((s_neg, r_pos, t_pos) in data_test) or (
               (s_neg, r_pos, t_pos) in data_valid) or (
               (s_neg, r_pos, t_pos) in data) or (
               (s_neg, r_pos, t_pos) in data_ind) or (
               (s_neg, r_pos, t_pos) in data_ind_valid) or (
               (s_neg, r_pos, t_pos) in data_ind_test):
            
            s_neg = random.choice(list(new_ent_set))
        
        neg_triples.append((s_neg, r_pos, t_pos))
    
    else: #replace tail entity

        t_neg = random.choice(list(new_ent_set))
        
        #filter out the existing triples
        while ((s_pos, r_pos, t_neg) in data_test) or (
               (s_pos, r_pos, t_neg) in data_valid) or (
               (s_pos, r_pos, t_neg) in data) or (
               (s_pos, r_pos, t_neg) in data_ind) or (
               (s_pos, r_pos, t_neg) in data_ind_valid) or (
               (s_pos, r_pos, t_neg) in data_ind_test):
            
            t_neg = random.choice(list(new_ent_set))
        
        neg_triples.append((s_pos, r_pos, t_neg))

if len(pos_triples) != len(neg_triples):
    raise ValueError('error when generating negative triples')
        
#combine all triples
all_triples = pos_triples + neg_triples

#obtain the label array
arr1 = np.ones((len(pos_triples),))
arr2 = np.zeros((len(neg_triples),))
y_test = np.concatenate((arr1, arr2))

#shuffle positive and negative triples (optional)
all_triples, y_test = shuffle(all_triples, y_test)

#obtain the score aray
y_score = np.zeros((len(y_test),))

#implement the scoring
for i in range(len(all_triples)):
    
    s, r, t = all_triples[i][0], all_triples[i][1], all_triples[i][2]
    
    path_score = path_based_triple_scoring(s, r, t, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)
    
    subg_score = subgraph_triple_scoring(s, r, t, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    ave_score = (path_score + subg_score)/float(2)
    
    y_score[i] = ave_score
    
    if i % 20 == 0 and i > 0:
        print('evaluating scores', i, len(all_triples))
        auc = metrics.roc_auc_score(y_test[:i], y_score[:i])
        auc_pr = metrics.average_precision_score(y_test[:i], y_score[:i])
        print('auc, auc-pr', auc, auc_pr)
        
print('evaluating scores', i, len(all_triples))
auc = metrics.roc_auc_score(y_test, y_score)
auc_pr = metrics.average_precision_score(y_test, y_score)
print('(final) auc, auc-pr', auc, auc_pr)

evaluating scores 20 956
auc, auc-pr 1.0 1.0
evaluating scores 40 956
auc, auc-pr 0.9799498746867168 0.9739721590035262
evaluating scores 60 956
auc, auc-pr 0.9830316742081447 0.9751607570180985
evaluating scores 80 956
auc, auc-pr 0.9561904761904761 0.9483390526464994
evaluating scores 100 956
auc, auc-pr 0.9614443084455324 0.9550752489734347
evaluating scores 120 956
auc, auc-pr 0.9459307237397916 0.9414779867147277
evaluating scores 140 956
auc, auc-pr 0.9317617866004962 0.924296481755754
evaluating scores 160 956
auc, auc-pr 0.931704260651629 0.9299228119999657
evaluating scores 180 956
auc, auc-pr 0.9329117063492063 0.9291536097395061
evaluating scores 200 956
auc, auc-pr 0.9351823937292735 0.9309768265221041
evaluating scores 220 956
auc, auc-pr 0.9388262599469496 0.9362063074756625
evaluating scores 240 956
auc, auc-pr 0.9382483987747146 0.9352695227184451
evaluating scores 260 956
auc, auc-pr 0.9401222479378079 0.9372858906252133
evaluating scores 280 956
auc, auc-pr 0.94302081

In [34]:
######################################################
#obtain the Hits@N for entity prediction##############

#we select all the triples in the inductive test set
selected = list(data_ind_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    triple_list = list()
    
    #score the true triple
    s_pos, r_pos, t_pos = selected[i][0], selected[i][1], selected[i][2]

    path_score = path_based_triple_scoring(s_pos, r_pos, t_pos, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)

    subg_score = subgraph_triple_scoring(s_pos, r_pos, t_pos, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    ave_score = (path_score + subg_score)/float(2)
    
    triple_list.append([(s_pos, r_pos, t_pos), ave_score])
    
    #generate the 50 random samples
    for sub_i in range(50):
        
        #decide to replace the head or tail entity
        number_0 = random.uniform(0, 1)

        if number_0 < 0.5: #replace head entity
            
            s_neg = random.choice(list(new_ent_set))
            
            while ((s_neg, r_pos, t_pos) in data_test) or (
                   (s_neg, r_pos, t_pos) in data_valid) or (
                   (s_neg, r_pos, t_pos) in data) or (
                   (s_neg, r_pos, t_pos) in data_ind) or (
                   (s_neg, r_pos, t_pos) in data_ind_valid) or (
                   (s_neg, r_pos, t_pos) in data_ind_test):

                s_neg = random.choice(list(new_ent_set))
            
            path_score = path_based_triple_scoring(s_neg, r_pos, t_pos, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)

            subg_score = subgraph_triple_scoring(s_neg, r_pos, t_pos, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)

            ave_score = (path_score + subg_score)/float(2)

            triple_list.append([(s_neg, r_pos, t_pos), ave_score])
            
        else: #replace tail entity

            t_neg = random.choice(list(new_ent_set))
            
            #filter out the existing triples
            while ((s_pos, r_pos, t_neg) in data_test) or (
                   (s_pos, r_pos, t_neg) in data_valid) or (
                   (s_pos, r_pos, t_neg) in data) or (
                   (s_pos, r_pos, t_neg) in data_ind) or (
                   (s_pos, r_pos, t_neg) in data_ind_valid) or (
                   (s_pos, r_pos, t_neg) in data_ind_test):

                t_neg = random.choice(list(new_ent_set))
            
            path_score = path_based_triple_scoring(s_pos, r_pos, t_neg, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)

            subg_score = subgraph_triple_scoring(s_pos, r_pos, t_neg, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)

            ave_score = (path_score + subg_score)/float(2)

            triple_list.append([(s_pos, r_pos, t_neg), ave_score])
            
    #random shuffle!
    random.shuffle(triple_list)
    
    #sort
    sorted_list = sorted(triple_list, key = lambda x: x[-1], reverse=True)
    
    p = 0
    
    while p < len(sorted_list) and sorted_list[p][0] != (s_pos, r_pos, t_pos):
            
        p += 1
    
    if p == 0:
        
        Hits_at_1 += 1
        
    if p < 3:
        
        Hits_at_3 += 1
        
    if p < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p + 1.) 
        
    print('checkcorrect', (s_pos, r_pos, t_pos), sorted_list[p][0],
          'real score', sorted_list[p][-1],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'rank', p,
          'total_num', i, len(selected))

checkcorrect (2903, 62, 3913) (2903, 62, 3913) real score 0.9517079740762711 Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 rank 0 total_num 0 478
checkcorrect (3160, 158, 2713) (3160, 158, 2713) real score 0.7914817497134208 Hits@1 0.5 Hits@3 0.5 Hits@10 1.0 MRR 0.5833333333333334 rank 5 total_num 1 478
checkcorrect (3741, 74, 3484) (3741, 74, 3484) real score 0.958466213941574 Hits@1 0.3333333333333333 Hits@3 0.6666666666666666 Hits@10 1.0 MRR 0.5 rank 2 total_num 2 478
checkcorrect (3615, 216, 3911) (3615, 216, 3911) real score 0.7211202189326287 Hits@1 0.25 Hits@3 0.5 Hits@10 1.0 MRR 0.40625 rank 7 total_num 3 478
checkcorrect (3617, 380, 3338) (3617, 380, 3338) real score 0.6343149177730083 Hits@1 0.2 Hits@3 0.6 Hits@10 1.0 MRR 0.39166666666666666 rank 2 total_num 4 478
checkcorrect (3196, 66, 2895) (3196, 66, 2895) real score 0.49480889439582826 Hits@1 0.3333333333333333 Hits@3 0.6666666666666666 Hits@10 1.0 MRR 0.4930555555555555 rank 0 total_num 5 478
checkcorrect (2774, 48, 3656) (

checkcorrect (2649, 78, 2798) (2649, 78, 2798) real score 0.9355483949184418 Hits@1 0.3695652173913043 Hits@3 0.6739130434782609 Hits@10 0.9782608695652174 MRR 0.5460769090244793 rank 5 total_num 45 478
checkcorrect (2678, 228, 3103) (2678, 228, 3103) real score 0.9857043206691742 Hits@1 0.3829787234042553 Hits@3 0.6808510638297872 Hits@10 0.9787234042553191 MRR 0.5557348471303415 rank 0 total_num 46 478
checkcorrect (3225, 80, 3612) (3225, 80, 3612) real score 0.6196498405188322 Hits@1 0.375 Hits@3 0.6875 Hits@10 0.9791666666666666 MRR 0.5511014822595705 rank 2 total_num 47 478
checkcorrect (3866, 138, 2765) (3866, 138, 2765) real score 0.8432230561971664 Hits@1 0.3673469387755102 Hits@3 0.673469387755102 Hits@10 0.9795918367346939 MRR 0.5439361458869261 rank 4 total_num 48 478
checkcorrect (2756, 48, 3750) (2756, 48, 3750) real score 0.9865304350852966 Hits@1 0.38 Hits@3 0.68 Hits@10 0.98 MRR 0.5530574229691876 rank 0 total_num 49 478
checkcorrect (3279, 68, 2672) (3279, 68, 2672) re

checkcorrect (3794, 74, 3655) (3794, 74, 3655) real score 0.8279321998357774 Hits@1 0.4318181818181818 Hits@3 0.7159090909090909 Hits@10 0.9659090909090909 MRR 0.6003978864272983 rank 11 total_num 87 478
checkcorrect (4138, 38, 4022) (4138, 38, 4022) real score 0.9370129048824309 Hits@1 0.42696629213483145 Hits@3 0.7078651685393258 Hits@10 0.9662921348314607 MRR 0.5955245019356057 rank 5 total_num 88 478
checkcorrect (3042, 206, 3224) (3042, 206, 3224) real score 0.9167108595371246 Hits@1 0.4222222222222222 Hits@3 0.7111111111111111 Hits@10 0.9666666666666667 MRR 0.5944631185807657 rank 1 total_num 89 478
checkcorrect (3038, 40, 4032) (3038, 40, 4032) real score 0.9866609513759613 Hits@1 0.42857142857142855 Hits@3 0.7142857142857143 Hits@10 0.967032967032967 MRR 0.5989195678271308 rank 0 total_num 90 478
checkcorrect (2741, 28, 2742) (2741, 28, 2742) real score 0.8791063711047173 Hits@1 0.42391304347826086 Hits@3 0.7065217391304348 Hits@10 0.967391304347826 MRR 0.5945834855681403 rank 

checkcorrect (2622, 70, 3720) (2622, 70, 3720) real score 0.8451667606830597 Hits@1 0.3953488372093023 Hits@3 0.7054263565891473 Hits@10 0.9457364341085271 MRR 0.5775291385414504 rank 2 total_num 128 478
checkcorrect (3034, 68, 3887) (3034, 68, 3887) real score 0.911619558930397 Hits@1 0.4 Hits@3 0.7076923076923077 Hits@10 0.9461538461538461 MRR 0.5807789143988238 rank 0 total_num 129 478
checkcorrect (3015, 24, 2765) (3015, 24, 2765) real score 0.9829557865858078 Hits@1 0.40458015267175573 Hits@3 0.7099236641221374 Hits@10 0.9465648854961832 MRR 0.5839790753576114 rank 0 total_num 130 478
checkcorrect (3448, 48, 4029) (3448, 48, 4029) real score 0.9843365788459777 Hits@1 0.4015151515151515 Hits@3 0.7121212121212122 Hits@10 0.946969696969697 MRR 0.5833428702412659 rank 1 total_num 131 478
checkcorrect (2678, 134, 3969) (2678, 134, 3969) real score 0.8768970191478729 Hits@1 0.39849624060150374 Hits@3 0.706766917293233 Hits@10 0.9473684210526315 MRR 0.5804605930214067 rank 4 total_num 13

checkcorrect (3265, 12, 2875) (3265, 12, 2875) real score 0.9401580691337585 Hits@1 0.3764705882352941 Hits@3 0.6823529411764706 Hits@10 0.9529411764705882 MRR 0.5633829674746629 rank 1 total_num 169 478
checkcorrect (4096, 8, 3594) (4096, 8, 3594) real score 0.48189271688461305 Hits@1 0.38011695906432746 Hits@3 0.6842105263157895 Hits@10 0.9532163742690059 MRR 0.5659362834543432 rank 0 total_num 170 478
checkcorrect (2959, 20, 3611) (2959, 20, 3611) real score 0.9285505264997482 Hits@1 0.38372093023255816 Hits@3 0.686046511627907 Hits@10 0.9534883720930233 MRR 0.5684599097133296 rank 0 total_num 171 478
checkcorrect (4048, 26, 3557) (4048, 26, 3557) real score 0.47381618320941926 Hits@1 0.3815028901734104 Hits@3 0.6820809248554913 Hits@10 0.953757225433526 MRR 0.5659997781130048 rank 6 total_num 172 478
checkcorrect (2974, 12, 2875) (2974, 12, 2875) real score 0.8818329155445099 Hits@1 0.3793103448275862 Hits@3 0.6781609195402298 Hits@10 0.9540229885057471 MRR 0.5633216184686772 rank 

checkcorrect (3103, 228, 3051) (3103, 228, 3051) real score 0.9851728707551957 Hits@1 0.3981042654028436 Hits@3 0.6872037914691943 Hits@10 0.943127962085308 MRR 0.5750963095524118 rank 0 total_num 210 478
checkcorrect (3019, 20, 3416) (3019, 20, 3416) real score 0.9371877402067185 Hits@1 0.4009433962264151 Hits@3 0.6886792452830188 Hits@10 0.9433962264150944 MRR 0.5771005722432022 rank 0 total_num 211 478
checkcorrect (3923, 80, 2652) (3923, 80, 2652) real score 0.9799066722393036 Hits@1 0.40375586854460094 Hits@3 0.6901408450704225 Hits@10 0.9436619718309859 MRR 0.5790860155660041 rank 0 total_num 212 478
checkcorrect (2678, 18, 2817) (2678, 18, 2817) real score 0.9674478322267532 Hits@1 0.40654205607476634 Hits@3 0.6915887850467289 Hits@10 0.9439252336448598 MRR 0.581052903343733 rank 0 total_num 213 478
checkcorrect (3091, 230, 2665) (3091, 230, 2665) real score 0.0 Hits@1 0.4046511627906977 Hits@3 0.6883720930232559 Hits@10 0.9395348837209302 MRR 0.5784610736714809 rank 41 total_nu

checkcorrect (3154, 146, 2616) (3154, 146, 2616) real score 0.4935576766729355 Hits@1 0.4087301587301587 Hits@3 0.6984126984126984 Hits@10 0.9444444444444444 MRR 0.5843916223687553 rank 1 total_num 251 478
checkcorrect (2609, 48, 2878) (2609, 48, 2878) real score 0.9816827088594436 Hits@1 0.41106719367588934 Hits@3 0.6996047430830039 Hits@10 0.9446640316205533 MRR 0.5860343432289579 rank 0 total_num 252 478
checkcorrect (2772, 96, 2744) (2772, 96, 2744) real score 0.9882924437522889 Hits@1 0.41338582677165353 Hits@3 0.7007874015748031 Hits@10 0.9448818897637795 MRR 0.5876641292792375 rank 0 total_num 253 478
checkcorrect (3320, 28, 2978) (3320, 28, 2978) real score 0.7678994722664356 Hits@1 0.4117647058823529 Hits@3 0.6980392156862745 Hits@10 0.9450980392156862 MRR 0.5858497601448092 rank 7 total_num 254 478
checkcorrect (2614, 158, 3479) (2614, 158, 3479) real score 0.40132721662521365 Hits@1 0.41015625 Hits@3 0.6953125 Hits@10 0.94140625 MRR 0.5838617652884742 rank 12 total_num 255 4

checkcorrect (3680, 364, 2729) (3680, 364, 2729) real score 0.8319695740938187 Hits@1 0.3993174061433447 Hits@3 0.6757679180887372 Hits@10 0.931740614334471 MRR 0.570825580395318 rank 1 total_num 292 478
checkcorrect (2733, 28, 2732) (2733, 28, 2732) real score 0.9245199203491211 Hits@1 0.3979591836734694 Hits@3 0.6768707482993197 Hits@10 0.9319727891156463 MRR 0.5705846770606401 rank 1 total_num 293 478
checkcorrect (2817, 134, 2823) (2817, 134, 2823) real score 0.8381123871542513 Hits@1 0.39661016949152544 Hits@3 0.6745762711864407 Hits@10 0.9322033898305084 MRR 0.5693284578163667 rank 4 total_num 294 478
checkcorrect (4047, 182, 3353) (4047, 182, 3353) real score 0.9092901021242141 Hits@1 0.39864864864864863 Hits@3 0.6756756756756757 Hits@10 0.9324324324324325 MRR 0.5707834292426628 rank 0 total_num 295 478
checkcorrect (4008, 86, 4252) (4008, 86, 4252) real score 0.4103895902633667 Hits@1 0.39730639730639733 Hits@3 0.6734006734006734 Hits@10 0.9326599326599326 MRR 0.569535000187973

checkcorrect (2638, 42, 3130) (2638, 42, 3130) real score 0.9420070677995682 Hits@1 0.38323353293413176 Hits@3 0.6676646706586826 Hits@10 0.9281437125748503 MRR 0.5589181252152164 rank 3 total_num 333 478
checkcorrect (3576, 24, 2619) (3576, 24, 2619) real score 0.8888330712914467 Hits@1 0.382089552238806 Hits@3 0.6686567164179105 Hits@10 0.9283582089552239 MRR 0.5587422502145739 rank 1 total_num 334 478
checkcorrect (3998, 310, 3497) (3998, 310, 3497) real score 0.0 Hits@1 0.38095238095238093 Hits@3 0.6666666666666666 Hits@10 0.9255952380952381 MRR 0.557193795715309 rank 25 total_num 335 478
checkcorrect (3105, 80, 2611) (3105, 80, 2611) real score 0.9735586971044541 Hits@1 0.3827893175074184 Hits@3 0.6676557863501483 Hits@10 0.9258160237388724 MRR 0.5585077607131864 rank 0 total_num 336 478
checkcorrect (2758, 38, 4033) (2758, 38, 4033) real score 0.9586482614278793 Hits@1 0.3816568047337278 Hits@3 0.6686390532544378 Hits@10 0.9260355029585798 MRR 0.5583346608294196 rank 1 total_num 

checkcorrect (2691, 48, 2934) (2691, 48, 2934) real score 0.9804432541131973 Hits@1 0.38666666666666666 Hits@3 0.672 Hits@10 0.9253333333333333 MRR 0.5617047374200135 rank 2 total_num 374 478
checkcorrect (2758, 48, 2639) (2758, 48, 2639) real score 0.9753034979104995 Hits@1 0.38563829787234044 Hits@3 0.6728723404255319 Hits@10 0.925531914893617 MRR 0.5610973666644639 rank 2 total_num 375 478
checkcorrect (2901, 70, 2641) (2901, 70, 2641) real score 0.9581540197134018 Hits@1 0.38461538461538464 Hits@3 0.6737400530503979 Hits@10 0.9257294429708223 MRR 0.5609353046839215 rank 1 total_num 376 478
checkcorrect (3886, 10, 3131) (3886, 10, 3131) real score 0.9695760518312454 Hits@1 0.3862433862433862 Hits@3 0.6746031746031746 Hits@10 0.9259259259259259 MRR 0.562096851496927 rank 0 total_num 377 478
checkcorrect (3279, 132, 2809) (3279, 132, 2809) real score 0.8931122660636901 Hits@1 0.38522427440633245 Hits@3 0.6728232189973615 Hits@10 0.9261213720316622 MRR 0.5609069155064631 rank 8 total_n

checkcorrect (2825, 106, 3132) (2825, 106, 3132) real score 0.37193499207496644 Hits@1 0.38221153846153844 Hits@3 0.6730769230769231 Hits@10 0.9230769230769231 MRR 0.5587715846944842 rank 7 total_num 415 478
checkcorrect (3251, 80, 3191) (3251, 80, 3191) real score 0.9393518745899201 Hits@1 0.381294964028777 Hits@3 0.6714628297362111 Hits@10 0.9232613908872902 MRR 0.5577741879514689 rank 6 total_num 416 478
checkcorrect (2825, 22, 4212) (2825, 22, 4212) real score 0.9641784191131592 Hits@1 0.3827751196172249 Hits@3 0.6722488038277512 Hits@10 0.9234449760765551 MRR 0.5588321444396233 rank 0 total_num 417 478
checkcorrect (4250, 42, 3094) (4250, 42, 3094) real score 0.9412201225757599 Hits@1 0.3818615751789976 Hits@3 0.6706443914081146 Hits@10 0.9236276849642004 MRR 0.5578393640062522 rank 6 total_num 418 478
checkcorrect (3794, 80, 2965) (3794, 80, 2965) real score 0.9349306643009185 Hits@1 0.38095238095238093 Hits@3 0.669047619047619 Hits@10 0.9238095238095239 MRR 0.556987365520523 ran

checkcorrect (2691, 48, 2896) (2691, 48, 2896) real score 0.9854595929384231 Hits@1 0.36323851203501095 Hits@3 0.6542669584245077 Hits@10 0.925601750547046 MRR 0.543781275898729 rank 0 total_num 456 478
checkcorrect (2760, 62, 2758) (2760, 62, 2758) real score 0.6708037823438644 Hits@1 0.3624454148471616 Hits@3 0.6528384279475983 Hits@10 0.925764192139738 MRR 0.5430306617592121 rank 4 total_num 457 478
checkcorrect (2944, 244, 2677) (2944, 244, 2677) real score 0.9585238426923752 Hits@1 0.3638344226579521 Hits@3 0.6535947712418301 Hits@10 0.9259259259259259 MRR 0.5440262376595189 rank 0 total_num 458 478
checkcorrect (3316, 56, 3831) (3316, 56, 3831) real score 0.8399241983890533 Hits@1 0.3630434782608696 Hits@3 0.6543478260869565 Hits@10 0.9260869565217391 MRR 0.5435682096066359 rank 2 total_num 459 478
checkcorrect (2721, 44, 3260) (2721, 44, 3260) real score 0.7592933133244515 Hits@1 0.36225596529284165 Hits@3 0.6550976138828634 Hits@10 0.9262472885032538 MRR 0.5431121686602729 rank

#### Fine tuned

In [34]:
#function to build the big batche for path-based training
def build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    count = 0
    for s in existing_ids:
        
        #impliment the path finding algorithm to find paths between s and t
        result, length_dict = Class_2.obtain_paths('direct_neighbour', s, 'nb', lower_bd, upper_bd, one_hop)
        
        for iteration in range(2):

            #proceed only if at least three paths are between s and t
            for t in result:

                if len(s_t_r[(s,t)]) == 0:

                    raise ValueError(s,t,id2entity[s], id2entity[t])

                #we are only interested in forward link in relation prediciton
                ini_r_list = list()

                #obtain initial relations between s and t
                for r in s_t_r[(s,t)]:
                    if r % 2 == 0:#initial relation id is always an even number
                        ini_r_list.append(r)

                #if there exist more than three paths between s and t, 
                #and inital connection between s and t exists,
                #and not every r in the relation dictionary exists between s and t (although this is rare)
                #we then proceed
                if len(result[t]) >= 3 and len(ini_r_list) > 0 and len(ini_r_list) < int(num_ini_r):

                    #obtain the list form of all the paths from s to t
                    temp_path_list = list(result[t])

                    temp_pair = random.sample(temp_path_list, 3)

                    path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

                    #####positive#####################
                    #append the paths: note that we add the space holder id at the end of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    r = random.choice(ini_r_list)
                    x_r_list.append([r])
                    y_list.append(1.)

                    #####negative#####################
                    #append the paths: note that we add the space holder id at the end
                    #of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    neg_r_list = list(ini_r_id_set.difference(set(ini_r_list)))
                    r_ran = random.choice(neg_r_list)
                    x_r_list.append([r_ran])
                    y_list.append(0.)
        
        count += 1
        if count % 100 == 0:
            print('generating big-batches for path-based model', count, len(existing_ids))

In [35]:
#Again, it is too slow to run the path-finding algorithm again and again on the complete FB15K-237
#Instead, we will find the subgraph for each entity once.
#then in the subgraph based training, the subgraphs are stored and used for multiple times
def store_subgraph_dicts(lower_bd, upper_bd, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
    
    #the ids to start path finding
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    #Dict stores the subgraph for each entity
    Dict_1 = dict()
    
    count = 0
    for s in existing_ids:
        
        path_set = set()
            
        result, length_dict = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)

        for t_ in result:
            for path in result[t_]:
                path_set.add(path)

        del(result, length_dict)
        
        path_list = list(path_set)
        
        path_select = random.sample(path_list, min(len(path_list), 100))
            
        Dict_1[s] = deepcopy(path_select)
        
        count += 1
        if count % 100 == 0:
            print('generating and storing paths for the path-based model', count, len(existing_ids))
        
    return(Dict_1)

In [36]:
#function to build the big-batch for one-hope neighbor training
def build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_s_list, x_t_list, x_r_list, y_list, Dict,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
        
    #if an entity has at least three out-stretching paths, it is a qualified one
    qualified = set()
    for e in Dict:
        if len(Dict[e]) >= 3:
            qualified.add(e)
    qualified = list(qualified)
    
    data = list(data)
    
    for iteration in range(2):

        data = shuffle(data)

        for i_0 in range(len(data)):

            triple = data[i_0]

            s, r, t = triple[0], triple[1], triple[2] #obtain entities and relation IDs

            if s in qualified and t in qualified:

                #obtain the path list for true entities
                path_s, path_t = list(Dict[s]), list(Dict[t])

                #####positive step###########
                #randomly obtain three paths for true entities
                temp_s = random.sample(path_s, 3)
                temp_t = random.sample(path_t, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative step for relation###########
                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                neg_r_list = list(ini_r_id_set.difference({r}))
                r_ran = random.choice(neg_r_list)
                x_r_list.append([r_ran])
                y_list.append(0.)
                
                ##############################################
                ##############################################
                #randomly choose two negative sampled entities
                s_ran = random.choice(qualified)
                t_ran = random.choice(qualified)

                #obtain the path list for random entities
                path_s_ran, path_t_ran = list(Dict[s_ran]), list(Dict[t_ran])
                
                #####positive step#################
                #Again: randomly obtain three paths
                temp_s = random.sample(path_s, 3)
                temp_t = random.sample(path_t, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative for source entity###########
                #randomly obtain three paths
                temp_s = random.sample(path_s_ran, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(0.)

                #####positive step###########
                #Again: randomly obtain three paths
                temp_s = random.sample(path_s, 3)
                temp_t = random.sample(path_t, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative for target entity###########
                #randomly obtain three paths
                temp_t = random.sample(path_t_ran, 3)
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(0.)

            if i_0 % 200 == 0:
                print('generating big-batches for subgraph-based model', i_0, len(data), iteration)

In [37]:
###fine tune the path-based model
lower_bd = lower_bound
upper_bd = upper_bound_path
batch_size = 32

#define the training lists
train_p_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_path(lower_bd, upper_bd, data_ind, one_hop_ind, s_t_r_ind,
                      train_p_list, train_r_list, train_y_list,
                      relation2id, entity2id, id2relation, id2entity)   

#######################################
###do the training#####################

#generate the input arrays
x_train_1 = np.asarray(train_p_list['1'], dtype='int')
x_train_2 = np.asarray(train_p_list['2'], dtype='int')
x_train_3 = np.asarray(train_p_list['3'], dtype='int')
x_train_r = np.asarray(train_r_list, dtype='int')
y_train = np.asarray(train_y_list, dtype='int')

model.fit([x_train_1, x_train_2, x_train_3, x_train_r], y_train,
           batch_size=batch_size, epochs=5)

generating big-batches for path-based model 100 922
generating big-batches for path-based model 200 922
generating big-batches for path-based model 300 922
generating big-batches for path-based model 400 922
generating big-batches for path-based model 500 922
generating big-batches for path-based model 600 922
generating big-batches for path-based model 700 922
generating big-batches for path-based model 800 922
generating big-batches for path-based model 900 922
Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7ff2daab26a0>

In [38]:
###fine tune the subgraph model
lower_bd = lower_bound
upper_bd = upper_bound_subg
batch_size = 32

Dict_train_ind = store_subgraph_dicts(lower_bd, upper_bd, data_ind, one_hop_ind, s_t_r_ind,
                         relation2id, entity2id, id2relation, id2entity)

#define the training lists
train_s_list, train_t_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_subgraph(lower_bd, upper_bd, data_ind, one_hop_ind, s_t_r_ind,
                      train_s_list, train_t_list, train_r_list, train_y_list, Dict_train_ind,
                      relation2id, entity2id, id2relation, id2entity)

#######################################
###do the training#####################

#generate the input arrays
x_train_s_1 = np.asarray(train_s_list['1'], dtype='int')
x_train_s_2 = np.asarray(train_s_list['2'], dtype='int')
x_train_s_3 = np.asarray(train_s_list['3'], dtype='int')

x_train_t_1 = np.asarray(train_t_list['1'], dtype='int')
x_train_t_2 = np.asarray(train_t_list['2'], dtype='int')
x_train_t_3 = np.asarray(train_t_list['3'], dtype='int')

x_train_r = np.asarray(train_r_list, dtype='int')
y_train = np.asarray(train_y_list, dtype='int')

model_2.fit([x_train_s_1, x_train_s_2, x_train_s_3, 
             x_train_t_1, x_train_t_2, x_train_t_3, x_train_r], y_train,
             batch_size=batch_size, epochs=5)

generating and storing paths for the path-based model 100 922
generating and storing paths for the path-based model 200 922
generating and storing paths for the path-based model 300 922
generating and storing paths for the path-based model 400 922
generating and storing paths for the path-based model 500 922
generating and storing paths for the path-based model 600 922
generating and storing paths for the path-based model 700 922
generating and storing paths for the path-based model 800 922
generating and storing paths for the path-based model 900 922
generating big-batches for subgraph-based model 0 1618 0
generating big-batches for subgraph-based model 200 1618 0
generating big-batches for subgraph-based model 400 1618 0
generating big-batches for subgraph-based model 600 1618 0
generating big-batches for subgraph-based model 800 1618 0
generating big-batches for subgraph-based model 1000 1618 0
generating big-batches for subgraph-based model 1200 1618 0
generating big-batches for su

<keras.callbacks.History at 0x7ff290652df0>

In [39]:
########################################################
#obtain the Hits@N for relation prediction##############

#we select all the triples in the inductive test set
selected = list(data_ind_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    #run the path-based scoring
    score_dict_path = path_based_relation_scoring(s_true, t_true, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)
    
    #run the one-hop neighbour based scoring
    score_dict_subg = subgraph_relation_scoring(s_true, t_true, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    #final score dict
    score_dict = defaultdict(float)
    
    for r in score_dict_path:
        score_dict[r] += score_dict_path[r]
    for r in score_dict_subg:
        score_dict[r] += score_dict_subg[r]
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        #again, we only care about initial relation prediciton
        if r % 2 == 0:
        
            if r in score_dict:

                temp_list.append([score_dict[r], r])

            else:

                temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
            (s_true, sorted_list[p][1], t_true) in data_valid) or (
            (s_true, sorted_list[p][1], t_true) in data) or (
            (s_true, sorted_list[p][1], t_true) in data_ind) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_valid) or (
            (s_true, sorted_list[p][1], t_true) in data_ind_test):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', r_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

0 0
9 10 10
checkcorrect 0 0 real score 0.3973050579428673 Hits@1 0.0 Hits@3 1.0 Hits@10 1.0 MRR 0.5 cur_rank 1 abs_cur_rank 1 total_num 0 188
0 1
9 11 31
checkcorrect 0 0 real score 0.7147584974765777 Hits@1 0.5 Hits@3 1.0 Hits@10 1.0 MRR 0.75 cur_rank 0 abs_cur_rank 0 total_num 1 188
0 1
9 34 10
checkcorrect 0 0 real score 0.7171109318733215 Hits@1 0.6666666666666666 Hits@3 1.0 Hits@10 1.0 MRR 0.8333333333333334 cur_rank 0 abs_cur_rank 0 total_num 2 188
9 61
9 13 22
checkcorrect 0 0 real score 1.6940622508525849 Hits@1 0.75 Hits@3 1.0 Hits@10 1.0 MRR 0.875 cur_rank 0 abs_cur_rank 0 total_num 3 188
9 11
9 47 31
checkcorrect 0 0 real score 1.4414083257317545 Hits@1 0.8 Hits@3 1.0 Hits@10 1.0 MRR 0.9 cur_rank 0 abs_cur_rank 0 total_num 4 188
0 2
9 14 58
checkcorrect 8 8 real score 0.2447445958852768 Hits@1 0.6666666666666666 Hits@3 1.0 Hits@10 1.0 MRR 0.8333333333333334 cur_rank 1 abs_cur_rank 1 total_num 5 188
9 30
9 22 9
checkcorrect 8 8 real score 0.5786192186176777 Hits@1 0.57142857

9 38
9 18 8
checkcorrect 0 0 real score 1.2535340383648874 Hits@1 0.7843137254901961 Hits@3 0.9803921568627451 Hits@10 1.0 MRR 0.8830065359477124 cur_rank 0 abs_cur_rank 0 total_num 50 188
9 54
9 23 20
checkcorrect 0 0 real score 1.7500863373279572 Hits@1 0.7884615384615384 Hits@3 0.9807692307692307 Hits@10 1.0 MRR 0.8852564102564102 cur_rank 0 abs_cur_rank 0 total_num 51 188
9 5
9 71 28
checkcorrect 0 0 real score 1.5344414383172988 Hits@1 0.7924528301886793 Hits@3 0.9811320754716981 Hits@10 1.0 MRR 0.8874213836477987 cur_rank 0 abs_cur_rank 0 total_num 52 188
9 53
9 35 13
checkcorrect 0 0 real score 1.4713187865912913 Hits@1 0.7962962962962963 Hits@3 0.9814814814814815 Hits@10 1.0 MRR 0.8895061728395062 cur_rank 0 abs_cur_rank 0 total_num 53 188
9 4
9 23 22
checkcorrect 10 10 real score 1.7709940746426582 Hits@1 0.8 Hits@3 0.9818181818181818 Hits@10 1.0 MRR 0.8915151515151515 cur_rank 0 abs_cur_rank 0 total_num 54 188
9 86
9 32 19
checkcorrect 0 0 real score 1.4118973553180694 Hits@1

9 150
9 12 13
checkcorrect 0 0 real score 1.5491866767406464 Hits@1 0.8526315789473684 Hits@3 0.9789473684210527 Hits@10 1.0 MRR 0.9164912280701754 cur_rank 0 abs_cur_rank 0 total_num 94 188
9 7
9 26 9
checkcorrect 0 0 real score 1.6799812585115432 Hits@1 0.8541666666666666 Hits@3 0.9791666666666666 Hits@10 1.0 MRR 0.9173611111111111 cur_rank 0 abs_cur_rank 0 total_num 95 188
0 0
0 28 1
checkcorrect 0 0 real score 0.0 Hits@1 0.8556701030927835 Hits@3 0.979381443298969 Hits@10 1.0 MRR 0.9182130584192439 cur_rank 0 abs_cur_rank 0 total_num 96 188
0 0
9 32 5
checkcorrect 8 8 real score 0.2664215698838234 Hits@1 0.8469387755102041 Hits@3 0.9795918367346939 Hits@10 1.0 MRR 0.9139455782312925 cur_rank 1 abs_cur_rank 1 total_num 97 188
9 65
9 27 34
checkcorrect 0 0 real score 1.5753029331564905 Hits@1 0.8484848484848485 Hits@3 0.9797979797979798 Hits@10 1.0 MRR 0.9148148148148147 cur_rank 0 abs_cur_rank 0 total_num 98 188
0 1
9 15 46
checkcorrect 0 0 real score 0.6504505783319473 Hits@1 0.85 

9 45
9 28 37
checkcorrect 0 0 real score 1.6402869790792465 Hits@1 0.8561151079136691 Hits@3 0.9784172661870504 Hits@10 1.0 MRR 0.9189448441247002 cur_rank 0 abs_cur_rank 0 total_num 138 188
0 1
9 45 22
checkcorrect 0 0 real score 0.6212806850671768 Hits@1 0.8571428571428571 Hits@3 0.9785714285714285 Hits@10 1.0 MRR 0.9195238095238096 cur_rank 0 abs_cur_rank 0 total_num 139 188
9 93
9 31 26
checkcorrect 0 0 real score 1.6512427419424056 Hits@1 0.8581560283687943 Hits@3 0.9787234042553191 Hits@10 1.0 MRR 0.9200945626477542 cur_rank 0 abs_cur_rank 0 total_num 140 188
9 87
9 39 44
checkcorrect 0 0 real score 1.5928798407316207 Hits@1 0.8591549295774648 Hits@3 0.9788732394366197 Hits@10 1.0 MRR 0.9206572769953053 cur_rank 0 abs_cur_rank 0 total_num 141 188
0 2
9 14 9
checkcorrect 0 0 real score 0.46829487979412077 Hits@1 0.8531468531468531 Hits@3 0.9790209790209791 Hits@10 1.0 MRR 0.9177156177156178 cur_rank 1 abs_cur_rank 1 total_num 142 188
9 132
9 58 37
checkcorrect 0 0 real score 1.661

9 16 34
checkcorrect 0 0 real score 0.6142259627580643 Hits@1 0.8524590163934426 Hits@3 0.9726775956284153 Hits@10 1.0 MRR 0.9163023679417123 cur_rank 0 abs_cur_rank 0 total_num 182 188
9 4
9 22 19
checkcorrect 10 10 real score 1.9311034083366394 Hits@1 0.8532608695652174 Hits@3 0.9728260869565217 Hits@10 1.0 MRR 0.9167572463768117 cur_rank 0 abs_cur_rank 0 total_num 183 188
9 22
9 50 58
checkcorrect 0 0 real score 1.5718801110982894 Hits@1 0.8540540540540541 Hits@3 0.972972972972973 Hits@10 1.0 MRR 0.9172072072072073 cur_rank 0 abs_cur_rank 0 total_num 184 188
9 18
9 16 10
checkcorrect 0 0 real score 1.3778696570545435 Hits@1 0.8548387096774194 Hits@3 0.9731182795698925 Hits@10 1.0 MRR 0.9176523297491039 cur_rank 0 abs_cur_rank 0 total_num 185 188
9 14
9 12 33
checkcorrect 0 0 real score 1.2510857969522475 Hits@1 0.8556149732620321 Hits@3 0.9732620320855615 Hits@10 1.0 MRR 0.9180926916221034 cur_rank 0 abs_cur_rank 0 total_num 186 188
0 0
0 1 51
checkcorrect 0 0 real score 0.0 Hits@1 

In [40]:
###########################################
##obtain the AUC-PR for the test triples###
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import auc, plot_precision_recall_curve
import matplotlib.pyplot as plt

#we select all the triples in the inductive test set
pos_triples = list(data_ind_test)

#we build the negative samples by randomly replace head or tail entity in the triple.
neg_triples = list()

for i in range(len(pos_triples)):
    
    s_pos, r_pos, t_pos = pos_triples[i][0], pos_triples[i][1], pos_triples[i][2]
    
    #decide to replace the head or tail entity
    number_0 = random.uniform(0, 1)
    
    if number_0 < 0.5: #replace head entity
        
        s_neg = random.choice(list(new_ent_set))
        
        #filter out the existing triples
        while ((s_neg, r_pos, t_pos) in data_test) or (
               (s_neg, r_pos, t_pos) in data_valid) or (
               (s_neg, r_pos, t_pos) in data) or (
               (s_neg, r_pos, t_pos) in data_ind) or (
               (s_neg, r_pos, t_pos) in data_ind_valid) or (
               (s_neg, r_pos, t_pos) in data_ind_test):
            
            s_neg = random.choice(list(new_ent_set))
        
        neg_triples.append((s_neg, r_pos, t_pos))
    
    else: #replace tail entity

        t_neg = random.choice(list(new_ent_set))
        
        #filter out the existing triples
        while ((s_pos, r_pos, t_neg) in data_test) or (
               (s_pos, r_pos, t_neg) in data_valid) or (
               (s_pos, r_pos, t_neg) in data) or (
               (s_pos, r_pos, t_neg) in data_ind) or (
               (s_pos, r_pos, t_neg) in data_ind_valid) or (
               (s_pos, r_pos, t_neg) in data_ind_test):
            
            t_neg = random.choice(list(new_ent_set))
        
        neg_triples.append((s_pos, r_pos, t_neg))

if len(pos_triples) != len(neg_triples):
    raise ValueError('error when generating negative triples')
        
#combine all triples
all_triples = pos_triples + neg_triples

#obtain the label array
arr1 = np.ones((len(pos_triples),))
arr2 = np.zeros((len(neg_triples),))
y_test = np.concatenate((arr1, arr2))

#shuffle positive and negative triples (optional)
all_triples, y_test = shuffle(all_triples, y_test)

#obtain the score aray
y_score = np.zeros((len(y_test),))

#implement the scoring
for i in range(len(all_triples)):
    
    s, r, t = all_triples[i][0], all_triples[i][1], all_triples[i][2]
    
    #path_score = path_based_triple_scoring(s, r, t, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)
    
    subg_score = subgraph_triple_scoring(s, r, t, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    #ave_score = (path_score + subg_score)/float(2)
    
    #y_score[i] = ave_score
    y_score[i] = subg_score
    
    if i % 20 == 0 and i > 0:
        print('evaluating scores', i, len(all_triples))
        
        # Data to plot precision - recall curve
        precision, recall, thresholds = precision_recall_curve(y_test[:i], y_score[:i])
        # Use AUC function to calculate the area under the curve of precision recall curve
        auc_precision_recall = auc(recall, precision)
        print('AUC-PR is:', auc_precision_recall)
        
        
# Data to plot precision - recall curve
precision, recall, thresholds = precision_recall_curve(y_test, y_score)
# Use AUC function to calculate the area under the curve of precision recall curve
auc_precision_recall = auc(recall, precision)
print('AUC-PR is:', auc_precision_recall)

evaluating scores 20 376
AUC-PR is: 0.8158124039702987
evaluating scores 40 376
AUC-PR is: 0.770603450876653
evaluating scores 60 376
AUC-PR is: 0.7587513046465775
evaluating scores 80 376
AUC-PR is: 0.7889110246028787
evaluating scores 100 376
AUC-PR is: 0.77768457943568
evaluating scores 120 376
AUC-PR is: 0.7778079025696366
evaluating scores 140 376
AUC-PR is: 0.7901361514723642
evaluating scores 160 376
AUC-PR is: 0.7539759441124845
evaluating scores 180 376
AUC-PR is: 0.7480333614245009
evaluating scores 200 376
AUC-PR is: 0.7373484679242981
evaluating scores 220 376
AUC-PR is: 0.7403965974871085
evaluating scores 240 376
AUC-PR is: 0.7600854965707611
evaluating scores 260 376
AUC-PR is: 0.7360974188142251
evaluating scores 280 376
AUC-PR is: 0.7377639545950183
evaluating scores 300 376
AUC-PR is: 0.725380399187881
evaluating scores 320 376
AUC-PR is: 0.7198041945682027
evaluating scores 340 376
AUC-PR is: 0.7173594115404267
evaluating scores 360 376
AUC-PR is: 0.7250464293392811


In [None]:
######################################################
#obtain the Hits@N for entity prediction##############

#we select all the triples in the inductive test set
selected = list(data_ind_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    triple_list = list()
    
    #score the true triple
    s_pos, r_pos, t_pos = selected[i][0], selected[i][1], selected[i][2]

    #path_score = path_based_triple_scoring(s_pos, r_pos, t_pos, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)

    subg_score = subgraph_triple_scoring(s_pos, r_pos, t_pos, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)
    
    #ave_score = (path_score + subg_score)/float(2)
    
    triple_list.append([(s_pos, r_pos, t_pos), subg_score])
    
    #generate the 50 random samples
    for sub_i in range(50):
        
        #decide to replace the head or tail entity
        number_0 = random.uniform(0, 1)

        if number_0 < 0.5: #replace head entity
            
            s_neg = random.choice(list(new_ent_set))
            
            while ((s_neg, r_pos, t_pos) in data_test) or (
                   (s_neg, r_pos, t_pos) in data_valid) or (
                   (s_neg, r_pos, t_pos) in data) or (
                   (s_neg, r_pos, t_pos) in data_ind) or (
                   (s_neg, r_pos, t_pos) in data_ind_valid) or (
                   (s_neg, r_pos, t_pos) in data_ind_test):

                s_neg = random.choice(list(new_ent_set))
            
            #path_score = path_based_triple_scoring(s_neg, r_pos, t_pos, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)

            subg_score = subgraph_triple_scoring(s_neg, r_pos, t_pos, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)

            #ave_score = (path_score + subg_score)/float(2)

            triple_list.append([(s_neg, r_pos, t_pos), subg_score])
            
        else: #replace tail entity

            t_neg = random.choice(list(new_ent_set))
            
            #filter out the existing triples
            while ((s_pos, r_pos, t_neg) in data_test) or (
                   (s_pos, r_pos, t_neg) in data_valid) or (
                   (s_pos, r_pos, t_neg) in data) or (
                   (s_pos, r_pos, t_neg) in data_ind) or (
                   (s_pos, r_pos, t_neg) in data_ind_valid) or (
                   (s_pos, r_pos, t_neg) in data_ind_test):

                t_neg = random.choice(list(new_ent_set))
            
            #path_score = path_based_triple_scoring(s_pos, r_pos, t_neg, lower_bound, upper_bound_path, one_hop_ind, id2relation, model)

            subg_score = subgraph_triple_scoring(s_pos, r_pos, t_neg, lower_bound, upper_bound_subg, one_hop_ind, id2relation, model_2)

            #ave_score = (path_score + subg_score)/float(2)

            triple_list.append([(s_pos, r_pos, t_neg), subg_score])
            
    #random shuffle!
    random.shuffle(triple_list)
    
    #sort
    sorted_list = sorted(triple_list, key = lambda x: x[-1], reverse=True)
    
    p = 0
    
    while p < len(sorted_list) and sorted_list[p][0] != (s_pos, r_pos, t_pos):
            
        p += 1
    
    if p == 0:
        
        Hits_at_1 += 1
        
    if p < 3:
        
        Hits_at_3 += 1
        
    if p < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p + 1.) 
        
    print('checkcorrect', (s_pos, r_pos, t_pos), sorted_list[p][0],
          'real score', sorted_list[p][-1],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'rank', p,
          'total_num', i, len(selected))

checkcorrect (3616, 0, 3274) (3616, 0, 3274) real score 0.4064545691013336 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.03571428571428571 rank 27 total_num 0 188
checkcorrect (3613, 0, 3144) (3613, 0, 3144) real score 0.7161331176757812 Hits@1 0.0 Hits@3 0.5 Hits@10 0.5 MRR 0.26785714285714285 rank 1 total_num 1 188
checkcorrect (2799, 0, 3251) (2799, 0, 3251) real score 0.7103965640068054 Hits@1 0.0 Hits@3 0.3333333333333333 Hits@10 0.6666666666666666 MRR 0.2261904761904762 rank 6 total_num 2 188
checkcorrect (3170, 0, 2776) (3170, 0, 2776) real score 0.7012916684150696 Hits@1 0.0 Hits@3 0.25 Hits@10 0.75 MRR 0.20535714285714285 rank 6 total_num 3 188
checkcorrect (3074, 0, 3194) (3074, 0, 3194) real score 0.749379575252533 Hits@1 0.0 Hits@3 0.4 Hits@10 0.8 MRR 0.23095238095238094 rank 2 total_num 4 188
checkcorrect (3329, 8, 2803) (3329, 8, 2803) real score 0.287155294418335 Hits@1 0.0 Hits@3 0.3333333333333333 Hits@10 0.6666666666666666 MRR 0.19912698412698412 rank 24 total_num 5 188
che

checkcorrect (2940, 0, 3343) (2940, 0, 3343) real score 0.7114482671022415 Hits@1 0.0 Hits@3 0.3404255319148936 Hits@10 0.48936170212765956 MRR 0.1805547654622057 rank 3 total_num 46 188
checkcorrect (3339, 0, 3170) (3339, 0, 3170) real score 0.5601408511400223 Hits@1 0.0 Hits@3 0.3333333333333333 Hits@10 0.4791666666666667 MRR 0.17839577195097386 rank 12 total_num 47 188
checkcorrect (3212, 8, 3283) (3212, 8, 3283) real score 0.5529832601547241 Hits@1 0.0 Hits@3 0.32653061224489793 Hits@10 0.46938775510204084 MRR 0.1764557221832669 rank 11 total_num 48 188
checkcorrect (3377, 8, 3576) (3377, 8, 3576) real score 0.0 Hits@1 0.0 Hits@3 0.32 Hits@10 0.46 MRR 0.17345292352907526 rank 37 total_num 49 188
checkcorrect (3154, 0, 3597) (3154, 0, 3597) real score 0.25739180445671084 Hits@1 0.0 Hits@3 0.3137254901960784 Hits@10 0.45098039215686275 MRR 0.17058182751923118 rank 36 total_num 50 188
checkcorrect (3150, 0, 3219) (3150, 0, 3219) real score 0.726323926448822 Hits@1 0.0 Hits@3 0.3076923

checkcorrect (3572, 0, 2813) (3572, 0, 2813) real score 0.5788687020540237 Hits@1 0.02247191011235955 Hits@3 0.24719101123595505 Hits@10 0.4044943820224719 MRR 0.17062255277834687 rank 23 total_num 88 188
checkcorrect (3540, 0, 2833) (3540, 0, 2833) real score 0.6214325815439224 Hits@1 0.022222222222222223 Hits@3 0.24444444444444444 Hits@10 0.4111111111111111 MRR 0.17150452441414302 rank 3 total_num 89 188
checkcorrect (2859, 0, 3040) (2859, 0, 3040) real score 0.6616534888744354 Hits@1 0.02197802197802198 Hits@3 0.24175824175824176 Hits@10 0.4175824175824176 MRR 0.170718760409592 rank 9 total_num 90 188
checkcorrect (2956, 0, 2903) (2956, 0, 2903) real score 0.7293300211429596 Hits@1 0.021739130434782608 Hits@3 0.25 Hits@10 0.42391304347826086 MRR 0.17248631011528484 rank 2 total_num 91 188
checkcorrect (3388, 0, 2809) (3388, 0, 2809) real score 0.7448042243719101 Hits@1 0.03225806451612903 Hits@3 0.25806451612903225 Hits@10 0.43010752688172044 MRR 0.18138430678071185 rank 0 total_num