In [1]:
data_name = 'FB15K237'
model_id = 'main_transductive'

In [2]:
#difine the names for saving
model_name = 'Model_' + model_id + '_' + data_name
one_hop_model_name = 'One_hop_model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [3]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
import pickle

from collections import defaultdict
from copy import deepcopy
from sklearn.utils import shuffle
from sys import getsizeof

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [4]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = set()
            
            one_hop[s].add((r,t))
            
            if t not in one_hop:
                
                one_hop[t] = set()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            one_hop[t].add((r_inv,s))
            
        #change each set in one_hop to list
        for e in one_hop:
            
            one_hop[e] = list(one_hop[e])

In [5]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, amount_bd=50, size_bd=50, threshold=20000):
        
        self.amount_bd = amount_bd #how many Tuples we choose in one_hop[node] for next recursion
                        
        self.size_bd = size_bd #size bound limit the number of paths to a target entity t
        
        #number of times paths with specific length been performed for recursion
        self.threshold = threshold
        
    '''
    Given an entity s, the function will find the paths from s to other entities, using recursion.
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on existing entities')

        #here is the result dict. Its key is each entity t sharing paths from s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #qualified_t contains the types of t we want to consider,
        #that is, what t will be added to the result set.
        qualified_t = set()

        #under this mode, we will only consider the direct neighbour of s
        if mode == 'direct_neighbour':
        
            for Tuple in one_hop[s]:
            
                t = Tuple[1]
                
                qualified_t.add(t)
        
        #under this mode, we will only consider one specified entity t
        elif mode == 'target_specified':
            
            qualified_t.add(t_input)
        
        #under this mode, we will consider any entity
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                qualified_t.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {s, e1, ..., ek-1, node}
        from s to this node, we will further find the direct neighbor t' of this node. 
        If t' is not an on-path entity (not among s, e1,...ek-1, node), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict):

            #when the current path is within lower_bd and upper_bd, 
            #and the node is among the qualified t, and it has not been fill of paths w.r.t size_limit,
            #we will add this path to the node
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in qualified_t) and (len(res[node]) < self.size_bd):
                
                res[node].add(tuple(path))
                    
            #won't start new recursions if the current path length already reaches upper limit
            #or the number of recursions performed on this length has reached the limit
            if (len(path) < upper_bd) and (count_dict[len(path)] <= self.threshold):
                                
                #temp list is the id list for us to go-over one_hop[node]
                temp_list = [i for i in range(len(one_hop[node]))]
                random.shuffle(temp_list) #so we random-shuffle the list
                
                #only take 20 recursions if there are too many (r,t)
                for i in temp_list[:self.amount_bd]:
                    
                    #obtain tuple of (r,t)
                    Tuple = one_hop[node][i]
                    r, t = Tuple[0], Tuple[1]
                    
                    #add to count_dict even if eventually this step not proceed
                    count_dict[len(path)] += 1
                    
                    #if t not on the path and we not exceed the computation threshold, 
                    #then finally proceed to next recursion
                    if (t not in on_path_en) and (count_dict[len(path)] <= self.threshold):

                        helper(t, path + [r], on_path_en.union({t}), res, qualified_t, 
                               lower_bd, upper_bd, one_hop, count_dict)

        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict)
        
        return(res, count_dict)

In [6]:
train_path = '../data/' + data_name + '/train.txt'
valid_path = '../data/' + data_name + '/valid.txt'

In [7]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [8]:
#define the dictionaries and sets for load KG
one_hop = dict() 
data = set()
s_t_r = defaultdict(set)

#define the dictionaries, which is shared by initail and inductive train/valid/test
entity2id = dict()
id2entity = dict()
relation2id = dict()
id2relation = dict()

#fill in the sets and dicts
Class_1.load_train_data(train_path, one_hop, data, s_t_r,
                        entity2id, id2entity, relation2id, id2relation)

In [9]:
#define the dictionaries and sets for load KG
one_hop_valid = dict() 
data_valid = set()
s_t_r_valid = defaultdict(set)

#fill in the sets and dicts
Class_1.load_train_data(valid_path, one_hop_valid, data_valid, s_t_r_valid,
                        entity2id, id2entity, relation2id, id2relation)

#### Build the path-based siamese neural network structure

We use biLSTM to train on the input path embedding sequence to predict the output embedding or the relation.

In [10]:
# Input layer, using integer to represent each relation type
#note that inputs_path is the path inputs, while inputs_out_re is the output relation inputs
fst_path = keras.Input(shape=(None,), dtype="int32")
scd_path = keras.Input(shape=(None,), dtype="int32")
thd_path = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of paths are not the same
in_embd_var = layers.Embedding(len(relation2id)+1, 300)

# Obtain the embedding
fst_p_embd = in_embd_var(fst_path)
scd_p_embd = in_embd_var(scd_path)
thd_p_embd = in_embd_var(thd_path)

# Embed each integer in a 300-dimensional vector as output
rela_embd = layers.Embedding(len(relation2id)+1, 300)(id_rela)

#add 2 layer bi-directional LSTM
lstm_layer_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_layer_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

#first LSTM layer
fst_lstm_mid = lstm_layer_1(fst_p_embd)
scd_lstm_mid = lstm_layer_1(scd_p_embd)
thd_lstm_mid = lstm_layer_1(thd_p_embd)

#second LSTM layer
fst_lstm_out = lstm_layer_2(fst_lstm_mid)
scd_lstm_out = lstm_layer_2(scd_lstm_mid)
thd_lstm_out = lstm_layer_2(thd_lstm_mid)

#reduce max
fst_reduce_max = tf.reduce_max(fst_lstm_out, axis=1)
scd_reduce_max = tf.reduce_max(scd_lstm_out, axis=1)
thd_reduce_max = tf.reduce_max(thd_lstm_out, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
path_concat = layers.concatenate([fst_reduce_max, scd_reduce_max, thd_reduce_max], axis=-1)

#add dropout on top of the concatenation from all channels
dropout = layers.Dropout(0.25)(path_concat)

#multiply into output embd size by dense layer: (Batch, 300)
path_out_vect = layers.Dense(300, activation='tanh')(dropout)

#remove the time dimension from the output embd since there is only one step
rela_out_embd = tf.reduce_sum(rela_embd, axis=1)

# Normalize the vectors to have unit length
path_out_vect_norm = tf.math.l2_normalize(path_out_vect, axis=-1)
rela_out_embd_norm = tf.math.l2_normalize(rela_out_embd, axis=-1)

# Calculate the dot product
dot_product = layers.Dot(axes=-1)([path_out_vect_norm, rela_out_embd_norm])

#put together the model
model = keras.Model([fst_path, scd_path, thd_path, id_rela], dot_product)

2023-03-14 10:23:15.188576: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [11]:
#config the Adam optimizer 
opt = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['binary_accuracy'])

#### Build the subgraph-based siamese neural network

In [12]:
#each input is an vector with number of relations to be dim:
#each dim represent the existence (1) or not (0) of an out-going relation from the entity
source_path_1 = keras.Input(shape=(None,), dtype="int32")
source_path_2 = keras.Input(shape=(None,), dtype="int32")
source_path_3 = keras.Input(shape=(None,), dtype="int32")

target_path_1 = keras.Input(shape=(None,), dtype="int32")
target_path_2 = keras.Input(shape=(None,), dtype="int32")
target_path_3 = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela_ = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of paths are not the same
in_embd_var_ = layers.Embedding(len(relation2id)+1, 300)

# Obtain the source embeddings
source_embd_1 = in_embd_var_(source_path_1)
source_embd_2 = in_embd_var_(source_path_2)
source_embd_3 = in_embd_var_(source_path_3)

#Obtain the target embeddings
target_embd_1 = in_embd_var_(target_path_1)
target_embd_2 = in_embd_var_(target_path_2)
target_embd_3 = in_embd_var_(target_path_3)

# Embed each integer in a 300-dimensional vector as output
rela_embd_ = layers.Embedding(len(relation2id)+1, 300)(id_rela_)

#add 2 layer bi-directional LSTM network
lstm_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

###source lstm implimentation########
#first LSTM layer
source_mid_1 = lstm_1(source_embd_1)
source_mid_2 = lstm_1(source_embd_2)
source_mid_3 = lstm_1(source_embd_3)

#second LSTM layer
source_out_1 = lstm_2(source_mid_1)
source_out_2 = lstm_2(source_mid_2)
source_out_3 = lstm_2(source_mid_3)

#reduce max
source_max_1 = tf.reduce_max(source_out_1, axis=1)
source_max_2 = tf.reduce_max(source_out_2, axis=1)
source_max_3 = tf.reduce_max(source_out_3, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
source_concat = layers.concatenate([source_max_1, source_max_2, source_max_3], axis=-1)

#add dropout on top of the concatenation from all channels
source_dropout = layers.Dropout(0.25)(source_concat)

###target lstm implimentation########
#first LSTM layer
target_mid_1 = lstm_1(target_embd_1)
target_mid_2 = lstm_1(target_embd_2)
target_mid_3 = lstm_1(target_embd_3)

#second LSTM layer
target_out_1 = lstm_2(target_mid_1)
target_out_2 = lstm_2(target_mid_2)
target_out_3 = lstm_2(target_mid_3)

#reduce max
target_max_1 = tf.reduce_max(target_out_1, axis=1)
target_max_2 = tf.reduce_max(target_out_2, axis=1)
target_max_3 = tf.reduce_max(target_out_3, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 900)
target_concat = layers.concatenate([target_max_1, target_max_2, target_max_3], axis=-1)

#add dropout on top of the concatenation from all channels
target_dropout = layers.Dropout(0.25)(target_concat)

#further concatenate source and target output embeddings
final_concat = layers.concatenate([source_dropout, target_dropout], axis=-1)

#multiply into output embd size by dense layer: (Batch, 300)
out_vect = layers.Dense(300, activation='tanh')(final_concat)

#remove the time dimension from the output embd since there is only one step
rela_out_embd_ = tf.reduce_sum(rela_embd_, axis=1)

# Normalize the vectors to have unit length
out_vect_norm = tf.math.l2_normalize(out_vect, axis=-1)
rela_out_embd_norm_ = tf.math.l2_normalize(rela_out_embd_, axis=-1)

# Calculate the dot product
dot_product_ = layers.Dot(axes=-1)([out_vect_norm, rela_out_embd_norm_])

#put together the model
model_2 = keras.Model([source_path_1, source_path_2, source_path_3,
                       target_path_1, target_path_2, target_path_3, id_rela_], dot_product_)

In [13]:
#config the Adam optimizer 
opt_ = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model_2.compile(loss='binary_crossentropy', optimizer=opt_, metrics=['binary_accuracy'])

### Build the big-batch for path-based model
We will build the big-batch for the path-based model training. That is, we will build three list to store three paths, respectively. But different from the relation prediction, this time the negative samples are paths from s to another t'. We will keep the relation r invariant.

In order to reduce computational complexity, we will run the path-finding algorithm for each entity e in the dataset before the training. That is, for each entity e, we will have two dictionaries. Dict 1 stores the paths between e and any other entities in the dataset. Will Dict 2 stores the paths between e and its direct neighbors. The two dicts will be used and invariant throughout the training.

* At each step, three different paths between two entities s and t are selected. Each path is append to one of the list. 
* If this step is for positive samples, t will be the directly connected entity from s. Then, relation r will be the existing connection from s to t. The label will be 1.
* If this step is for negative samples, t will be another entity that can be reached from s, but not directly connected. Relation r will be the same. The label will be 0.
* In practice, the positive step is always fallowed by a negative step. That is, for positive step, we have s and t. But for negative step we have s and t'. Relaiton r will be the same in both steps.
* We do this until all the entity in the dataset is visited as s. Then we complete building the entire batches for one session. The model will be trained on these batches for fixed number of epochs in each session.

**For source and target entity prediciton, we will need to train using both (s,r,t) triple and (t,r-1,s) triple**

In [14]:
#function to build the big batche for path-based training
def build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    count = 0
    for s in existing_ids:
        
        #impliment the path finding algorithm to find paths between s and t
        #result 1 is for postitive case (direct neighbor)
        #result 2 is for negative case (any reachable entity)
        result_1 = defaultdict(set)

        result_2, length_dict_2 = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)

        for t_ in result_2:

            #if (s, t_) or (t_, s) in s_t_r, then it t_ is a direct neighbor of s
            if (s, t_) in s_t_r or (t_, s) in s_t_r:
                result_1[t_] = deepcopy(result_2[t_])
        
        #qualified_1 stores the direct neighbors with at least 3 differetn paths from s
        #qualified_2 stores any entity with at least 3 diff paths from s
        qualified_1, qualified_2 = set(), set()
        
        for t in result_1:
            if len(result_1[t]) >= 3:
                qualified_1.add(t)
        
        for t in result_2:
            if len(result_2[t]) >= 3:
                qualified_2.add(t)
        
        #proceed if both qualified_1 and qualified_2 are not empty
        if len(qualified_1) > 0 and len(qualified_2) > 0:
            
            qualified_1 = list(qualified_1)
            qualified_2 = list(qualified_2)
            
            for iteration in range(5):
                
                random.shuffle(qualified_1)
            
                for t in qualified_1:

                    if len(s_t_r[(s,t)]) == 0:

                        raise ValueError(s, t, id2entity[s], id2entity[t])

                    #we randomly choose the t_ for negative samples each time
                    t_ = random.choice(qualified_2)

                    #we already know there are more than 3 paths between s and t, 
                    #and s_t_r is not empty, so directly proceed
                    temp_r_list = list(s_t_r[(s,t)])
                    r = random.choice(temp_r_list)

                    #####positive#####################
                    #obtain the list form of all the paths from s to t
                    temp_path_list = list(result_1[t])
                    temp_pair = random.sample(temp_path_list, 3)
                    path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

                    #append the paths: note that we add the space holder id at the end of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    x_r_list.append([r])
                    y_list.append(1.)

                    del(temp_path_list, temp_pair, path_1, path_2, path_3)

                    #####negative#####################
                    #obtain the list form of all the paths from s to t
                    temp_path_list = list(result_2[t_])
                    temp_pair = random.sample(temp_path_list, 3)
                    path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]

                    #append the paths: note that we add the space holder id at the end
                    #of the shorter path
                    x_p_list['1'].append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                    x_p_list['2'].append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                    x_p_list['3'].append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))

                    #append relation
                    x_r_list.append([r])
                    y_list.append(0.)

                    del(temp_path_list, temp_pair, path_1, path_2, path_3)
        
        count += 1
        if count % 100 == 0:
            print('generating big-batches for path-based model', count, len(existing_ids))

### Build the big-batch for the subgraph-based network training

Different from relation prediciton, the negative sample will be a corrupted triple (s,r,t') or a (s',r,t). Also, in order to reduce computational complexity, we will store the subgraph of each entity e at the beginning.

* At each step, we will select one triple (s,r,t) from the dataset. Then, subgraphs of s and t are selected.
* We will select three paths for each of source and target entity from their subgraphs, respectively. Add them to the corresponding list.
* If this is a positive sample step, (s,r,t) is a true triple.
* If this is a negative sample step, we will use either (s',r,t) or (s,r,t').
* Similarly, one negative sample step always follows one positive step.
* After we go over all the triples, we built the big-batch for one session. Then, fixed number of epoches are trained. Then we enter a new session to rebuilt the big-batch. However, as we indicated, the subgraphs are stored and won't change aross different sessions.

In [15]:
#Again, it is too slow to run the path-finding algorithm again and again on the complete FB15K-237
#Instead, we will find the subgraph for each entity once.
#then in the subgraph based training, the subgraphs are stored and used for multiple times
def store_subgraph_dicts(lower_bd, upper_bd, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
    
    #the ids to start path finding
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    #Dict stores the subgraph for each entity
    Dict_1 = dict()
    
    count = 0
    for s in existing_ids:
        
        path_set = set()
            
        result, length_dict = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)

        for t_ in result:
            for path in result[t_]:
                path_set.add(path)

        del(result, length_dict)
        
        path_list = list(path_set)
        
        path_select = random.sample(path_list, min(len(path_list), 100))
            
        Dict_1[s] = deepcopy(path_select)
        
        count += 1
        if count % 100 == 0:
            print('generating and storing paths for the path-based model', count, len(existing_ids))
        
    return(Dict_1)

In [16]:
#function to build the big-batch for one-hope neighbor training
def build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      x_s_list, x_t_list, x_r_list, y_list, Dict,
                      relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    #the set of all initial relations
    ini_r_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
        
        if i % 2 == 0: #initial relation id is always an even number
            ini_r_id_set.add(i)
    
    num_r = len(id2relation)
    num_ini_r = len(ini_r_id_set)
    
    if num_ini_r != int(num_r/2):
        raise ValueError('error when generating id2relation')
        
    #if an entity has at least three out-stretching paths, it is a qualified one
    qualified = set()
    for e in Dict:
        if len(Dict[e]) >= 3:
            qualified.add(e)
    qualified = list(qualified)
    
    data = list(data)
    
    for iteration in range(5):
    
        data = shuffle(data)

        for i_0 in range(len(data)):

            triple = data[i_0]

            s, r, t = triple[0], triple[1], triple[2] #obtain entities and relation IDs

            if s in qualified and t in qualified:

                #obtain the path list for true entities
                path_s, path_t = list(Dict[s]), list(Dict[t])

                #randomly choose two negative sampled entities
                s_ran = random.choice(qualified)
                t_ran = random.choice(qualified)

                #obtain the path list for random entities
                path_s_ran, path_t_ran = list(Dict[s_ran]), list(Dict[t_ran])

                #####positive step###########
                #randomly obtain three paths
                temp_s = random.sample(path_s, 3)
                temp_t = random.sample(path_t, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative for source entity###########
                #randomly obtain three paths
                temp_s = random.sample(path_s_ran, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(0.)

                #####positive step###########
                #randomly obtain three paths
                temp_s = random.sample(path_s, 3)
                temp_t = random.sample(path_t, 3)
                s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(1.)

                #####negative for target entity###########
                #randomly obtain three paths
                temp_t = random.sample(path_t_ran, 3)
                t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                #append the paths: note that we add the space holder id at the end of the shorter path
                x_s_list['1'].append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                x_s_list['2'].append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                x_s_list['3'].append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                x_t_list['1'].append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                x_t_list['2'].append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                x_t_list['3'].append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                #append relation
                x_r_list.append([r])
                y_list.append(0.)

            if i_0 % 2000 == 0:
                print('generating big-batches for subgraph-based model', i_0, len(data), iteration)

### Start Training: load the KG and call classes

Here, we use the validation set to see the training efficiency. That is, we use the validation to check whether the true relation between entities can be predicted by paths.

The trick is: in validation, we have to use the same relation ID and entity ID as in the training. But we don't want to use the links in training anymore. That is, in validation, we want to use (and update if necessary) entity2id, id2entity, relation2id and id2relation. But we want to use new one_hop, data, data_ and s_t_r for validation set. Then, path-finding will also be based on new one_hop.


In [17]:
model_name

'Model_main_transductive_FB15K237'

In [18]:
one_hop_model_name

'One_hop_model_main_transductive_FB15K237'

In [19]:
ids_name

'IDs_main_transductive_FB15K237'

In [20]:
#first, we save the relation and ids
Dict = dict()

#save training data
Dict['one_hop'] = one_hop
Dict['data'] = data
Dict['s_t_r'] = s_t_r

#save valid data
Dict['one_hop_valid'] = one_hop_valid
Dict['data_valid'] = data_valid
Dict['s_t_r_valid'] = s_t_r_valid

#save shared dictionaries
Dict['entity2id'] = entity2id
Dict['id2entity'] = id2entity
Dict['relation2id'] = relation2id
Dict['id2relation'] = id2relation

with open('../weight_bin/' + ids_name + '.pickle', 'wb') as handle:
    pickle.dump(Dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [21]:
###train the path-based model
lower_bd = 1
upper_bd = 10
num_epoch = 10
batch_size = 32

        
#define the training lists
train_p_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, list(), list()

#define the validation lists
valid_p_list, valid_r_list, valid_y_list = {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_path(lower_bd, upper_bd, data, one_hop, s_t_r,
                      train_p_list, train_r_list, train_y_list,
                      relation2id, entity2id, id2relation, id2entity)

#fill in the validation array list
build_big_batches_path(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                      valid_p_list, valid_r_list, valid_y_list,
                      relation2id, entity2id, id2relation, id2entity)    

#######################################
###do the training#####################
#sometimes the validation dataset is so small so sparse, 
#which cannot find three paths between any pair of s and t.
#in such a case, we will divide the training big-batch into train and valid
if len(valid_y_list) >= 100:
    #generate the input arrays
    x_train_1 = np.asarray(train_p_list['1'], dtype='int')
    x_train_2 = np.asarray(train_p_list['2'], dtype='int')
    x_train_3 = np.asarray(train_p_list['3'], dtype='int')
    x_train_r = np.asarray(train_r_list, dtype='int')
    y_train = np.asarray(train_y_list, dtype='int')

    #generate the validation arrays
    x_valid_1 = np.asarray(valid_p_list['1'], dtype='int')
    x_valid_2 = np.asarray(valid_p_list['2'], dtype='int')
    x_valid_3 = np.asarray(valid_p_list['3'], dtype='int')
    x_valid_r = np.asarray(valid_r_list, dtype='int')
    y_valid = np.asarray(valid_y_list, dtype='int')

else:
    split = int(len(train_y_list)*0.8)
    #generate the input arrays
    x_train_1 = np.asarray(train_p_list['1'][:split], dtype='int')
    x_train_2 = np.asarray(train_p_list['2'][:split], dtype='int')
    x_train_3 = np.asarray(train_p_list['3'][:split], dtype='int')
    x_train_r = np.asarray(train_r_list[:split], dtype='int')
    y_train = np.asarray(train_y_list[:split], dtype='int')

    #generate the validation arrays
    x_valid_1 = np.asarray(train_p_list['1'][split:], dtype='int')
    x_valid_2 = np.asarray(train_p_list['2'][split:], dtype='int')
    x_valid_3 = np.asarray(train_p_list['3'][split:], dtype='int')
    x_valid_r = np.asarray(train_r_list[split:], dtype='int')
    y_valid = np.asarray(train_y_list[split:], dtype='int')
    
#do the training
model.fit([x_train_1, x_train_2, x_train_3, x_train_r], y_train, 
          validation_data=([x_valid_1, x_valid_2, x_valid_3, x_valid_r], y_valid),
          batch_size=batch_size, epochs=num_epoch)

# Save model and weights
add_h5 = model_name + '.h5'
save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path = os.path.join(save_dir, add_h5)
model.save(model_path)
print('Save model')
del(model)

generating big-batches for path-based model 100 14505
generating big-batches for path-based model 200 14505
generating big-batches for path-based model 300 14505
generating big-batches for path-based model 400 14505
generating big-batches for path-based model 500 14505
generating big-batches for path-based model 600 14505
generating big-batches for path-based model 700 14505
generating big-batches for path-based model 800 14505
generating big-batches for path-based model 900 14505
generating big-batches for path-based model 1000 14505
generating big-batches for path-based model 1100 14505
generating big-batches for path-based model 1200 14505
generating big-batches for path-based model 1300 14505
generating big-batches for path-based model 1400 14505
generating big-batches for path-based model 1500 14505
generating big-batches for path-based model 1600 14505
generating big-batches for path-based model 1700 14505
generating big-batches for path-based model 1800 14505
generating big-batc

generating big-batches for path-based model 500 9809
generating big-batches for path-based model 600 9809
generating big-batches for path-based model 700 9809
generating big-batches for path-based model 800 9809
generating big-batches for path-based model 900 9809
generating big-batches for path-based model 1000 9809
generating big-batches for path-based model 1100 9809
generating big-batches for path-based model 1200 9809
generating big-batches for path-based model 1300 9809
generating big-batches for path-based model 1400 9809
generating big-batches for path-based model 1500 9809
generating big-batches for path-based model 1600 9809
generating big-batches for path-based model 1700 9809
generating big-batches for path-based model 1800 9809
generating big-batches for path-based model 1900 9809
generating big-batches for path-based model 2000 9809
generating big-batches for path-based model 2100 9809
generating big-batches for path-based model 2200 9809
generating big-batches for path-b

In [22]:
###train the subgraph-based model
lower_bd = 1
upper_bd = 3
num_epoch = 10
batch_size = 32

Dict_train = store_subgraph_dicts(lower_bd, upper_bd, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity)

Dict_valid = store_subgraph_dicts(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                         relation2id, entity2id, id2relation, id2entity)

        
#define the training lists
train_s_list, train_t_list, train_r_list, train_y_list = {'1': [], '2': [], '3': []}, {'1': [], '2': [], '3': []}, list(), list()

#define the validation lists
valid_s_list, valid_t_list, valid_r_list, valid_y_list = {'1': [], '2': [], '3': []}, {'1': [], '2': [], '3': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches_subgraph(lower_bd, upper_bd, data, one_hop, s_t_r,
                      train_s_list, train_t_list, train_r_list, train_y_list, Dict_train,
                      relation2id, entity2id, id2relation, id2entity)

#fill in the validation array list
build_big_batches_subgraph(lower_bd, upper_bd, data_valid, one_hop_valid, s_t_r_valid,
                      valid_s_list, valid_t_list, valid_r_list, valid_y_list, Dict_valid,
                      relation2id, entity2id, id2relation, id2entity)    

#######################################
###do the training#####################
#sometimes the validation dataset is so small so sparse, 
#which cannot find three paths between any pair of s and t.
#in such a case, we will divide the training big-batch into train and valid
if len(valid_y_list) >= 100:
    #generate the input arrays
    x_train_s_1 = np.asarray(train_s_list['1'], dtype='int')
    x_train_s_2 = np.asarray(train_s_list['2'], dtype='int')
    x_train_s_3 = np.asarray(train_s_list['3'], dtype='int')

    x_train_t_1 = np.asarray(train_t_list['1'], dtype='int')
    x_train_t_2 = np.asarray(train_t_list['2'], dtype='int')
    x_train_t_3 = np.asarray(train_t_list['3'], dtype='int')

    x_train_r = np.asarray(train_r_list, dtype='int')
    y_train = np.asarray(train_y_list, dtype='int')

    #generate the validation arrays
    x_valid_s_1 = np.asarray(valid_s_list['1'], dtype='int')
    x_valid_s_2 = np.asarray(valid_s_list['2'], dtype='int')
    x_valid_s_3 = np.asarray(valid_s_list['3'], dtype='int')

    x_valid_t_1 = np.asarray(valid_t_list['1'], dtype='int')
    x_valid_t_2 = np.asarray(valid_t_list['2'], dtype='int')
    x_valid_t_3 = np.asarray(valid_t_list['3'], dtype='int')

    x_valid_r = np.asarray(valid_r_list, dtype='int')
    y_valid = np.asarray(valid_y_list, dtype='int')

else:
    split = int(len(train_y_list)*0.8)
    #generate the input arrays
    x_train_s_1 = np.asarray(train_s_list['1'][:split], dtype='int')
    x_train_s_2 = np.asarray(train_s_list['2'][:split], dtype='int')
    x_train_s_3 = np.asarray(train_s_list['3'][:split], dtype='int')

    x_train_t_1 = np.asarray(train_t_list['1'][:split], dtype='int')
    x_train_t_2 = np.asarray(train_t_list['2'][:split], dtype='int')
    x_train_t_3 = np.asarray(train_t_list['3'][:split], dtype='int')

    x_train_r = np.asarray(train_r_list[:split], dtype='int')
    y_train = np.asarray(train_y_list[:split], dtype='int')

    #generate the validation arrays
    x_valid_s_1 = np.asarray(train_s_list['1'][split:], dtype='int')
    x_valid_s_2 = np.asarray(train_s_list['2'][split:], dtype='int')
    x_valid_s_3 = np.asarray(train_s_list['3'][split:], dtype='int')

    x_valid_t_1 = np.asarray(train_t_list['1'][split:], dtype='int')
    x_valid_t_2 = np.asarray(train_t_list['2'][split:], dtype='int')
    x_valid_t_3 = np.asarray(train_t_list['3'][split:], dtype='int')

    x_valid_r = np.asarray(train_r_list[split:], dtype='int')
    y_valid = np.asarray(train_y_list[split:], dtype='int')

#do the training
model_2.fit([x_train_s_1, x_train_s_2, x_train_s_3, x_train_t_1, x_train_t_2, x_train_t_3, x_train_r], y_train, 
          validation_data=([x_valid_s_1, x_valid_s_2, x_valid_s_3, x_valid_t_1, x_valid_t_2, x_valid_t_3, x_valid_r], y_valid),
          batch_size=batch_size, epochs=num_epoch)

# Save model and weights
one_hop_add_h5 = one_hop_model_name + '.h5'
one_hop_save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(one_hop_save_dir):
    os.makedirs(one_hop_save_dir)
one_hop_model_path = os.path.join(one_hop_save_dir, one_hop_add_h5)
model_2.save(one_hop_model_path)
print('Save model')
del(model_2, Dict_train, Dict_valid)

generating and storing paths for the path-based model 100 14505
generating and storing paths for the path-based model 200 14505
generating and storing paths for the path-based model 300 14505
generating and storing paths for the path-based model 400 14505
generating and storing paths for the path-based model 500 14505
generating and storing paths for the path-based model 600 14505
generating and storing paths for the path-based model 700 14505
generating and storing paths for the path-based model 800 14505
generating and storing paths for the path-based model 900 14505
generating and storing paths for the path-based model 1000 14505
generating and storing paths for the path-based model 1100 14505
generating and storing paths for the path-based model 1200 14505
generating and storing paths for the path-based model 1300 14505
generating and storing paths for the path-based model 1400 14505
generating and storing paths for the path-based model 1500 14505
generating and storing paths for t

generating and storing paths for the path-based model 12700 14505
generating and storing paths for the path-based model 12800 14505
generating and storing paths for the path-based model 12900 14505
generating and storing paths for the path-based model 13000 14505
generating and storing paths for the path-based model 13100 14505
generating and storing paths for the path-based model 13200 14505
generating and storing paths for the path-based model 13300 14505
generating and storing paths for the path-based model 13400 14505
generating and storing paths for the path-based model 13500 14505
generating and storing paths for the path-based model 13600 14505
generating and storing paths for the path-based model 13700 14505
generating and storing paths for the path-based model 13800 14505
generating and storing paths for the path-based model 13900 14505
generating and storing paths for the path-based model 14000 14505
generating and storing paths for the path-based model 14100 14505
generating

generating big-batches for subgraph-based model 22000 272115 0
generating big-batches for subgraph-based model 24000 272115 0
generating big-batches for subgraph-based model 26000 272115 0
generating big-batches for subgraph-based model 28000 272115 0
generating big-batches for subgraph-based model 30000 272115 0
generating big-batches for subgraph-based model 32000 272115 0
generating big-batches for subgraph-based model 34000 272115 0
generating big-batches for subgraph-based model 36000 272115 0
generating big-batches for subgraph-based model 38000 272115 0
generating big-batches for subgraph-based model 40000 272115 0
generating big-batches for subgraph-based model 42000 272115 0
generating big-batches for subgraph-based model 44000 272115 0
generating big-batches for subgraph-based model 46000 272115 0
generating big-batches for subgraph-based model 48000 272115 0
generating big-batches for subgraph-based model 50000 272115 0
generating big-batches for subgraph-based model 52000 2

generating big-batches for subgraph-based model 6000 272115 1
generating big-batches for subgraph-based model 8000 272115 1
generating big-batches for subgraph-based model 10000 272115 1
generating big-batches for subgraph-based model 12000 272115 1
generating big-batches for subgraph-based model 14000 272115 1
generating big-batches for subgraph-based model 16000 272115 1
generating big-batches for subgraph-based model 18000 272115 1
generating big-batches for subgraph-based model 20000 272115 1
generating big-batches for subgraph-based model 22000 272115 1
generating big-batches for subgraph-based model 24000 272115 1
generating big-batches for subgraph-based model 26000 272115 1
generating big-batches for subgraph-based model 28000 272115 1
generating big-batches for subgraph-based model 30000 272115 1
generating big-batches for subgraph-based model 32000 272115 1
generating big-batches for subgraph-based model 34000 272115 1
generating big-batches for subgraph-based model 36000 272

generating big-batches for subgraph-based model 266000 272115 1
generating big-batches for subgraph-based model 268000 272115 1
generating big-batches for subgraph-based model 270000 272115 1
generating big-batches for subgraph-based model 272000 272115 1
generating big-batches for subgraph-based model 0 272115 2
generating big-batches for subgraph-based model 2000 272115 2
generating big-batches for subgraph-based model 4000 272115 2
generating big-batches for subgraph-based model 6000 272115 2
generating big-batches for subgraph-based model 8000 272115 2
generating big-batches for subgraph-based model 10000 272115 2
generating big-batches for subgraph-based model 12000 272115 2
generating big-batches for subgraph-based model 14000 272115 2
generating big-batches for subgraph-based model 16000 272115 2
generating big-batches for subgraph-based model 18000 272115 2
generating big-batches for subgraph-based model 20000 272115 2
generating big-batches for subgraph-based model 22000 27211

generating big-batches for subgraph-based model 250000 272115 2
generating big-batches for subgraph-based model 252000 272115 2
generating big-batches for subgraph-based model 254000 272115 2
generating big-batches for subgraph-based model 256000 272115 2
generating big-batches for subgraph-based model 258000 272115 2
generating big-batches for subgraph-based model 260000 272115 2
generating big-batches for subgraph-based model 262000 272115 2
generating big-batches for subgraph-based model 264000 272115 2
generating big-batches for subgraph-based model 266000 272115 2
generating big-batches for subgraph-based model 268000 272115 2
generating big-batches for subgraph-based model 270000 272115 2
generating big-batches for subgraph-based model 272000 272115 2
generating big-batches for subgraph-based model 0 272115 3
generating big-batches for subgraph-based model 2000 272115 3
generating big-batches for subgraph-based model 4000 272115 3
generating big-batches for subgraph-based model 6

generating big-batches for subgraph-based model 234000 272115 3
generating big-batches for subgraph-based model 236000 272115 3
generating big-batches for subgraph-based model 238000 272115 3
generating big-batches for subgraph-based model 240000 272115 3
generating big-batches for subgraph-based model 242000 272115 3
generating big-batches for subgraph-based model 244000 272115 3
generating big-batches for subgraph-based model 246000 272115 3
generating big-batches for subgraph-based model 248000 272115 3
generating big-batches for subgraph-based model 250000 272115 3
generating big-batches for subgraph-based model 252000 272115 3
generating big-batches for subgraph-based model 254000 272115 3
generating big-batches for subgraph-based model 256000 272115 3
generating big-batches for subgraph-based model 258000 272115 3
generating big-batches for subgraph-based model 260000 272115 3
generating big-batches for subgraph-based model 262000 272115 3
generating big-batches for subgraph-base

generating big-batches for subgraph-based model 218000 272115 4
generating big-batches for subgraph-based model 220000 272115 4
generating big-batches for subgraph-based model 222000 272115 4
generating big-batches for subgraph-based model 224000 272115 4
generating big-batches for subgraph-based model 226000 272115 4
generating big-batches for subgraph-based model 228000 272115 4
generating big-batches for subgraph-based model 230000 272115 4
generating big-batches for subgraph-based model 232000 272115 4
generating big-batches for subgraph-based model 234000 272115 4
generating big-batches for subgraph-based model 236000 272115 4
generating big-batches for subgraph-based model 238000 272115 4
generating big-batches for subgraph-based model 240000 272115 4
generating big-batches for subgraph-based model 242000 272115 4
generating big-batches for subgraph-based model 244000 272115 4
generating big-batches for subgraph-based model 246000 272115 4
generating big-batches for subgraph-base

### Result on the testset for transductive link prediction

We use the testset for transductive link prediction.

In [1]:
data_name = 'FB15K237'
model_id = 'main_transductive'

In [2]:
#difine the names for saving
model_name = 'Model_' + model_id + '_' + data_name
one_hop_model_name = 'One_hop_model_' + model_id + '_' + data_name
ids_name = 'IDs_' + model_id + '_' + data_name

In [3]:
ids_name

'IDs_main_transductive_FB15K237'

In [4]:
one_hop_model_name

'One_hop_model_main_transductive_FB15K237'

In [5]:
model_name

'Model_main_transductive_FB15K237'

In [6]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
import pickle

from collections import defaultdict
from copy import deepcopy
from sklearn.utils import shuffle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [7]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = set()
            
            one_hop[s].add((r,t))
            
            if t not in one_hop:
                
                one_hop[t] = set()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            one_hop[t].add((r_inv,s))
            
        #change each set in one_hop to list
        for e in one_hop:
            
            one_hop[e] = list(one_hop[e])

In [8]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, amount_bd=50, size_bd=50, threshold=20000):
        
        self.amount_bd = amount_bd #how many Tuples we choose in one_hop[node] for next recursion
                        
        self.size_bd = size_bd #size bound limit the number of paths to a target entity t
        
        #number of times paths with specific length been performed for recursion
        self.threshold = threshold
        
    '''
    Given an entity s, the function will find the paths from s to other entities, using recursion.
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on existing entities')

        #here is the result dict. Its key is each entity t sharing paths from s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #qualified_t contains the types of t we want to consider,
        #that is, what t will be added to the result set.
        qualified_t = set()

        #under this mode, we will only consider the direct neighbour of s
        if mode == 'direct_neighbour':
        
            for Tuple in one_hop[s]:
            
                t = Tuple[1]
                
                qualified_t.add(t)
        
        #under this mode, we will only consider one specified entity t
        elif mode == 'target_specified':
            
            qualified_t.add(t_input)
        
        #under this mode, we will consider any entity
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                qualified_t.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {s, e1, ..., ek-1, node}
        from s to this node, we will further find the direct neighbor t' of this node. 
        If t' is not an on-path entity (not among s, e1,...ek-1, node), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict):

            #when the current path is within lower_bd and upper_bd, 
            #and the node is among the qualified t, and it has not been fill of paths w.r.t size_limit,
            #we will add this path to the node
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in qualified_t) and (len(res[node]) < self.size_bd):
                
                res[node].add(tuple(path))
                    
            #won't start new recursions if the current path length already reaches upper limit
            #or the number of recursions performed on this length has reached the limit
            if (len(path) < upper_bd) and (count_dict[len(path)] <= self.threshold):
                                
                #temp list is the id list for us to go-over one_hop[node]
                temp_list = [i for i in range(len(one_hop[node]))]
                random.shuffle(temp_list) #so we random-shuffle the list
                
                #only take 20 recursions if there are too many (r,t)
                for i in temp_list[:self.amount_bd]:
                    
                    #obtain tuple of (r,t)
                    Tuple = one_hop[node][i]
                    r, t = Tuple[0], Tuple[1]
                    
                    #add to count_dict even if eventually this step not proceed
                    count_dict[len(path)] += 1
                    
                    #if t not on the path and we not exceed the computation threshold, 
                    #then finally proceed to next recursion
                    if (t not in on_path_en) and (count_dict[len(path)] <= self.threshold):

                        helper(t, path + [r], on_path_en.union({t}), res, qualified_t, 
                               lower_bd, upper_bd, one_hop, count_dict)

        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, qualified_t, lower_bd, upper_bd, one_hop, count_dict)
        
        return(res, count_dict)

In [9]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [10]:
#load ids and relation/entity dicts
with open('../weight_bin/' + ids_name + '.pickle', 'rb') as handle:
    Dict = pickle.load(handle)
    
#save training data
one_hop = Dict['one_hop']
data = Dict['data']
s_t_r = Dict['s_t_r']

#save valid data
one_hop_valid = Dict['one_hop_valid']
data_valid = Dict['data_valid']
s_t_r_valid = Dict['s_t_r_valid']

#save shared dictionaries
entity2id = Dict['entity2id']
id2entity = Dict['id2entity']
relation2id = Dict['relation2id']
id2relation = Dict['id2relation']

#we want to keep the initial entity/relation dicts before adding new entities
entity2id_ini = deepcopy(entity2id)
id2entity_ini = deepcopy(id2entity)
relation2id_ini = deepcopy(relation2id)
id2relation_ini = deepcopy(id2relation)

num_r = len(id2relation)
num_r

474

In [11]:
test_path = '../data/' + data_name + '/test.txt'

In [12]:
#define the dictionaries and sets for load KG
one_hop_test = dict() 
data_test = set()
s_t_r_test = defaultdict(set)

#fill in the sets and dicts
Class_1.load_train_data(test_path, one_hop_test, data_test, s_t_r_test,
                        entity2id, id2entity, relation2id, id2relation)

In [13]:
#load the model
model = keras.models.load_model('../weight_bin/' + model_name + '.h5')

2023-03-16 22:28:34.214914: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [14]:
#load the one-hop neighbor model
model_2 = keras.models.load_model('../weight_bin/' + one_hop_model_name + '.h5')

In [15]:
#the function to do path-based entity scoring
def path_based_entity_scoring(s, r, lower_bd, upper_bd, one_hop, model):
    
    #key is each entity t, value is the score of (s,r,t)
    score_dict = defaultdict(float)
    
    #since we want to take the average score, we need to count how many path-combo 
    #we have added to the same entity t
    count_dict = defaultdict(int)
    
    #if s not in one_hop then directly return empty score dict
    if s not in one_hop:
        return(score_dict)
    
    #hold all connected entities to s, as well as the paths from s to them
    result, length_dict = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)
    
    #input holder lists for big-batch
    #unlike the relation prediction, we need to mark for each position:
    #which entity it stands for. So another list_mark is needed.
    list_m = list()
    list_1 = list()
    list_2 = list()
    list_3 = list()
    list_r = list()

    #for each entity t, if there are more than three paths from s to it, 
    #we do 10 times path selection
    for t in result:
        
        if len(result[t]) >= 3: #won't proceed if less than three existing paths
            
            count = 0 #iterate
            while count < 10:

                temp_pair = random.sample(list(result[t]), 3)
                path_1, path_2, path_3 = temp_pair[0], temp_pair[1], temp_pair[2]
                
                list_m.append(t)
                list_1.append(list(path_1) + [num_r]*abs(len(path_1)-upper_bd))
                list_2.append(list(path_2) + [num_r]*abs(len(path_2)-upper_bd))
                list_3.append(list(path_3) + [num_r]*abs(len(path_3)-upper_bd))
                list_r.append([r])
                count += 1
    
    #change to arrays
    input_1 = np.array(list_1)
    input_2 = np.array(list_2)
    input_3 = np.array(list_3)
    input_r = np.array(list_r)
    
    pred = model.predict([input_1, input_2, input_3, input_r], verbose = 0)
    
    #add to score dict using list_m as indication on which entity t it is
    for i in range(len(list_m)):
        
        t = list_m[i]
        score_dict[t] += float(pred[i])
        count_dict[t] += 1
        
    for t in score_dict:
        
        score_dict[t] = deepcopy(score_dict[t]/float(count_dict[t]))

    return(score_dict)

In [16]:
#Again, it is too slow to run the path-finding algorithm again and again on the complete FB15K-237
#Instead, we will find the subgraph for each entity once.
#then in the subgraph based training, the subgraphs are stored and used for multiple times
def store_subgraph_dicts(lower_bd, upper_bd, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity):
    
    #the set of all relation IDs
    relation_id_set = set()
    
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
    
    #the ids to start path finding
    existing_ids = list(existing_ids)
    random.shuffle(existing_ids)
    
    #Dict stores the subgraph for each entity
    Dict_1 = dict()
    
    count = 0
    for s in existing_ids:
        
        path_set = set()
            
        result, length_dict = Class_2.obtain_paths('any_target', s, 'any', lower_bd, upper_bd, one_hop)

        for t_ in result:
            for path in result[t_]:
                path_set.add(path)

        del(result, length_dict)
        
        path_list = list(path_set)
        
        path_select = random.sample(path_list, min(len(path_list), 100))
            
        Dict_1[s] = deepcopy(path_select)
        
        count += 1
        if count % 100 == 0:
            print('generating and storing paths for the path-based model', count, len(existing_ids))
        
    return(Dict_1)

In [17]:
#subgraph based source entity scoring
def subgraph_source_scoring(r, t, lower_bd, upper_bd, one_hop, Dict, id2entity, model_2):
    
    #final output: the score dict
    score_dict = defaultdict(float)
    count_dict = defaultdict(int)
    
    #if t not in one_hop then directly return empty score dict
    if t not in one_hop:
        return(score_dict)
    
    path_t = list(Dict[t])
    
    #lists holding the input to the network
    #again, we need another list to store which s it is for
    list_s = list()
    list_s_1 = list()
    list_s_2 = list()
    list_s_3 = list()
    list_t_1 = list()
    list_t_2 = list()
    list_t_3 = list()
    list_r = list()
    
    for s in id2entity:
        
        if s in one_hop:#proceed only if s exists in the training dataset
    
            path_s = list(Dict[s])
    
            #see if both path_s and path_t have at least three paths
            if len(path_s) >= 3 and len(path_t) >= 3:

                count = 0
                while count < 10:

                    #randomly obtain three paths
                    temp_s = random.sample(path_s, 3)
                    temp_t = random.sample(path_t, 3)
                    s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
                    t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                    #append the paths: note that we add the space holder id at the end of the shorter path
                    list_s_1.append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                    list_s_2.append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                    list_s_3.append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                    list_t_1.append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                    list_t_2.append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                    list_t_3.append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                    list_r.append([r])
                    list_s.append(s)
                    count += 1

    #change to arrays
    input_s_1 = np.array(list_s_1)
    input_s_2 = np.array(list_s_2)
    input_s_3 = np.array(list_s_3)
    input_t_1 = np.array(list_t_1)
    input_t_2 = np.array(list_t_2)
    input_t_3 = np.array(list_t_3)
    input_r = np.array(list_r)
    
    #make prediction
    pred = model_2.predict([input_s_1, input_s_2, input_s_3,
                            input_t_1, input_t_2, input_t_3, input_r], verbose = 0)
    #add to dictionary
    for i in range(len(list_s)):
        s = list_s[i]
        score_dict[s] += float(pred[i])
        count_dict[s] += 1
        
    #average the score
    for s in score_dict:
        score_dict[s] = deepcopy(score_dict[s]/float(count_dict[s]))
        
    print(len(score_dict))
        
    return(score_dict)

In [18]:
#subgraph based source entity scoring
def subgraph_target_scoring(s, r, lower_bd, upper_bd, one_hop, Dict, id2entity, model_2):
    
    #final output: the score dict
    score_dict = defaultdict(float)
    count_dict = defaultdict(int)
    
    #if s not in one_hop then directly return empty score dict
    if s not in one_hop:
        return(score_dict)
    
    path_s = list(Dict[s])
    
    #lists holding the input to the network
    #again, we need another list to store which t it is for
    list_t = list()
    list_s_1 = list()
    list_s_2 = list()
    list_s_3 = list()
    list_t_1 = list()
    list_t_2 = list()
    list_t_3 = list()
    list_r = list()
    
    for t in id2entity:
        
        if t in one_hop:#proceed only if s exists in the training dataset
    
            path_t = list(Dict[t])
    
            #see if both path_s and path_t have at least three paths
            if len(path_s) >= 3 and len(path_t) >= 3:

                count = 0
                while count < 10:

                    #randomly obtain three paths
                    temp_s = random.sample(path_s, 3)
                    temp_t = random.sample(path_t, 3)
                    s_p_1, s_p_2, s_p_3 = temp_s[0], temp_s[1], temp_s[2]
                    t_p_1, t_p_2, t_p_3 = temp_t[0], temp_t[1], temp_t[2]

                    #append the paths: note that we add the space holder id at the end of the shorter path
                    list_s_1.append(list(s_p_1) + [num_r]*abs(len(s_p_1)-upper_bd))
                    list_s_2.append(list(s_p_2) + [num_r]*abs(len(s_p_2)-upper_bd))
                    list_s_3.append(list(s_p_3) + [num_r]*abs(len(s_p_3)-upper_bd))

                    list_t_1.append(list(t_p_1) + [num_r]*abs(len(t_p_1)-upper_bd))
                    list_t_2.append(list(t_p_2) + [num_r]*abs(len(t_p_2)-upper_bd))
                    list_t_3.append(list(t_p_3) + [num_r]*abs(len(t_p_3)-upper_bd))

                    list_r.append([r])
                    list_t.append(t)
                    count += 1

    #change to arrays
    input_s_1 = np.array(list_s_1)
    input_s_2 = np.array(list_s_2)
    input_s_3 = np.array(list_s_3)
    input_t_1 = np.array(list_t_1)
    input_t_2 = np.array(list_t_2)
    input_t_3 = np.array(list_t_3)
    input_r = np.array(list_r)
    
    #make prediction
    pred = model_2.predict([input_s_1, input_s_2, input_s_3,
                            input_t_1, input_t_2, input_t_3, input_r], verbose = 0)
    #add to dictionary
    for i in range(len(list_t)):
        t = list_t[i]
        score_dict[t] += float(pred[i])
        count_dict[t] += 1
        
    #average the score
    for t in score_dict:
        score_dict[t] = deepcopy(score_dict[t]/float(count_dict[t]))
        
    print(len(score_dict))
        
    return(score_dict)

In [19]:
#obtain the subgraph Dict
#Dict_train = store_subgraph_dicts(1, 3, data, one_hop, s_t_r,
#                         relation2id, entity2id, id2relation, id2entity)

In [20]:
############################################
####source entity prediction################

#we select all the triples in the inductive test set
selected = list(data_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    r_true_inv = r_true + 1
    
    #run the path-based scoring
    score_dict = path_based_entity_scoring(t_true, r_true_inv, 1, 10, one_hop, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for e in id2entity:

        if e in score_dict:

            temp_list.append([score_dict[e], e])

        else:

            temp_list.append([0.0, e])

    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != s_true:
        
        #moreover, we want to remove existing triples
        if ((sorted_list[p][1], r_true, t_true) in data_test) or (
            (sorted_list[p][1], r_true, t_true) in data_valid) or (
            (sorted_list[p][1], r_true, t_true) in data):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', s_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

checkcorrect 3241 3241 real score 0.7750978827476501 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.0007434944237918215 cur_rank 1344 abs_cur_rank 1454 total_num 0 20466
checkcorrect 2382 2382 real score 0.0 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.0006395565316602385 cur_rank 1866 abs_cur_rank 3699 total_num 1 20466
checkcorrect 6498 6498 real score 0.9808678388595581 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.004171689373166751 cur_rank 88 abs_cur_rank 90 total_num 2 20466
checkcorrect 2298 2298 real score 0.6329669773578643 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.003551778874206704 cur_rank 590 abs_cur_rank 2861 total_num 3 20466
checkcorrect 13682 13682 real score 0.9968073487281799 Hits@1 0.2 Hits@3 0.2 Hits@10 0.2 MRR 0.20284142309936537 cur_rank 0 abs_cur_rank 0 total_num 4 20466
checkcorrect 10895 10895 real score 0.8823171377182006 Hits@1 0.16666666666666666 Hits@3 0.16666666666666666 Hits@10 0.16666666666666666 MRR 0.1695975823125342 cur_rank 295 abs_cur_rank 302 total_num 5 20466
check

checkcorrect 7590 7590 real score 0.8413313090801239 Hits@1 0.045454545454545456 Hits@3 0.09090909090909091 Hits@10 0.13636363636363635 MRR 0.08542415440889954 cur_rank 265 abs_cur_rank 340 total_num 43 20466
checkcorrect 2733 2733 real score 0.9828138589859009 Hits@1 0.044444444444444446 Hits@3 0.08888888888888889 Hits@10 0.15555555555555556 MRR 0.08599497566894869 cur_rank 8 abs_cur_rank 8 total_num 44 20466
checkcorrect 2929 2929 real score 0.9123480498790741 Hits@1 0.043478260869565216 Hits@3 0.08695652173913043 Hits@10 0.15217391304347827 MRR 0.08429404006711276 cur_rank 128 abs_cur_rank 147 total_num 45 20466
checkcorrect 4841 4841 real score 0.5510504096746445 Hits@1 0.0425531914893617 Hits@3 0.0851063829787234 Hits@10 0.14893617021276595 MRR 0.08252335434996569 cur_rank 932 abs_cur_rank 936 total_num 46 20466
checkcorrect 1360 1360 real score 0.6874051451683044 Hits@1 0.041666666666666664 Hits@3 0.08333333333333333 Hits@10 0.14583333333333334 MRR 0.08086881759396873 cur_rank 32

checkcorrect 114 114 real score 0.7798423767089844 Hits@1 0.023529411764705882 Hits@3 0.047058823529411764 Hits@10 0.09411764705882353 MRR 0.05395905163817013 cur_rank 572 abs_cur_rank 581 total_num 84 20466
checkcorrect 6696 6696 real score 0.6301090955734253 Hits@1 0.023255813953488372 Hits@3 0.046511627906976744 Hits@10 0.09302325581395349 MRR 0.05335153160478586 cur_rank 583 abs_cur_rank 606 total_num 85 20466
checkcorrect 5676 5676 real score 0.8187301337718964 Hits@1 0.022988505747126436 Hits@3 0.04597701149425287 Hits@10 0.09195402298850575 MRR 0.05299372345096329 cur_rank 44 abs_cur_rank 151 total_num 86 20466
checkcorrect 1132 1132 real score 0.9995072722434998 Hits@1 0.022727272727272728 Hits@3 0.056818181818181816 Hits@10 0.10227272727272728 MRR 0.05617940083599022 cur_rank 2 abs_cur_rank 2 total_num 87 20466
checkcorrect 5713 5713 real score 0.8230918645858765 Hits@1 0.02247191011235955 Hits@3 0.056179775280898875 Hits@10 0.10112359550561797 MRR 0.05556206030910426 cur_rank

checkcorrect 14363 14363 real score 0.0 Hits@1 0.016 Hits@3 0.064 Hits@10 0.12 MRR 0.059180407835358204 cur_rank 7139 abs_cur_rank 7145 total_num 124 20466
checkcorrect 3131 3131 real score 0.9876857817173004 Hits@1 0.015873015873015872 Hits@3 0.06349206349206349 Hits@10 0.11904761904761904 MRR 0.0589843947463529 cur_rank 28 abs_cur_rank 29 total_num 125 20466
checkcorrect 12537 12537 real score 0.0 Hits@1 0.015748031496062992 Hits@3 0.06299212598425197 Hits@10 0.11811023622047244 MRR 0.05852098375748654 cur_rank 7621 abs_cur_rank 7624 total_num 126 20466
checkcorrect 10915 10915 real score 0.7099664628505706 Hits@1 0.015625 Hits@3 0.0625 Hits@10 0.1171875 MRR 0.05807412256659017 cur_rank 755 abs_cur_rank 1195 total_num 127 20466
checkcorrect 7964 7964 real score 0.995951509475708 Hits@1 0.015503875968992248 Hits@3 0.06201550387596899 Hits@10 0.11627906976744186 MRR 0.05817764542598537 cur_rank 13 abs_cur_rank 13 total_num 128 20466
checkcorrect 10356 10356 real score 0.381872104108333

checkcorrect 8540 8540 real score 0.9476346433162689 Hits@1 0.018072289156626505 Hits@3 0.05421686746987952 Hits@10 0.1144578313253012 MRR 0.05662685768512574 cur_rank 8 abs_cur_rank 8 total_num 165 20466
checkcorrect 4445 4445 real score 0.7605818331241607 Hits@1 0.017964071856287425 Hits@3 0.05389221556886228 Hits@10 0.11377245508982035 MRR 0.056297432808123826 cur_rank 619 abs_cur_rank 700 total_num 166 20466
checkcorrect 9626 9626 real score 0.6732264280319213 Hits@1 0.017857142857142856 Hits@3 0.05357142857142857 Hits@10 0.1130952380952381 MRR 0.055976990078237324 cur_rank 405 abs_cur_rank 642 total_num 167 20466
checkcorrect 2280 2280 real score 0.7675442099571228 Hits@1 0.01775147928994083 Hits@3 0.05325443786982249 Hits@10 0.11242603550295859 MRR 0.05565326485364977 cur_rank 788 abs_cur_rank 843 total_num 168 20466
checkcorrect 10468 10468 real score 0.6004910618066788 Hits@1 0.01764705882352941 Hits@3 0.052941176470588235 Hits@10 0.11176470588235295 MRR 0.05534421779449599 cur

checkcorrect 4651 4651 real score 0.9917716562747956 Hits@1 0.014563106796116505 Hits@3 0.04854368932038835 Hits@10 0.11165048543689321 MRR 0.053973183347806986 cur_rank 16 abs_cur_rank 16 total_num 205 20466
checkcorrect 1033 1033 real score 0.4353216737508774 Hits@1 0.014492753623188406 Hits@3 0.04830917874396135 Hits@10 0.1111111111111111 MRR 0.05372031127605484 cur_rank 613 abs_cur_rank 717 total_num 206 20466
checkcorrect 6774 6774 real score 0.9959148824214935 Hits@1 0.014423076923076924 Hits@3 0.04807692307692308 Hits@10 0.11057692307692307 MRR 0.053762521317996886 cur_rank 15 abs_cur_rank 15 total_num 207 20466
checkcorrect 6194 6194 real score 0.9918090045452118 Hits@1 0.014354066985645933 Hits@3 0.04784688995215311 Hits@10 0.11004784688995216 MRR 0.05380432743609259 cur_rank 15 abs_cur_rank 15 total_num 208 20466
checkcorrect 3319 3319 real score 0.0 Hits@1 0.014285714285714285 Hits@3 0.047619047619047616 Hits@10 0.10952380952380952 MRR 0.05354938721101979 cur_rank 3746 abs_c

checkcorrect 1540 1540 real score 0.9568442523479461 Hits@1 0.016260162601626018 Hits@3 0.04878048780487805 Hits@10 0.10975609756097561 MRR 0.05369859226397286 cur_rank 78 abs_cur_rank 94 total_num 245 20466
checkcorrect 4193 4193 real score 0.8628424286842347 Hits@1 0.016194331983805668 Hits@3 0.05263157894736842 Hits@10 0.11336032388663968 MRR 0.055505480554402124 cur_rank 1 abs_cur_rank 6 total_num 246 20466
checkcorrect 3350 3350 real score 0.9656907558441162 Hits@1 0.016129032258064516 Hits@3 0.05241935483870968 Hits@10 0.11290322580645161 MRR 0.05555048533711287 cur_rank 14 abs_cur_rank 34 total_num 247 20466
checkcorrect 463 463 real score 0.965251910686493 Hits@1 0.01606425702811245 Hits@3 0.05220883534136546 Hits@10 0.11244979919678715 MRR 0.055372013954679125 cur_rank 89 abs_cur_rank 89 total_num 248 20466
checkcorrect 4098 4098 real score 0.9748398661613464 Hits@1 0.016 Hits@3 0.052 Hits@10 0.112 MRR 0.05540052589886041 cur_rank 15 abs_cur_rank 26 total_num 249 20466
checkco

checkcorrect 1135 1135 real score 0.7495968699455261 Hits@1 0.013986013986013986 Hits@3 0.05244755244755245 Hits@10 0.12237762237762238 MRR 0.05626504256118205 cur_rank 841 abs_cur_rank 889 total_num 285 20466
checkcorrect 3897 3897 real score 0.7885922789573669 Hits@1 0.013937282229965157 Hits@3 0.05226480836236934 Hits@10 0.12195121951219512 MRR 0.056078263926711336 cur_rank 375 abs_cur_rank 389 total_num 286 20466
checkcorrect 5311 5311 real score 0.9983724415302276 Hits@1 0.013888888888888888 Hits@3 0.052083333333333336 Hits@10 0.125 MRR 0.056269350201657166 cur_rank 8 abs_cur_rank 8 total_num 287 20466
checkcorrect 6 6 real score 0.9904924392700195 Hits@1 0.01384083044982699 Hits@3 0.05190311418685121 Hits@10 0.1245674740484429 MRR 0.05632180425434546 cur_rank 13 abs_cur_rank 14 total_num 288 20466
checkcorrect 1723 1723 real score 0.9968178033828735 Hits@1 0.013793103448275862 Hits@3 0.05172413793103448 Hits@10 0.12413793103448276 MRR 0.056246497200436306 cur_rank 28 abs_cur_rank

checkcorrect 12929 12929 real score 0.9542066097259522 Hits@1 0.015337423312883436 Hits@3 0.049079754601226995 Hits@10 0.1165644171779141 MRR 0.054650550568675 cur_rank 51 abs_cur_rank 53 total_num 325 20466
checkcorrect 4770 4770 real score 0.964806342124939 Hits@1 0.01529051987767584 Hits@3 0.04892966360856269 Hits@10 0.1162079510703364 MRR 0.05452711070236626 cur_rank 69 abs_cur_rank 80 total_num 326 20466
checkcorrect 2429 2429 real score 0.7378715515136719 Hits@1 0.01524390243902439 Hits@3 0.04878048780487805 Hits@10 0.11585365853658537 MRR 0.054377621052342305 cur_rank 181 abs_cur_rank 189 total_num 327 20466
checkcorrect 4850 4850 real score 0.7626906722784043 Hits@1 0.015197568389057751 Hits@3 0.0486322188449848 Hits@10 0.11550151975683891 MRR 0.054221808419348885 cur_rank 320 abs_cur_rank 328 total_num 328 20466
checkcorrect 2641 2641 real score 0.9842485010623931 Hits@1 0.015151515151515152 Hits@3 0.048484848484848485 Hits@10 0.11818181818181818 MRR 0.0544362877877751 cur_ran

checkcorrect 7552 7552 real score 0.9957388460636138 Hits@1 0.01366120218579235 Hits@3 0.04644808743169399 Hits@10 0.11475409836065574 MRR 0.05256104324408243 cur_rank 11 abs_cur_rank 14 total_num 365 20466
checkcorrect 981 981 real score 0.815712320804596 Hits@1 0.013623978201634877 Hits@3 0.04632152588555858 Hits@10 0.11444141689373297 MRR 0.05242382689511165 cur_rank 453 abs_cur_rank 501 total_num 366 20466
checkcorrect 7003 7003 real score 0.7033947288990021 Hits@1 0.01358695652173913 Hits@3 0.04619565217391304 Hits@10 0.11413043478260869 MRR 0.05228647872215787 cur_rank 531 abs_cur_rank 538 total_num 367 20466
checkcorrect 4592 4592 real score 0.6132359743118286 Hits@1 0.013550135501355014 Hits@3 0.04607046070460705 Hits@10 0.11382113821138211 MRR 0.05214675618625436 cur_rank 1371 abs_cur_rank 1407 total_num 368 20466
checkcorrect 6126 6126 real score 0.7439913809299469 Hits@1 0.013513513513513514 Hits@3 0.04594594594594595 Hits@10 0.11351351351351352 MRR 0.05200840780306484 cur_r

checkcorrect 4657 4657 real score 0.8342693090438843 Hits@1 0.012315270935960592 Hits@3 0.04433497536945813 Hits@10 0.11822660098522167 MRR 0.05161358941709326 cur_rank 6 abs_cur_rank 420 total_num 405 20466
checkcorrect 10999 10999 real score 0.0 Hits@1 0.012285012285012284 Hits@3 0.044226044226044224 Hits@10 0.11793611793611794 MRR 0.05148729902938996 cur_rank 4685 abs_cur_rank 7288 total_num 406 20466
checkcorrect 5457 5457 real score 0.517805990576744 Hits@1 0.012254901960784314 Hits@3 0.04411764705882353 Hits@10 0.11764705882352941 MRR 0.05136312027789893 cur_rank 1215 abs_cur_rank 1308 total_num 407 20466
checkcorrect 6038 6038 real score 0.7007575616240501 Hits@1 0.012224938875305624 Hits@3 0.044009779951100246 Hits@10 0.11735941320293398 MRR 0.05124134054605406 cur_rank 642 abs_cur_rank 646 total_num 408 20466
checkcorrect 4068 4068 real score 0.3232242401689291 Hits@1 0.012195121951219513 Hits@3 0.04390243902439024 Hits@10 0.11707317073170732 MRR 0.05111837738931001 cur_rank 1

checkcorrect 71 71 real score 0.9995151877403259 Hits@1 0.01569506726457399 Hits@3 0.04484304932735426 Hits@10 0.11883408071748879 MRR 0.053502382073292415 cur_rank 0 abs_cur_rank 0 total_num 445 20466
checkcorrect 59 59 real score 0.7519051313400269 Hits@1 0.015659955257270694 Hits@3 0.0447427293064877 Hits@10 0.1185682326621924 MRR 0.05338492707983986 cur_rank 999 abs_cur_rank 1452 total_num 446 20466
checkcorrect 5194 5194 real score 0.6515090614557266 Hits@1 0.015625 Hits@3 0.044642857142857144 Hits@10 0.11830357142857142 MRR 0.05326712785320891 cur_rank 1636 abs_cur_rank 1828 total_num 447 20466
checkcorrect 4683 4683 real score 0.8227619528770447 Hits@1 0.015590200445434299 Hits@3 0.044543429844097995 Hits@10 0.11804008908685969 MRR 0.05319798552440939 cur_rank 44 abs_cur_rank 201 total_num 448 20466
checkcorrect 2126 2126 real score 0.9995068788528443 Hits@1 0.015555555555555555 Hits@3 0.044444444444444446 Hits@10 0.11777777777777777 MRR 0.05320322456892305 cur_rank 17 abs_cur_r

KeyboardInterrupt: 

In [None]:
############################################
####target entity prediction################

#we select all the triples in the inductive test set
selected = list(data_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
        
    #first run the path-based scoring
    score_dict = path_based_entity_scoring(s_true, r_true, 1, 10, one_hop, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for e in id2entity:

        if e in score_dict:

            temp_list.append([score_dict[e], e])

        else:

            temp_list.append([0.0, e])

    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != t_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, r_true, sorted_list[p][1]) in data_test) or (
            (s_true, r_true, sorted_list[p][1]) in data_valid) or (
            (s_true, r_true, sorted_list[p][1]) in data):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', t_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

**bin**

In [None]:
#obtain the subgraph Dict
Dict_train = store_subgraph_dicts(1, 3, data, one_hop, s_t_r,
                         relation2id, entity2id, id2relation, id2entity)

In [22]:
############################################
####source entity prediction################

#we select all the triples in the inductive test set
selected = list(data_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    r_true_inv = r_true + 1
    
    #first run the path-based scoring
    score_dict_path = path_based_entity_scoring(t_true, r_true_inv, 1, 10, one_hop, model)
    
    #then run the subgraph scoring
    score_dict_subg = subgraph_source_scoring(r_true, t_true, 1, 3, one_hop, Dict_train, id2entity, model_2)
    
    #final score dict
    score_dict = defaultdict(float)
    
    for s in score_dict_path:
        score_dict[s] += score_dict_path[s]
    for s in score_dict_subg:
        score_dict[s] += score_dict_subg[s]
    
    #[... [score, r], ...]
    temp_list = list()
    
    for e in id2entity:

        if e in score_dict:

            temp_list.append([score_dict[e], e])

        else:

            temp_list.append([0.0, e])

    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != s_true:
        
        #moreover, we want to remove existing triples
        if ((sorted_list[p][1], r_true, t_true) in data_test) or (
            (sorted_list[p][1], r_true, t_true) in data_valid) or (
            (sorted_list[p][1], r_true, t_true) in data):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', s_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

14496
checkcorrect 2382 2382 real score 1.5545192778110504 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.001869158878504673 cur_rank 534 abs_cur_rank 1658 total_num 0 20466
14496
checkcorrect 3241 3241 real score 1.743541157245636 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.002537143541816439 cur_rank 311 abs_cur_rank 351 total_num 1 20466
14496
checkcorrect 6498 6498 real score 1.8224830389022828 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.008937805839471829 cur_rank 45 abs_cur_rank 46 total_num 2 20466
14496
checkcorrect 2298 2298 real score 1.4572955012321471 Hits@1 0.0 Hits@3 0.0 Hits@10 0.0 MRR 0.00702427864147807 cur_rank 778 abs_cur_rank 3221 total_num 3 20466
14496
checkcorrect 13682 13682 real score 1.7605762034654617 Hits@1 0.0 Hits@3 0.0 Hits@10 0.2 MRR 0.04561942291318245 cur_rank 4 abs_cur_rank 4 total_num 4 20466
14496
checkcorrect 10895 10895 real score 1.7302967727184297 Hits@1 0.0 Hits@3 0.0 Hits@10 0.16666666666666666 MRR 0.042182852427652046 cur_rank 39 abs_cur_rank 39 total_num 5

14496
checkcorrect 2929 2929 real score 1.3325686648488044 Hits@1 0.021739130434782608 Hits@3 0.06521739130434782 Hits@10 0.10869565217391304 MRR 0.06398536070062191 cur_rank 1636 abs_cur_rank 1705 total_num 45 20466
14496
checkcorrect 4841 4841 real score 1.5733115732669831 Hits@1 0.02127659574468085 Hits@3 0.06382978723404255 Hits@10 0.10638297872340426 MRR 0.06267276957894169 cur_rank 435 abs_cur_rank 437 total_num 46 20466
14496
checkcorrect 1360 1360 real score 1.4202022314071656 Hits@1 0.020833333333333332 Hits@3 0.0625 Hits@10 0.10416666666666667 MRR 0.06143607142684177 cur_rank 301 abs_cur_rank 305 total_num 47 20466
14496
checkcorrect 5029 5029 real score 1.7455645203590393 Hits@1 0.02040816326530612 Hits@3 0.061224489795918366 Hits@10 0.10204081632653061 MRR 0.060283304561998166 cur_rank 201 abs_cur_rank 1572 total_num 48 20466
14496
checkcorrect 5161 5161 real score -0.2713398814201355 Hits@1 0.02 Hits@3 0.06 Hits@10 0.1 MRR 0.05907918849788368 cur_rank 12902 abs_cur_rank 13

14496
checkcorrect 6696 6696 real score 1.199754573404789 Hits@1 0.011627906976744186 Hits@3 0.03488372093023256 Hits@10 0.08139534883720931 MRR 0.0424191540818083 cur_rank 2981 abs_cur_rank 3018 total_num 85 20466
14496
checkcorrect 5676 5676 real score 1.6874215111136437 Hits@1 0.011494252873563218 Hits@3 0.034482758620689655 Hits@10 0.08045977011494253 MRR 0.04197561688114982 cur_rank 260 abs_cur_rank 427 total_num 86 20466
14496
checkcorrect 1132 1132 real score 1.9611951053142547 Hits@1 0.011363636363636364 Hits@3 0.03409090909090909 Hits@10 0.09090909090909091 MRR 0.04339256062871251 cur_rank 5 abs_cur_rank 8 total_num 87 20466
14496
checkcorrect 5713 5713 real score 1.5829115450382232 Hits@1 0.011235955056179775 Hits@3 0.033707865168539325 Hits@10 0.0898876404494382 MRR 0.04291246456124574 cur_rank 1505 abs_cur_rank 1547 total_num 88 20466
14496
checkcorrect 10389 10389 real score 1.8206190764904022 Hits@1 0.011111111111111112 Hits@3 0.03333333333333333 Hits@10 0.1 MRR 0.0435467

14496
checkcorrect 14363 14363 real score 1.8736393332481385 Hits@1 0.008 Hits@3 0.04 Hits@10 0.112 MRR 0.04152261776066093 cur_rank 23 abs_cur_rank 24 total_num 124 20466
14496
checkcorrect 3131 3131 real score 1.8916177451610565 Hits@1 0.007936507936507936 Hits@3 0.03968253968253968 Hits@10 0.1111111111111111 MRR 0.04130179246206033 cur_rank 72 abs_cur_rank 75 total_num 125 20466
14496
checkcorrect 12537 12537 real score 1.5409196615219116 Hits@1 0.007874015748031496 Hits@3 0.03937007874015748 Hits@10 0.11023622047244094 MRR 0.0412915421277134 cur_rank 24 abs_cur_rank 25 total_num 126 20466
14496
checkcorrect 10915 10915 real score 0.8439805388450623 Hits@1 0.0078125 Hits@3 0.0390625 Hits@10 0.109375 MRR 0.04097156483109482 cur_rank 2989 abs_cur_rank 4035 total_num 127 20466
14496
checkcorrect 7964 7964 real score 0.8505736768245697 Hits@1 0.007751937984496124 Hits@3 0.03875968992248062 Hits@10 0.10852713178294573 MRR 0.04074735264458311 cur_rank 82 abs_cur_rank 83 total_num 128 2046

14496
checkcorrect 8618 8618 real score 1.8677910268306732 Hits@1 0.018292682926829267 Hits@3 0.042682926829268296 Hits@10 0.11585365853658537 MRR 0.05017379245701209 cur_rank 264 abs_cur_rank 305 total_num 163 20466
14496
checkcorrect 10649 10649 real score 0.8772981464862823 Hits@1 0.01818181818181818 Hits@3 0.04242424242424243 Hits@10 0.11515151515151516 MRR 0.04987505803323079 cur_rank 1132 abs_cur_rank 3884 total_num 164 20466
14496
checkcorrect 8540 8540 real score 1.79371035695076 Hits@1 0.018072289156626505 Hits@3 0.04216867469879518 Hits@10 0.1144578313253012 MRR 0.04990927789782311 cur_rank 17 abs_cur_rank 18 total_num 165 20466
14496
checkcorrect 4445 4445 real score 1.5892368614673615 Hits@1 0.017964071856287425 Hits@3 0.041916167664670656 Hits@10 0.11377245508982035 MRR 0.049666910738340043 cur_rank 105 abs_cur_rank 122 total_num 166 20466
14496
checkcorrect 9626 9626 real score 1.3667500019073486 Hits@1 0.017857142857142856 Hits@3 0.041666666666666664 Hits@10 0.1130952380

14496
checkcorrect 1841 1841 real score 1.6323628067970275 Hits@1 0.019704433497536946 Hits@3 0.03940886699507389 Hits@10 0.11330049261083744 MRR 0.0509178512290713 cur_rank 95 abs_cur_rank 329 total_num 202 20466
14496
checkcorrect 835 835 real score 1.266238847374916 Hits@1 0.0196078431372549 Hits@3 0.0392156862745098 Hits@10 0.11274509803921569 MRR 0.05067155935458294 cur_rank 1482 abs_cur_rank 1487 total_num 203 20466
14496
checkcorrect 3058 3058 real score 1.6953365862369538 Hits@1 0.01951219512195122 Hits@3 0.03902439024390244 Hits@10 0.11219512195121951 MRR 0.05050434902906278 cur_rank 60 abs_cur_rank 63 total_num 204 20466
14496
checkcorrect 4651 4651 real score 1.9013964414596556 Hits@1 0.019417475728155338 Hits@3 0.038834951456310676 Hits@10 0.11165048543689321 MRR 0.050501900732805194 cur_rank 19 abs_cur_rank 19 total_num 205 20466
14496
checkcorrect 1033 1033 real score 1.6141491025686263 Hits@1 0.01932367149758454 Hits@3 0.03864734299516908 Hits@10 0.1111111111111111 MRR 0

14496
checkcorrect 2201 2201 real score 1.849783968925476 Hits@1 0.024793388429752067 Hits@3 0.045454545454545456 Hits@10 0.128099173553719 MRR 0.0562052999193027 cur_rank 75 abs_cur_rank 353 total_num 241 20466
14496
checkcorrect 7140 7140 real score 1.5672392219305038 Hits@1 0.024691358024691357 Hits@3 0.04526748971193416 Hits@10 0.12757201646090535 MRR 0.05598153943334466 cur_rank 545 abs_cur_rank 553 total_num 242 20466
14496
checkcorrect 3714 3714 real score 1.8071396321058275 Hits@1 0.02459016393442623 Hits@3 0.045081967213114756 Hits@10 0.12704918032786885 MRR 0.055792286901112216 cur_rank 101 abs_cur_rank 136 total_num 243 20466
14496
checkcorrect 4091 4091 real score 1.7185117244720458 Hits@1 0.024489795918367346 Hits@3 0.044897959183673466 Hits@10 0.12653061224489795 MRR 0.05562371737752885 cur_rank 68 abs_cur_rank 68 total_num 244 20466
14496
checkcorrect 1540 1540 real score 1.8897958815097808 Hits@1 0.024390243902439025 Hits@3 0.044715447154471545 Hits@10 0.126016260162601

14496
checkcorrect 697 697 real score 1.8267245829105376 Hits@1 0.021352313167259787 Hits@3 0.0498220640569395 Hits@10 0.13167259786476868 MRR 0.056043792827669005 cur_rank 2 abs_cur_rank 2 total_num 280 20466
14496
checkcorrect 246 246 real score -0.21830328702926635 Hits@1 0.02127659574468085 Hits@3 0.04964539007092199 Hits@10 0.13120567375886524 MRR 0.05584538181206867 cur_rank 10882 abs_cur_rank 10882 total_num 281 20466
14496
checkcorrect 1633 1633 real score 1.8074527442455293 Hits@1 0.02120141342756184 Hits@3 0.04946996466431095 Hits@10 0.13074204946996468 MRR 0.055969281844213625 cur_rank 10 abs_cur_rank 11 total_num 282 20466
14496
checkcorrect 11858 11858 real score 1.5171434104442596 Hits@1 0.02112676056338028 Hits@3 0.04929577464788732 Hits@10 0.13028169014084506 MRR 0.055777083815013034 cur_rank 721 abs_cur_rank 1386 total_num 283 20466
14496
checkcorrect 937 937 real score 1.9244524121284483 Hits@1 0.021052631578947368 Hits@3 0.04912280701754386 Hits@10 0.1298245614035087

14496
checkcorrect 5963 5963 real score 1.6562221020460128 Hits@1 0.025 Hits@3 0.05 Hits@10 0.121875 MRR 0.05734530793939532 cur_rank 1054 abs_cur_rank 1528 total_num 319 20466
14496
checkcorrect 3864 3864 real score 1.4615337669849398 Hits@1 0.024922118380062305 Hits@3 0.04984423676012461 Hits@10 0.12149532710280374 MRR 0.05716916032780557 cur_rank 1246 abs_cur_rank 1263 total_num 320 20466
14496
checkcorrect 2452 2452 real score 1.5384984865784643 Hits@1 0.024844720496894408 Hits@3 0.049689440993788817 Hits@10 0.12111801242236025 MRR 0.05699543157283878 cur_rank 813 abs_cur_rank 820 total_num 321 20466
14496
checkcorrect 3863 3863 real score 1.9082739770412445 Hits@1 0.02476780185758514 Hits@3 0.04953560371517028 Hits@10 0.12074303405572756 MRR 0.05691003322340549 cur_rank 33 abs_cur_rank 59 total_num 322 20466
14496
checkcorrect 593 593 real score 1.3420038908720016 Hits@1 0.024691358024691357 Hits@3 0.04938271604938271 Hits@10 0.12037037037037036 MRR 0.05674168147331427 cur_rank 42

14496
checkcorrect 885 885 real score 1.5343511432409287 Hits@1 0.025069637883008356 Hits@3 0.05013927576601671 Hits@10 0.11977715877437325 MRR 0.056882267055091404 cur_rank 725 abs_cur_rank 978 total_num 358 20466
14496
checkcorrect 9637 9637 real score 1.7651563882827759 Hits@1 0.025 Hits@3 0.05 Hits@10 0.12222222222222222 MRR 0.057187223720679116 cur_rank 5 abs_cur_rank 5 total_num 359 20466
14496
checkcorrect 1420 1420 real score 1.9593048453330995 Hits@1 0.024930747922437674 Hits@3 0.04986149584487535 Hits@10 0.12188365650969529 MRR 0.05728063609516225 cur_rank 10 abs_cur_rank 51 total_num 360 20466
14496
checkcorrect 3514 3514 real score 0.870198130607605 Hits@1 0.024861878453038673 Hits@3 0.049723756906077346 Hits@10 0.12154696132596685 MRR 0.05712465918181664 cur_rank 1223 abs_cur_rank 4182 total_num 361 20466
14496
checkcorrect 1994 1994 real score 1.5652328193187715 Hits@1 0.024793388429752067 Hits@3 0.049586776859504134 Hits@10 0.12121212121212122 MRR 0.05699540139242501 cur

14496
checkcorrect 7384 7384 real score 1.8864255964756012 Hits@1 0.02512562814070352 Hits@3 0.04773869346733668 Hits@10 0.12311557788944724 MRR 0.05715515504806254 cur_rank 0 abs_cur_rank 0 total_num 397 20466
14496
checkcorrect 10822 10822 real score 0.7609229892492294 Hits@1 0.02506265664160401 Hits@3 0.047619047619047616 Hits@10 0.12280701754385964 MRR 0.057013891850549864 cur_rank 1263 abs_cur_rank 1263 total_num 398 20466
14496
checkcorrect 5819 5819 real score 1.4669575214385988 Hits@1 0.025 Hits@3 0.0475 Hits@10 0.1225 MRR 0.05691103966060603 cur_rank 62 abs_cur_rank 62 total_num 399 20466
14496
checkcorrect 4995 4995 real score 1.7726864218711853 Hits@1 0.02493765586034913 Hits@3 0.04738154613466334 Hits@10 0.12219451371571072 MRR 0.056806337250316065 cur_rank 66 abs_cur_rank 71 total_num 400 20466
14496
checkcorrect 8552 8552 real score 1.836226212978363 Hits@1 0.024875621890547265 Hits@3 0.04975124378109453 Hits@10 0.12437810945273632 MRR 0.05790880904820085 cur_rank 1 abs_c

14496
checkcorrect 5966 5966 real score 1.4378474205732346 Hits@1 0.02517162471395881 Hits@3 0.05263157894736842 Hits@10 0.12585812356979406 MRR 0.05930618073887425 cur_rank 800 abs_cur_rank 3100 total_num 436 20466
14496
checkcorrect 11485 11485 real score 0.7678351819515228 Hits@1 0.02511415525114155 Hits@3 0.05251141552511415 Hits@10 0.12557077625570776 MRR 0.05917132666421567 cur_rank 4164 abs_cur_rank 7825 total_num 437 20466
14496
checkcorrect 8282 8282 real score 1.62577945291996 Hits@1 0.025056947608200455 Hits@3 0.05239179954441914 Hits@10 0.1252847380410023 MRR 0.05922636540378085 cur_rank 11 abs_cur_rank 248 total_num 438 20466
14496
checkcorrect 4754 4754 real score 1.8117413192987442 Hits@1 0.025 Hits@3 0.05227272727272727 Hits@10 0.125 MRR 0.059119476214115944 cur_rank 81 abs_cur_rank 155 total_num 439 20466
14496
checkcorrect 7612 7612 real score 1.5691459357738495 Hits@1 0.024943310657596373 Hits@3 0.05215419501133787 Hits@10 0.12471655328798185 MRR 0.05898902349054479 

14496
checkcorrect 1346 1346 real score 1.8513185679912567 Hits@1 0.0273109243697479 Hits@3 0.052521008403361345 Hits@10 0.12815126050420167 MRR 0.06137963817788047 cur_rank 105 abs_cur_rank 114 total_num 475 20466
14496
checkcorrect 9512 9512 real score 1.7273660480976103 Hits@1 0.027253668763102725 Hits@3 0.05450733752620545 Hits@10 0.129979035639413 MRR 0.062299177720484496 cur_rank 1 abs_cur_rank 1 total_num 476 20466
14496
checkcorrect 4528 4528 real score 0.6196912705898284 Hits@1 0.027196652719665274 Hits@3 0.05439330543933055 Hits@10 0.1297071129707113 MRR 0.06223857971967456 cur_rank 29 abs_cur_rank 44 total_num 477 20466
14496
checkcorrect 7799 7799 real score 1.8701491177082064 Hits@1 0.027139874739039668 Hits@3 0.054279749478079335 Hits@10 0.12943632567849686 MRR 0.062131841789385286 cur_rank 89 abs_cur_rank 90 total_num 478 20466
14496
checkcorrect 2227 2227 real score 1.7701385140419008 Hits@1 0.027083333333333334 Hits@3 0.05416666666666667 Hits@10 0.12916666666666668 MRR

KeyboardInterrupt: 

In [None]:
############################################
####target entity prediction################

#we select all the triples in the inductive test set
selected = list(data_test)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
        
    #first run the path-based scoring
    score_dict_path = path_based_entity_scoring(s_true, r_true, 1, 10, one_hop, model)
    
    #then run the subgraph scoring
    score_dict_subg = subgraph_target_scoring(s_true, r_true, 1, 3, one_hop, Dict_train, id2entity, model_2)
    
    #final score dict
    score_dict = defaultdict(float)
    
    for t in score_dict_path:
        score_dict[t] += score_dict_path[t]
    for t in score_dict_subg:
        score_dict[t] += score_dict_subg[t]
    
    #[... [score, r], ...]
    temp_list = list()
    
    for e in id2entity:

        if e in score_dict:

            temp_list.append([score_dict[e], e])

        else:

            temp_list.append([0.0, e])

    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != t_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, r_true, sorted_list[p][1]) in data_test) or (
            (s_true, r_true, sorted_list[p][1]) in data_valid) or (
            (s_true, r_true, sorted_list[p][1]) in data):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('checkcorrect', t_true, sorted_list[p][1],
          'real score', sorted_list[p][0],
          'Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))