In [1]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
from collections import defaultdict
from copy import deepcopy
import pickle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import initializers
from tensorflow.keras.utils import plot_model

In [2]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = dict()
            
            if r not in one_hop[s]:
                
                one_hop[s][r] = set()
            
            one_hop[s][r].add(t)
            
            if t not in one_hop:
                
                one_hop[t] = dict()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            if r_inv not in one_hop[t]:
                
                one_hop[t][r_inv] = set()
            
            one_hop[t][r_inv].add(s)

In [3]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, size_bd=50, threshold=100000):
                
        self.size_bd = size_bd
        
        self.threshold = threshold
    
    '''
    Given an entity s, here is the function to find:
      1. any else entity t that is directely connected to s
      2. most of the paths from s to each t with length L
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on active entities for validation')
        
        #here is the result dict. Its key is each entity t that is directly connected to s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #direct_nb contains all the direct neighbour of s
        direct_nb = set()
        
        if mode == 'direct_neighbour':
        
            for r in one_hop[s]:
            
                for t in one_hop[s][r]:
                
                    direct_nb.add(t)
                    
        elif mode == 'target_specified':
            
            direct_nb.add(t_input)
            
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                direct_nb.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {e1, ..., ek-1, node}
        from s to this node, we further find the direct neighbor t' of this node. 
        If t' is not a on-path entity (not among e1,...ek-1), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict):
            
            #when the current path is within lower_bd and upper_bd and its corresponding
            #length still within the size_bd and its tail node is within the note dict, 
            #we will then intend to add this path
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in direct_nb) and (length_dict[len(path)] < self.size_bd):
                
                #if this path already exists between the source entity and the current target node,
                #we will not count it.
                #here is an interesting situation: this path may exist between s and some other node t,
                #however, it does not exist between s and this node t. Then, we still count it: length_dict[len(path)] += 1
                #That is, each path may be counted for multiple times.
                #We count how many paths we "actually" found between entity pairs
                #Same type of path between different entity pairs are count separately.
                if tuple(path) not in res[node]:
                
                    res[node].add(tuple(path))
                
                    length_dict[len(path)] += 1
                
            #For some rare entities, we may face such a case: so many paths are evaluated,
            #but no entities on the paths are direct neighbors of the rare entity.
            #In this case, the recursion cannot be bounded and stoped by the size threshold.
            #In order to cure this, we count how many times the recursion happens on a specific length, using the count_dict.
            #Its key is length, value counts the recursion occurred to that length. 
            #The recursion is forced to stop for that length (and hence for longer lengths) once reach the threshold.
            if (len(path) < upper_bd) and (length_dict[len(path) + 1] < self.size_bd) and (
                count_dict[len(path)] <= self.threshold):
                
                #we randomly shuffle relation r so that the reading in order is not fixed
                temp_list = list()
                
                for r in one_hop[node]:
                    
                    temp_list.append(r)
                
                for i_0 in range(len(temp_list)):
                    
                    if count_dict[len(path)] > self.threshold:
                        break
                    
                    r = random.choice(temp_list)
                    
                    for i_1 in range(len(one_hop[node][r])):
                        
                        if count_dict[len(path)] > self.threshold:
                            break
                        
                        t = random.choice(list(one_hop[node][r]))
                        
                        if t not in on_path_en:
                                
                            count_dict[len(path)] += 1

                            helper(t, path + [r], on_path_en.union({t}), res, direct_nb, 
                                   lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        return(res, length_dict)

In [4]:
train_path = '../data/fb237_v4/train.txt'

In [5]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [6]:
#define the dictionaries and sets for load KG
one_hop = dict() 
data = set()
s_t_r = defaultdict(set)
entity2id = dict()
id2entity = dict()
relation2id = dict()
id2relation = dict()

#fill in the sets and dicts
Class_1.load_train_data(train_path, one_hop, data, s_t_r,
                        entity2id, id2entity, relation2id, id2relation)

### Build the deep neural network structure

We use biLSTM to train on the input path embedding sequence to predict the output embedding or the relation.

In [7]:
# Input layer, using integer to represent each relation type
#note that inputs_path is the path inputs, while inputs_out_re is the output relation inputs
fst_path = keras.Input(shape=(None,), dtype="int32")
scd_path = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of two paths are not the same
in_embd_var = layers.Embedding(len(relation2id)+1, 300)

# Obtain the embedding
fst_p_embd = in_embd_var(fst_path)
scd_p_embd = in_embd_var(scd_path)

# Embed each integer in a 300-dimensional vector as output
rela_embd = layers.Embedding(len(relation2id)+1, 300)(id_rela)

#add 2 layer bi-directional LSTM
lstm_layer_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_layer_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))

#first LSTM layer
fst_lstm_mid = lstm_layer_1(fst_p_embd)
scd_lstm_mid = lstm_layer_1(scd_p_embd)

#second LSTM layer
fst_lstm_out = lstm_layer_2(fst_lstm_mid)
scd_lstm_out = lstm_layer_2(scd_lstm_mid)

####apply the attention mechanism##############
#first expand the dimention at the end to get ready for conv2D: (Batch,time,300,1)
fst_exp_dim = tf.expand_dims(fst_lstm_out, axis=-1)
scd_exp_dim = tf.expand_dims(scd_lstm_out, axis=-1)

#define the attention layer using convolutional 2D: output is
att_layer_conv2D = layers.Conv2D(100, (1, 300), padding='valid', activation='relu', 
                   input_shape=(None, 300), data_format='channels_last')

#shape: (Batch,Time,1,100)
fst_att_mid = att_layer_conv2D(fst_exp_dim)
scd_att_mid = att_layer_conv2D(scd_exp_dim)

#squeeze out the dim 1 to become: (Batch, Time, 100)
fst_squez = tf.squeeze(fst_att_mid, 2)
scd_squez = tf.squeeze(scd_att_mid, 2)

#expand the dimention again to become (Batch, Time, 100, 1)
fst_exp_dim_2 = tf.expand_dims(fst_squez, axis=-1)
scd_exp_dim_2 = tf.expand_dims(scd_squez, axis=-1)

#obtain the attention score for each time step by another conv2D layer
att_layer_conv2D_2 = layers.Conv2D(1, (1, 100), padding='valid', activation='relu', 
                     input_shape=(None, 100), data_format='channels_last')

#obtain (Batch, Time, 1, 1)
fst_mid_score = att_layer_conv2D_2(fst_exp_dim_2)
scd_mid_score = att_layer_conv2D_2(scd_exp_dim_2)

#squeeze again to obtain (Batch, Time, 1)
fst_squez_2 = tf.squeeze(fst_mid_score, -1)
scd_squez_2 = tf.squeeze(scd_mid_score, -1)

#softmax the attention score
softmax_l = layers.Softmax(1) #define softmax

fst_att_score = softmax_l(fst_squez_2)
scd_att_score = softmax_l(scd_squez_2)

#multiply the attention score to lstm output
fst_att_befsum = layers.Multiply()([fst_lstm_out, fst_att_score])
scd_att_befsum = layers.Multiply()([scd_lstm_out, scd_att_score])

#sum over time dimension to complete the attention: (Batch, 300)
fst_att_out = tf.reduce_sum(fst_att_befsum, axis=1)
scd_att_out = tf.reduce_sum(scd_att_befsum, axis=1)

#concatenate the output vector from both siamese tunnel: (Batch, 600)
path_concat = layers.concatenate([fst_att_out, scd_att_out], axis=-1)

#multiply into output embd size by dense layer: (Batch, 300)
path_out_vect = layers.Dense(300, activation='tanh')(path_concat)

#remove the time dimension from the output embd since there is only one step
rela_out_embd = tf.reduce_sum(rela_embd, axis=1)

# Normalize the vectors to have unit length
path_out_vect_norm = tf.math.l2_normalize(path_out_vect, axis=-1)
rela_out_embd_norm = tf.math.l2_normalize(rela_out_embd, axis=-1)

# Calculate the dot product
dot_product = layers.Dot(axes=-1)([path_out_vect_norm, rela_out_embd_norm])

#put together the model
model = keras.Model([fst_path, scd_path, id_rela], dot_product)

2023-01-07 23:28:18.762446: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [8]:
#config the Adam optimizer 
opt = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

#compile the model
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['binary_accuracy'])

### Build the batches
We build each big-batch for each path combination with length (i,j). Then, we iteratively train the siamese network on different big-batches. The length of each big-batch is N.

To be specific:
* If we allow the length difference between two paths in a combination to be d, then the combination with path length i and path length j, denoted as (i,j), will be like (2,2), (2,3), (2,4), (3,3), (3,4), (3,5), ... 
* We will first build all the big-batches before fitting the NN model. 
* That is, we will perform the ObtainPathsByDynamicProgramming class function for some randomly chosen source entities. Then, for each target entity, we will further have two for loops:
* for path_1 in all the 
* Do this until all the slots in all big-batchs are filled.
* In every epoch, big-batchs will be re-filled.

Then, in the training, we will use negative sampling: In each batch (actual batch, not the big-batch), we will include K true output relation embeddings and K random selected output relation embeddings. The true label is [1,0], while the false label is [0,1].

In [9]:
#function to build all the big batches
def build_big_batches(holder_len, lower_bd, upper_bd, Class_2, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity):
    
    if holder_len % 10 != 0:
        raise ValueError('We would like to take 10X as a big-batch size')
    
    #the set of all relation IDs
    relation_id_set = set()
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #count how many appending has performed
    count = 0

    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    
    carry_on = True
    
    while carry_on:

        #obtain paths by dynamic programming
        source_id = random.choice(existing_ids)

        result, length_dict = Class_2.obtain_paths('direct_neighbour', source_id, 
                                                   'not_specified', lower_bd, upper_bd, one_hop)
        
        #We want to increase the diversity of paths and targets.
        #So we abandon one sub-graph from a source_id, if we sampled more than K1 path pairs
        #Note that we mean "sampled", not "appended"! 
        #We do not care whether the pair is actually appended.
        threshold_0 = 1000
        count_0 = 0
        
        for target_id in result:

            if (not carry_on) or (count_0 > threshold_0):
                break
            
            #we want to make sure s, t are indeed directly connected, 
            #otherwise there is no relation for positive sample
            #also, we want to make sure s and t and not connected by all relations, 
            #although this situation is rare. 
            #But in that case, there is no relation for negative samples
            #Also, we want at least two different paths here between s and t
            if ((source_id, target_id) in s_t_r) and (
                len(s_t_r[(source_id, target_id)]) < len(id2relation)) and (
                len(result[target_id]) >= 2):
                
                dir_r = list(s_t_r[(source_id, target_id)])
                
                non_dir_r = list(relation_id_set.difference(dir_r))
                
                if len(dir_r) <= 0:
                    
                    raise ValueError('errors when creating s_t_r !!')
                    
                temp_path_list = list(result[target_id])
                    
                #futhermore, we will abandon one targed_id if we sampled more than K2 times
                threshold_1 = 50
                count_1 = 0
                
                while count_1 <= threshold_1 and count_0 <= threshold_0:
                
                    temp_pair = random.sample(temp_path_list,2)
                    
                    path_1, path_2 = temp_pair[0], temp_pair[1]

                    #decide which path is shorter and which is longer
                    if len(path_1) <= len(path_2):

                        path_s, path_l = path_1, path_2

                    else:

                        path_s, path_l = path_2, path_1                            

                    if (len(path_s) < lower_bd) or (len(path_l) > upper_bd):

                        raise ValueError('something wrong with the path finding')

                    #proceed when the entire length not yet reached,
                    #and whether this path pair is new, and whether the two paths are different
                    #But it is optional to require the path to be new. 
                    #We may remove this requirment, especially for short paths
                    '''remember to cancel the comment below when using path_comb'''
                    if (carry_on) and (path_s != path_l):

                        #we always add one positive and one negative situation together,
                        #hence, the length of list should always be even.
                        #also we want to make sure the length of lists coincide
                        if (len(x_p_list['s']) != len(y_list)) or (
                            len(x_p_list['s']) != len(x_p_list['l'])) or (
                            len(y_list) != len(x_r_list)) or (
                            len(y_list) % 2 != 0):

                            raise ValueError('error when building big batches: length error')

                        #####positive#####################
                        #we randomly choose one direction relation as the target relation
                        relation_id = random.choice(dir_r)

                        #append the paths: note that we add the space holder id at the end
                        #of the shorter path
                        x_p_list['s'].append(list(path_s) + [num_r]*abs(len(path_s)-upper_bd))
                        x_p_list['l'].append(list(path_l) + [num_r]*abs(len(path_l)-upper_bd))

                        #append relation
                        x_r_list.append([relation_id])
                        y_list.append(1.)

                        #####negative#####################
                        relation_id = random.choice(non_dir_r)

                        #append the paths: note that we add the space holder id at the end
                        #of the shorter path
                        x_p_list['s'].append(list(path_s) + [num_r]*abs(len(path_s)-upper_bd))
                        x_p_list['l'].append(list(path_l) + [num_r]*abs(len(path_l)-upper_bd))

                        #append relation
                        x_r_list.append([relation_id])
                        y_list.append(0.)

                        ######add to path combinations#####
                        #here is the tricky part: we have to add both (path_s, path_l)
                        #and (path_l, path_s). This is because when the length are the same
                        #adding only one situation won't guarantee that 
                        #the same path with different order is also considered.
                        #in other words: path combination don't have order, but our dict does.
                        #so we have to add both situations.
                        '''remember to cancel the comment here when using path_comb'''
                        #path_comb[(len(path_s), len(path_l))].add((path_s, path_l))
                        #path_comb[(len(path_s), len(path_l))].add((path_l, path_s))

                        count += 2

                        if count % 20000 == 0:
                            print('generating big-batches', count, holder_len)

                    if len(y_list) >= holder_len:

                        carry_on = False
                        
                    count_1 += 1
                    count_0 += 1

### Start Training: load the KG and call classes

Here, we use the validation set to see the training efficiency. That is, we use the validation to check whether the true relation between entities can be predicted by paths.

The trick is: in validation, we have to use the same relation ID and entity ID as in the training. But we don't want to use the links in training anymore. That is, in validation, we want to use (and update if necessary) entity2id, id2entity, relation2id and id2relation. But we want to use new one_hop, data, data_ and s_t_r for validation set. Then, path-finding will also be based on new one_hop.


In [10]:
#difine the names for saving
model_name = 'SiaLP_Jan_7_2023_fb237_v4'
ids_name = 'IDs_Jan_7_2023_fb237_v4'

In [11]:
#first, we save the relation and ids
Dict = dict()
Dict['one_hop'] = one_hop
Dict['data'] = data
Dict['s_t_r'] = s_t_r
Dict['entity2id'] = entity2id
Dict['id2entity'] = id2entity
Dict['relation2id'] = relation2id
Dict['id2relation'] = id2relation

with open('../weight_bin/' + ids_name + '.pickle', 'wb') as handle:
    pickle.dump(Dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [12]:
holder_len = 1000000
lower_bd = 2
upper_bd = 10
num_epoch = 10
batch_size = 4

#90% to be train, 10% to be validation
train_len = 9*int(holder_len/10)
    
######################################
###pre-define the lists###############

#define the lists
x_p_list, x_r_list, y_list = {'s': [], 'l': []}, list(), list()

#######################################
###build the big-batches###############      

#fill in the training array list
build_big_batches(holder_len, lower_bd, upper_bd, Class_2, one_hop, s_t_r,
                      x_p_list, x_r_list, y_list,
                      relation2id, entity2id, id2relation, id2entity)

#######################################
###do the training#####################

#generate the input arrays
x_train_s = np.asarray(x_p_list['s'][:train_len], dtype='int')
x_train_l = np.asarray(x_p_list['l'][:train_len], dtype='int')
x_train_r = np.asarray(x_r_list[:train_len], dtype='int')
y_train = np.asarray(y_list[:train_len], dtype='int')

x_valid_s = np.asarray(x_p_list['s'][train_len:], dtype='int')
x_valid_l = np.asarray(x_p_list['l'][train_len:], dtype='int')
x_valid_r = np.asarray(x_r_list[train_len:], dtype='int')
y_valid = np.asarray(y_list[train_len:], dtype='int')

model.fit([x_train_s, x_train_l, x_train_r], y_train, 
          validation_data=([x_valid_s, x_valid_l, x_valid_r], y_valid),
          batch_size=batch_size, epochs=num_epoch)

# Save model and weights
add_h5 = model_name + '.h5'
save_dir = os.path.join(os.getcwd(), '../weight_bin')

if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path = os.path.join(save_dir, add_h5)
model.save(model_path)
print('Save model')
del(model)

del(x_train_s, x_train_l, x_train_r, y_train)
del(x_valid_s, x_valid_l, x_valid_r, y_valid)

del(x_p_list, x_r_list, y_list)

generating big-batches 20000 1000000
generating big-batches 40000 1000000
generating big-batches 60000 1000000
generating big-batches 80000 1000000
generating big-batches 100000 1000000
generating big-batches 120000 1000000
generating big-batches 140000 1000000
generating big-batches 160000 1000000
generating big-batches 180000 1000000
generating big-batches 200000 1000000
generating big-batches 220000 1000000
generating big-batches 240000 1000000
generating big-batches 260000 1000000
generating big-batches 280000 1000000
generating big-batches 300000 1000000
generating big-batches 320000 1000000
generating big-batches 340000 1000000
generating big-batches 360000 1000000
generating big-batches 380000 1000000
generating big-batches 400000 1000000
generating big-batches 420000 1000000
generating big-batches 440000 1000000
generating big-batches 460000 1000000
generating big-batches 480000 1000000
generating big-batches 500000 1000000
generating big-batches 520000 1000000
generating big-b

### Result on the testset for inductive link prediction

We use the testset for inductive link prediction.

In [1]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
from collections import defaultdict
from copy import deepcopy
import pickle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras.utils import plot_model

In [2]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = dict()
            
            if r not in one_hop[s]:
                
                one_hop[s][r] = set()
            
            one_hop[s][r].add(t)
            
            if t not in one_hop:
                
                one_hop[t] = dict()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            if r_inv not in one_hop[t]:
                
                one_hop[t][r_inv] = set()
            
            one_hop[t][r_inv].add(s)

In [3]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, size_bd=50, threshold=100000):
                
        self.size_bd = size_bd
        
        self.threshold = threshold
    
    '''
    Given an entity s, here is the function to find:
      1. any else entity t that is directely connected to s
      2. most of the paths from s to each t with length L
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on active entities for validation')
        
        #here is the result dict. Its key is each entity t that is directly connected to s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #direct_nb contains all the direct neighbour of s
        direct_nb = set()
        
        if mode == 'direct_neighbour':
        
            for r in one_hop[s]:
            
                for t in one_hop[s][r]:
                
                    direct_nb.add(t)
                    
        elif mode == 'target_specified':
            
            direct_nb.add(t_input)
            
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                direct_nb.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {e1, ..., ek-1, node}
        from s to this node, we further find the direct neighbor t' of this node. 
        If t' is not a on-path entity (not among e1,...ek-1), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict):
            
            #when the current path is within lower_bd and upper_bd and its corresponding
            #length still within the size_bd and its tail node is within the note dict, 
            #we will then intend to add this path
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in direct_nb) and (length_dict[len(path)] < self.size_bd):
                
                #if this path already exists between the source entity and the current target node,
                #we will not count it.
                #here is an interesting situation: this path may exist between s and some other node t,
                #however, it does not exist between s and this node t. Then, we still count it: length_dict[len(path)] += 1
                #That is, each path may be counted for multiple times.
                #We count how many paths we "actually" found between entity pairs
                #Same type of path between different entity pairs are count separately.
                if tuple(path) not in res[node]:
                
                    res[node].add(tuple(path))
                
                    length_dict[len(path)] += 1
                
            #For some rare entities, we may face such a case: so many paths are evaluated,
            #but no entities on the paths are direct neighbors of the rare entity.
            #In this case, the recursion cannot be bounded and stoped by the size threshold.
            #In order to cure this, we count how many times the recursion happens on a specific length, using the count_dict.
            #Its key is length, value counts the recursion occurred to that length. 
            #The recursion is forced to stop for that length (and hence for longer lengths) once reach the threshold.
            if (len(path) < upper_bd) and (length_dict[len(path) + 1] < self.size_bd) and (
                count_dict[len(path)] <= self.threshold):
                
                #we randomly shuffle relation r so that the reading in order is not fixed
                temp_list = list()
                
                for r in one_hop[node]:
                    
                    temp_list.append(r)
                
                for i_0 in range(len(temp_list)):
                    
                    if count_dict[len(path)] > self.threshold:
                        break
                    
                    r = random.choice(temp_list)
                    
                    for i_1 in range(len(one_hop[node][r])):
                        
                        if count_dict[len(path)] > self.threshold:
                            break
                        
                        t = random.choice(list(one_hop[node][r]))
                        
                        if t not in on_path_en:
                                
                            count_dict[len(path)] += 1

                            helper(t, path + [r], on_path_en.union({t}), res, direct_nb, 
                                   lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        return(res, length_dict)

In [4]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [5]:
#load ids and relation/entity dicts
with open('../weight_bin/IDs_Jan_7_2023_fb237_v4.pickle', 'rb') as handle:
    Dict = pickle.load(handle)

one_hop = Dict['one_hop']
data = Dict['data']
s_t_r = Dict['s_t_r']
entity2id = Dict['entity2id']
id2entity = Dict['id2entity']
relation2id = Dict['relation2id']
id2relation = Dict['id2relation']

#we want to keep the initial entity/relation dicts
entity2id_ini = deepcopy(entity2id)
id2entity_ini = deepcopy(id2entity)
relation2id_ini = deepcopy(relation2id)
id2relation_ini = deepcopy(id2relation)

num_r = len(id2relation)

In [6]:
#load the model
model = keras.models.load_model('../weight_bin/SiaLP_Jan_7_2023_fb237_v4.h5')

2023-01-13 22:28:07.028174: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [7]:
ind_train_path = '../data/fb237_v4_ind/train.txt'
ind_valid_path = '../data/fb237_v4_ind/valid.txt'
ind_test_path = '../data/fb237_v4_ind/test.txt'

In [8]:
#load the test dataset
one_hop_ind = dict() 
data_ind = set()
s_t_r_ind = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_train_path, 
                        one_hop_ind, data_ind, s_t_r_ind,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [9]:
print(size_0, size_1, len(data_ind))

4707 7758 11714


In [10]:
#load the test dataset
one_hop_test = dict() 
data_test = set()
s_t_r_test = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_test_path, 
                        one_hop_test, data_test, s_t_r_test,
                        entity2id, id2entity, relation2id, id2relation)


len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [11]:
print(size_0, size_1, len(data_test))

7758 7758 1424


In [12]:
#load the validation for existing triple removal when ranking
one_hop_valid = dict() 
data_valid = set()
s_t_r_valid = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_valid_path, 
                        one_hop_valid, data_valid, s_t_r_valid,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [13]:
print(size_0, size_1, len(data_valid))

7758 7758 1416


In [14]:
print(len(entity2id), len(entity2id_ini))

7758 4707


In [15]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_test:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [16]:
overlapping = 0

for ele in data_valid:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [17]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_ind:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

0

In [18]:
def relation_ranking(s, t, lower_bd, upper_bd, one_hop, id2relation, model):
    
    path_holder = set()
    
    for iteration in range(20):
    
        result, length_dict = Class_2.obtain_paths('target_specified', 
                                                   s, t, lower_bd, upper_bd, one_hop)
        if t in result:
            
            for path in result[t]:
                
                path_holder.add(path)
    
    path_holder = list(path_holder)
    random.shuffle(path_holder)
    
    score_dict = defaultdict(float)
    
    count = 0
    
    if len(path_holder) >= 2:
    
        #iterate over path_1
        while count <= 50:

            temp_pair = random.sample(path_holder, 2)

            path_1, path_2 = temp_pair[0], temp_pair[1]

            #decide which path is shorter and which is longer
            if len(path_1) <= len(path_2):

                path_s, path_l = path_1, path_2

            else:

                path_s, path_l = path_2, path_1                            

            #whether lengths of the two paths satisfies the requirments
            if (len(path_s) < lower_bd) or (len(path_l) > upper_bd):

                raise ValueError('something wrong with path finding')

            list_s = list()
            list_l = list()
            list_r = list()

            for i in range(len(id2relation)):

                if i not in id2relation:

                    raise ValueError ('error when generating id2relation')

                list_s.append(list(path_s) + [num_r]*abs(len(path_s)-upper_bd))
                list_l.append(list(path_l) + [num_r]*abs(len(path_l)-upper_bd))
                list_r.append([i])

            input_s = np.array(list_s)
            input_l = np.array(list_l)
            input_r = np.array(list_r)

            pred = model.predict([input_s, input_l, input_r], verbose = 0)

            for i in range(pred.shape[0]):

                score_dict[i] += float(pred[i])

            count += 1
                
    print(len(score_dict), len(path_holder))

    return(score_dict)

In [19]:
########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_test), min(len(data_test), 500))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 10, one_hop_ind, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
              (s_true, sorted_list[p][1], t_true) in data_valid) or (
              (s_true, sorted_list[p][1], t_true) in data_ind):
            
            exist_tri += 1
            
        p += 1
    
    if p - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - exist_tri + 1.) 
        
    print('Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - exist_tri,
          'abs_cur_rank', p,
          'total_num', i, len(selected))

438 785
Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 abs_cur_rank 0 total_num 0 500
438 381
Hits@1 0.5 Hits@3 1.0 Hits@10 1.0 MRR 0.75 cur_rank 1 abs_cur_rank 1 total_num 1 500
438 2964
Hits@1 0.6666666666666666 Hits@3 1.0 Hits@10 1.0 MRR 0.8333333333333334 cur_rank 0 abs_cur_rank 0 total_num 2 500
438 433
Hits@1 0.75 Hits@3 1.0 Hits@10 1.0 MRR 0.875 cur_rank 0 abs_cur_rank 0 total_num 3 500
438 1667
Hits@1 0.8 Hits@3 1.0 Hits@10 1.0 MRR 0.9 cur_rank 0 abs_cur_rank 0 total_num 4 500
438 222
Hits@1 0.8333333333333334 Hits@3 1.0 Hits@10 1.0 MRR 0.9166666666666666 cur_rank 0 abs_cur_rank 0 total_num 5 500
438 962
Hits@1 0.8571428571428571 Hits@3 1.0 Hits@10 1.0 MRR 0.9285714285714286 cur_rank 0 abs_cur_rank 0 total_num 6 500
438 607
Hits@1 0.875 Hits@3 1.0 Hits@10 1.0 MRR 0.9375 cur_rank 0 abs_cur_rank 0 total_num 7 500
438 1425
Hits@1 0.8888888888888888 Hits@3 1.0 Hits@10 1.0 MRR 0.9444444444444444 cur_rank 0 abs_cur_rank 0 total_num 8 500
438 2525
Hits@1 0.9 Hits@3 1.0 Hits@10 1

438 2639
Hits@1 0.532258064516129 Hits@3 0.7741935483870968 Hits@10 0.9032258064516129 MRR 0.6605794649535989 cur_rank 0 abs_cur_rank 0 total_num 61 500
438 909
Hits@1 0.5238095238095238 Hits@3 0.7619047619047619 Hits@10 0.8888888888888888 MRR 0.6507289972559227 cur_rank 24 abs_cur_rank 24 total_num 62 500
438 4689
Hits@1 0.515625 Hits@3 0.765625 Hits@10 0.890625 MRR 0.6457696900071322 cur_rank 2 abs_cur_rank 2 total_num 63 500
438 5658
Hits@1 0.5076923076923077 Hits@3 0.7692307692307693 Hits@10 0.8923076923076924 MRR 0.6409629768275353 cur_rank 2 abs_cur_rank 2 total_num 64 500
438 559
Hits@1 0.5151515151515151 Hits@3 0.7727272727272727 Hits@10 0.8939393939393939 MRR 0.6464029317240879 cur_rank 0 abs_cur_rank 0 total_num 65 500
438 716
Hits@1 0.5074626865671642 Hits@3 0.7611940298507462 Hits@10 0.8955223880597015 MRR 0.638887322935029 cur_rank 6 abs_cur_rank 6 total_num 66 500
438 1560
Hits@1 0.5147058823529411 Hits@3 0.7647058823529411 Hits@10 0.8970588235294118 MRR 0.644197803480102

438 525
Hits@1 0.5042735042735043 Hits@3 0.7863247863247863 Hits@10 0.905982905982906 MRR 0.6474461694856419 cur_rank 17 abs_cur_rank 18 total_num 116 500
438 197
Hits@1 0.5 Hits@3 0.788135593220339 Hits@10 0.9067796610169492 MRR 0.6461966256764415 cur_rank 1 abs_cur_rank 2 total_num 117 500
438 96
Hits@1 0.5042016806722689 Hits@3 0.7899159663865546 Hits@10 0.907563025210084 MRR 0.6491697632757992 cur_rank 0 abs_cur_rank 0 total_num 118 500
438 91
Hits@1 0.5 Hits@3 0.7833333333333333 Hits@10 0.9 MRR 0.6444010408895265 cur_rank 12 abs_cur_rank 12 total_num 119 500
438 460
Hits@1 0.49586776859504134 Hits@3 0.7851239669421488 Hits@10 0.9008264462809917 MRR 0.6418302333890621 cur_rank 2 abs_cur_rank 2 total_num 120 500
438 815
Hits@1 0.5 Hits@3 0.7868852459016393 Hits@10 0.9016393442622951 MRR 0.6447660511481681 cur_rank 0 abs_cur_rank 0 total_num 121 500
438 5157
Hits@1 0.4959349593495935 Hits@3 0.7886178861788617 Hits@10 0.9024390243902439 MRR 0.6435890913827359 cur_rank 1 abs_cur_rank 1

438 812
Hits@1 0.4883720930232558 Hits@3 0.7790697674418605 Hits@10 0.9127906976744186 MRR 0.6394743391243205 cur_rank 0 abs_cur_rank 0 total_num 171 500
438 2365
Hits@1 0.48554913294797686 Hits@3 0.7803468208092486 Hits@10 0.9132947976878613 MRR 0.6386681290715788 cur_rank 1 abs_cur_rank 1 total_num 172 500
438 1435
Hits@1 0.4885057471264368 Hits@3 0.7816091954022989 Hits@10 0.9137931034482759 MRR 0.6407447490194432 cur_rank 0 abs_cur_rank 0 total_num 173 500
438 238
Hits@1 0.49142857142857144 Hits@3 0.7828571428571428 Hits@10 0.9142857142857143 MRR 0.6427976361679036 cur_rank 0 abs_cur_rank 0 total_num 174 500
438 565
Hits@1 0.4943181818181818 Hits@3 0.7840909090909091 Hits@10 0.9147727272727273 MRR 0.6448271950533132 cur_rank 0 abs_cur_rank 0 total_num 175 500
438 3989
Hits@1 0.4915254237288136 Hits@3 0.7853107344632768 Hits@10 0.9152542372881356 MRR 0.6430673427272117 cur_rank 2 abs_cur_rank 2 total_num 176 500
438 1573
Hits@1 0.4943820224719101 Hits@3 0.7865168539325843 Hits@10 0.

438 844
Hits@1 0.4823008849557522 Hits@3 0.7610619469026548 Hits@10 0.911504424778761 MRR 0.6293995878409668 cur_rank 0 abs_cur_rank 0 total_num 225 500
438 1287
Hits@1 0.4801762114537445 Hits@3 0.762114537444934 Hits@10 0.9118942731277533 MRR 0.628829545603782 cur_rank 1 abs_cur_rank 1 total_num 226 500
438 1419
Hits@1 0.4824561403508772 Hits@3 0.7631578947368421 Hits@10 0.9122807017543859 MRR 0.630457486193239 cur_rank 0 abs_cur_rank 0 total_num 227 500
438 2517
Hits@1 0.4847161572052402 Hits@3 0.7641921397379913 Hits@10 0.9126637554585153 MRR 0.6320712089609543 cur_rank 0 abs_cur_rank 0 total_num 228 500
438 1305
Hits@1 0.48695652173913045 Hits@3 0.7652173913043478 Hits@10 0.9130434782608695 MRR 0.6336708993567761 cur_rank 0 abs_cur_rank 0 total_num 229 500
438 261
Hits@1 0.48484848484848486 Hits@3 0.7619047619047619 Hits@10 0.9090909090909091 MRR 0.6310416564568493 cur_rank 37 abs_cur_rank 37 total_num 230 500
438 2815
Hits@1 0.4870689655172414 Hits@3 0.7629310344827587 Hits@10 0.9

438 5207
Hits@1 0.49642857142857144 Hits@3 0.7607142857142857 Hits@10 0.8964285714285715 MRR 0.637778644519189 cur_rank 1 abs_cur_rank 1 total_num 279 500
438 80
Hits@1 0.498220640569395 Hits@3 0.7615658362989324 Hits@10 0.896797153024911 MRR 0.6390676884888715 cur_rank 0 abs_cur_rank 0 total_num 280 500
438 1206
Hits@1 0.49645390070921985 Hits@3 0.7624113475177305 Hits@10 0.8971631205673759 MRR 0.6385745406573508 cur_rank 1 abs_cur_rank 1 total_num 281 500
438 900
Hits@1 0.49469964664310956 Hits@3 0.7632508833922261 Hits@10 0.8975265017667845 MRR 0.6380848779695155 cur_rank 1 abs_cur_rank 1 total_num 282 500
0 0
Hits@1 0.49295774647887325 Hits@3 0.7605633802816901 Hits@10 0.8943661971830986 MRR 0.6359911926980368 cur_rank 22 abs_cur_rank 22 total_num 283 500
438 1443
Hits@1 0.49473684210526314 Hits@3 0.7614035087719299 Hits@10 0.8947368421052632 MRR 0.6372684165833069 cur_rank 0 abs_cur_rank 0 total_num 284 500
438 81
Hits@1 0.4965034965034965 Hits@3 0.7622377622377622 Hits@10 0.89510

438 3998
Hits@1 0.5 Hits@3 0.7724550898203593 Hits@10 0.8982035928143712 MRR 0.6443915828121158 cur_rank 0 abs_cur_rank 0 total_num 333 500
438 1362
Hits@1 0.5014925373134328 Hits@3 0.7731343283582089 Hits@10 0.8985074626865671 MRR 0.6454531004753633 cur_rank 0 abs_cur_rank 0 total_num 334 500
438 3583
Hits@1 0.5 Hits@3 0.7738095238095238 Hits@10 0.8988095238095238 MRR 0.6450202043429961 cur_rank 1 abs_cur_rank 1 total_num 335 500
438 3671
Hits@1 0.5014836795252225 Hits@3 0.7744807121661721 Hits@10 0.8991097922848664 MRR 0.6460735568523641 cur_rank 0 abs_cur_rank 0 total_num 336 500
438 97
Hits@1 0.5029585798816568 Hits@3 0.7751479289940828 Hits@10 0.8994082840236687 MRR 0.6471206765066471 cur_rank 0 abs_cur_rank 0 total_num 337 500
438 3098
Hits@1 0.5014749262536873 Hits@3 0.775811209439528 Hits@10 0.8997050147492626 MRR 0.6461950501256049 cur_rank 2 abs_cur_rank 2 total_num 338 500
438 225
Hits@1 0.5 Hits@3 0.7764705882352941 Hits@10 0.9 MRR 0.6457650646840589 cur_rank 1 abs_cur_rank

438 2097
Hits@1 0.5180412371134021 Hits@3 0.7809278350515464 Hits@10 0.8994845360824743 MRR 0.6578917992171143 cur_rank 1 abs_cur_rank 1 total_num 387 500
438 2472
Hits@1 0.519280205655527 Hits@3 0.781491002570694 Hits@10 0.8997429305912596 MRR 0.6587712547461192 cur_rank 0 abs_cur_rank 0 total_num 388 500
438 526
Hits@1 0.5205128205128206 Hits@3 0.782051282051282 Hits@10 0.9 MRR 0.6596462002467701 cur_rank 0 abs_cur_rank 0 total_num 389 500
438 2400
Hits@1 0.5217391304347826 Hits@3 0.782608695652174 Hits@10 0.9002557544757033 MRR 0.6605166703228653 cur_rank 0 abs_cur_rank 0 total_num 390 500
438 37
Hits@1 0.5204081632653061 Hits@3 0.7806122448979592 Hits@10 0.8979591836734694 MRR 0.658841269119226 cur_rank 265 abs_cur_rank 265 total_num 391 500
438 2329
Hits@1 0.5216284987277354 Hits@3 0.7811704834605598 Hits@10 0.8982188295165394 MRR 0.6597093574929684 cur_rank 0 abs_cur_rank 0 total_num 392 500
438 290
Hits@1 0.5203045685279187 Hits@3 0.7817258883248731 Hits@10 0.8984771573604061 MR

438 4710
Hits@1 0.5226244343891403 Hits@3 0.7805429864253394 Hits@10 0.8914027149321267 MRR 0.6593831096879754 cur_rank 1 abs_cur_rank 1 total_num 441 500
438 3676
Hits@1 0.5237020316027088 Hits@3 0.781038374717833 Hits@10 0.891647855530474 MRR 0.6601519965735556 cur_rank 0 abs_cur_rank 0 total_num 442 500
438 3856
Hits@1 0.5225225225225225 Hits@3 0.7815315315315315 Hits@10 0.8918918918918919 MRR 0.6594159185031947 cur_rank 2 abs_cur_rank 2 total_num 443 500
438 2506
Hits@1 0.5235955056179775 Hits@3 0.7820224719101123 Hits@10 0.8921348314606742 MRR 0.6601812759897043 cur_rank 0 abs_cur_rank 0 total_num 444 500
438 1576
Hits@1 0.5246636771300448 Hits@3 0.7825112107623319 Hits@10 0.8923766816143498 MRR 0.660943201379862 cur_rank 0 abs_cur_rank 0 total_num 445 500
438 89
Hits@1 0.5234899328859061 Hits@3 0.7807606263982103 Hits@10 0.8903803131991052 MRR 0.6594843789103816 cur_rank 112 abs_cur_rank 112 total_num 446 500
438 587
Hits@1 0.5245535714285714 Hits@3 0.78125 Hits@10 0.890625 MRR 0

438 480
Hits@1 0.532258064516129 Hits@3 0.782258064516129 Hits@10 0.8891129032258065 MRR 0.6645114441731786 cur_rank 0 abs_cur_rank 0 total_num 495 500
438 5241
Hits@1 0.5311871227364185 Hits@3 0.7826961770623743 Hits@10 0.8893360160965795 MRR 0.6638450898254123 cur_rank 2 abs_cur_rank 3 total_num 496 500
438 2758
Hits@1 0.5321285140562249 Hits@3 0.7831325301204819 Hits@10 0.8895582329317269 MRR 0.6645200996852006 cur_rank 0 abs_cur_rank 0 total_num 497 500
438 2514
Hits@1 0.533066132264529 Hits@3 0.7835671342685371 Hits@10 0.8897795591182365 MRR 0.6651924040946491 cur_rank 0 abs_cur_rank 0 total_num 498 500
438 4812
Hits@1 0.534 Hits@3 0.784 Hits@10 0.89 MRR 0.6658620192864598 cur_rank 0 abs_cur_rank 0 total_num 499 500


In [19]:
########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_test), min(len(data_test), 500))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_3 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 2, 10, one_hop_ind, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    inverse_r = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #we want to see how many inverse relaiton are ranked above true relation
        #then, we remove them from ranking, otherwise it is not fair for us to compare our
        #result with other model who does not consider inverse relations
        #if sorted_list[p][1] % 2 != 0:
            
        #    inverse_r += 1
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
              (s_true, sorted_list[p][1], t_true) in data_valid) or (
              (s_true, sorted_list[p][1], t_true) in data_ind):
            
            exist_tri += 1
            
        p += 1
    
    if p - inverse_r - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - inverse_r - exist_tri < 3:
        
        Hits_at_3 += 1
        
    if p - inverse_r - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - inverse_r - exist_tri + 1.) 
        
    print('Hits@1', Hits_at_1/(i+1),
          'Hits@3', Hits_at_3/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - inverse_r - exist_tri, 
          'total_num', i, len(selected))

438 1441
Hits@1 1.0 Hits@3 1.0 Hits@10 1.0 MRR 1.0 cur_rank 0 total_num 0 500
438 12
Hits@1 0.5 Hits@3 0.5 Hits@10 1.0 MRR 0.625 cur_rank 3 total_num 1 500
438 251
Hits@1 0.3333333333333333 Hits@3 0.6666666666666666 Hits@10 1.0 MRR 0.5833333333333334 cur_rank 1 total_num 2 500
438 370
Hits@1 0.5 Hits@3 0.75 Hits@10 1.0 MRR 0.6875 cur_rank 0 total_num 3 500
438 3856
Hits@1 0.4 Hits@3 0.8 Hits@10 1.0 MRR 0.65 cur_rank 1 total_num 4 500
438 1952
Hits@1 0.5 Hits@3 0.8333333333333334 Hits@10 1.0 MRR 0.7083333333333334 cur_rank 0 total_num 5 500
0 0
Hits@1 0.42857142857142855 Hits@3 0.7142857142857143 Hits@10 0.8571428571428571 MRR 0.6104651162790697 cur_rank 42 total_num 6 500
438 1207
Hits@1 0.5 Hits@3 0.75 Hits@10 0.875 MRR 0.659156976744186 cur_rank 0 total_num 7 500
438 2502
Hits@1 0.4444444444444444 Hits@3 0.7777777777777778 Hits@10 0.8888888888888888 MRR 0.6229543496985357 cur_rank 2 total_num 8 500
438 1014
Hits@1 0.5 Hits@3 0.8 Hits@10 0.9 MRR 0.6606589147286821 cur_rank 0 total_num

438 1041
Hits@1 0.5303030303030303 Hits@3 0.8333333333333334 Hits@10 0.9242424242424242 MRR 0.6910402547948449 cur_rank 4 total_num 65 500
438 3656
Hits@1 0.5373134328358209 Hits@3 0.835820895522388 Hits@10 0.9253731343283582 MRR 0.6956515942755188 cur_rank 0 total_num 66 500
438 3823
Hits@1 0.5294117647058824 Hits@3 0.8382352941176471 Hits@10 0.9264705882352942 MRR 0.6927743649479376 cur_rank 1 total_num 67 500
438 789
Hits@1 0.5362318840579711 Hits@3 0.8405797101449275 Hits@10 0.927536231884058 MRR 0.6972269103834748 cur_rank 0 total_num 68 500
438 564
Hits@1 0.5285714285714286 Hits@3 0.8285714285714286 Hits@10 0.9142857142857143 MRR 0.6873700456181623 cur_rank 137 total_num 69 500
438 1508
Hits@1 0.5352112676056338 Hits@3 0.8309859154929577 Hits@10 0.9154929577464789 MRR 0.6917732844122726 cur_rank 0 total_num 70 500
438 177
Hits@1 0.5416666666666666 Hits@3 0.8333333333333334 Hits@10 0.9166666666666666 MRR 0.6960542110176577 cur_rank 0 total_num 71 500
438 1268
Hits@1 0.547945205479

438 2739
Hits@1 0.5669291338582677 Hits@3 0.84251968503937 Hits@10 0.9212598425196851 MRR 0.7111163381120214 cur_rank 0 total_num 126 500
438 368
Hits@1 0.5625 Hits@3 0.84375 Hits@10 0.921875 MRR 0.7094669917205212 cur_rank 1 total_num 127 500
438 306
Hits@1 0.5658914728682171 Hits@3 0.8449612403100775 Hits@10 0.9224806201550387 MRR 0.7117191855831528 cur_rank 0 total_num 128 500
438 2310
Hits@1 0.5692307692307692 Hits@3 0.8461538461538461 Hits@10 0.9230769230769231 MRR 0.7139367303094363 cur_rank 0 total_num 129 500
438 546
Hits@1 0.5648854961832062 Hits@3 0.8396946564885496 Hits@10 0.9236641221374046 MRR 0.710013549162036 cur_rank 4 total_num 130 500
438 3073
Hits@1 0.5681818181818182 Hits@3 0.8409090909090909 Hits@10 0.9242424242424242 MRR 0.7122104162138387 cur_rank 0 total_num 131 500
438 478
Hits@1 0.5714285714285714 Hits@3 0.8421052631578947 Hits@10 0.924812030075188 MRR 0.7143742476708775 cur_rank 0 total_num 132 500
438 1826
Hits@1 0.5746268656716418 Hits@3 0.8432835820895522 

438 993
Hits@1 0.5828877005347594 Hits@3 0.8449197860962567 Hits@10 0.93048128342246 MRR 0.7232853727454798 cur_rank 0 total_num 186 500
438 2521
Hits@1 0.5851063829787234 Hits@3 0.8457446808510638 Hits@10 0.9308510638297872 MRR 0.7247572590606634 cur_rank 0 total_num 187 500
438 554
Hits@1 0.5873015873015873 Hits@3 0.8465608465608465 Hits@10 0.9312169312169312 MRR 0.7262135698592842 cur_rank 0 total_num 188 500
438 4544
Hits@1 0.5842105263157895 Hits@3 0.8473684210526315 Hits@10 0.9315789473684211 MRR 0.7250229721231827 cur_rank 1 total_num 189 500
438 1153
Hits@1 0.5863874345549738 Hits@3 0.8481675392670157 Hits@10 0.9319371727748691 MRR 0.7264626424262027 cur_rank 0 total_num 190 500
438 1328
Hits@1 0.5885416666666666 Hits@3 0.8489583333333334 Hits@10 0.9322916666666666 MRR 0.7278873161635663 cur_rank 0 total_num 191 500
438 313
Hits@1 0.5854922279792746 Hits@3 0.844559585492228 Hits@10 0.9326424870466321 MRR 0.7254112160798172 cur_rank 3 total_num 192 500
438 246
Hits@1 0.582474226

438 3431
Hits@1 0.5465587044534413 Hits@3 0.8299595141700404 Hits@10 0.9271255060728745 MRR 0.6993979229164251 cur_rank 0 total_num 246 500
438 943
Hits@1 0.5443548387096774 Hits@3 0.8266129032258065 Hits@10 0.9274193548387096 MRR 0.6973842216143427 cur_rank 4 total_num 247 500
438 808
Hits@1 0.5461847389558233 Hits@3 0.8273092369477911 Hits@10 0.927710843373494 MRR 0.6985995460255301 cur_rank 0 total_num 248 500
438 3191
Hits@1 0.544 Hits@3 0.828 Hits@10 0.928 MRR 0.6978051478414279 cur_rank 1 total_num 249 500
438 1219
Hits@1 0.5418326693227091 Hits@3 0.8286852589641435 Hits@10 0.9282868525896414 MRR 0.6970170795233346 cur_rank 1 total_num 250 500
438 257
Hits@1 0.5436507936507936 Hits@3 0.8293650793650794 Hits@10 0.9285714285714286 MRR 0.6982193926998292 cur_rank 0 total_num 251 500
438 668
Hits@1 0.5454545454545454 Hits@3 0.8300395256916996 Hits@10 0.9288537549407114 MRR 0.6994122014243359 cur_rank 0 total_num 252 500
438 624
Hits@1 0.5433070866141733 Hits@3 0.8267716535433071 Hits

438 2146
Hits@1 0.5439739413680782 Hits@3 0.8110749185667753 Hits@10 0.9120521172638436 MRR 0.6910091280218743 cur_rank 0 total_num 306 500
438 4829
Hits@1 0.5422077922077922 Hits@3 0.8116883116883117 Hits@10 0.9123376623376623 MRR 0.6903889685153098 cur_rank 1 total_num 307 500
438 1476
Hits@1 0.540453074433657 Hits@3 0.8122977346278317 Hits@10 0.912621359223301 MRR 0.689233448660352 cur_rank 2 total_num 308 500
438 352
Hits@1 0.5419354838709678 Hits@3 0.8129032258064516 Hits@10 0.9129032258064517 MRR 0.6902359214066089 cur_rank 0 total_num 309 500
438 381
Hits@1 0.5434083601286174 Hits@3 0.8135048231511254 Hits@10 0.9131832797427653 MRR 0.6912319473827935 cur_rank 0 total_num 310 500
438 4568
Hits@1 0.5416666666666666 Hits@3 0.8141025641025641 Hits@10 0.9134615384615384 MRR 0.6906190244745153 cur_rank 1 total_num 311 500
438 97
Hits@1 0.5431309904153354 Hits@3 0.8146964856230032 Hits@10 0.9137380191693291 MRR 0.6916074620960024 cur_rank 0 total_num 312 500
438 1172
Hits@1 0.541401273

438 175
Hits@1 0.5313351498637602 Hits@3 0.8310626702997275 Hits@10 0.9237057220708447 MRR 0.6918290168642026 cur_rank 1 total_num 366 500
438 311
Hits@1 0.532608695652174 Hits@3 0.8315217391304348 Hits@10 0.9239130434782609 MRR 0.692666438014028 cur_rank 0 total_num 367 500
438 225
Hits@1 0.5338753387533876 Hits@3 0.8319783197831978 Hits@10 0.924119241192412 MRR 0.6934993202958329 cur_rank 0 total_num 368 500
438 1570
Hits@1 0.5351351351351351 Hits@3 0.8324324324324325 Hits@10 0.9243243243243243 MRR 0.6943277005112496 cur_rank 0 total_num 369 500
438 1109
Hits@1 0.5336927223719676 Hits@3 0.8301886792452831 Hits@10 0.9245283018867925 MRR 0.6931300517228095 cur_rank 3 total_num 370 500
438 327
Hits@1 0.532258064516129 Hits@3 0.8279569892473119 Hits@10 0.9247311827956989 MRR 0.6918044333041998 cur_rank 4 total_num 371 500
438 3122
Hits@1 0.5308310991957105 Hits@3 0.8257372654155496 Hits@10 0.9249329758713136 MRR 0.6904859227591483 cur_rank 4 total_num 372 500
438 270
Hits@1 0.52941176470

438 281
Hits@1 0.5362997658079626 Hits@3 0.8290398126463701 Hits@10 0.9227166276346604 MRR 0.6944485138930065 cur_rank 0 total_num 426 500
438 692
Hits@1 0.5350467289719626 Hits@3 0.8294392523364486 Hits@10 0.9228971962616822 MRR 0.6936047868356241 cur_rank 2 total_num 427 500
438 615
Hits@1 0.5337995337995338 Hits@3 0.8298368298368298 Hits@10 0.9230769230769231 MRR 0.693153493626217 cur_rank 1 total_num 428 500
438 4783
Hits@1 0.5348837209302325 Hits@3 0.8302325581395349 Hits@10 0.9232558139534883 MRR 0.6938670901526678 cur_rank 0 total_num 429 500
438 1302
Hits@1 0.5359628770301624 Hits@3 0.8306264501160093 Hits@10 0.9234338747099768 MRR 0.6945773753263275 cur_rank 0 total_num 430 500
438 3585
Hits@1 0.5370370370370371 Hits@3 0.8310185185185185 Hits@10 0.9236111111111112 MRR 0.6952843721427017 cur_rank 0 total_num 431 500
438 610
Hits@1 0.5381062355658198 Hits@3 0.8314087759815243 Hits@10 0.9237875288683602 MRR 0.6959881033848664 cur_rank 0 total_num 432 500
438 5649
Hits@1 0.5368663

438 5113
Hits@1 0.5462012320328542 Hits@3 0.8357289527720739 Hits@10 0.9301848049281314 MRR 0.7023364451039981 cur_rank 1 total_num 486 500
438 1095
Hits@1 0.5471311475409836 Hits@3 0.8360655737704918 Hits@10 0.930327868852459 MRR 0.7029464114050146 cur_rank 0 total_num 487 500
438 2316
Hits@1 0.5480572597137015 Hits@3 0.83640081799591 Hits@10 0.9304703476482618 MRR 0.7035538829563335 cur_rank 0 total_num 488 500
438 559
Hits@1 0.5489795918367347 Hits@3 0.8367346938775511 Hits@10 0.9306122448979591 MRR 0.7041588750319329 cur_rank 0 total_num 489 500


KeyboardInterrupt: 

In [18]:
'''Ranking on transductive setting'''

########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_test), int(len(data_test)/5))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 2, 10, one_hop, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    inverse_r = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #we want to see how many inverse relaiton are ranked above true relation
        #then, we remove them from ranking, otherwise it is not fair for us to compare our
        #result with other model who does not consider inverse relations
        if temp_list[p][1] % 2 != 0:
            
            inverse_r += 1
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
              (s_true, sorted_list[p][1], t_true) in data_valid) or (
              (s_true, sorted_list[p][1], t_true) in data):
            
            exist_tri += 1
            
        p += 1
    
    if p - inverse_r - exist_tri == 0:
        
        Hits_at_1 += 1
        
    print('Calculating_hit_at_one', Hits_at_1/(i+1), 'current ranking:', 
          p - inverse_r - exist_tri, 'triple number:', i, len(selected))

438 66
Calculating_hit_at_one 0.0 current ranking: 1 triple number: 0 672
438 1171
Calculating_hit_at_one 0.0 current ranking: 1 triple number: 1 672
438 1044
Calculating_hit_at_one 0.0 current ranking: 1 triple number: 2 672
438 600
Calculating_hit_at_one 0.25 current ranking: 0 triple number: 3 672
438 4356
Calculating_hit_at_one 0.2 current ranking: 1 triple number: 4 672
438 377
Calculating_hit_at_one 0.3333333333333333 current ranking: 0 triple number: 5 672
438 1864
Calculating_hit_at_one 0.42857142857142855 current ranking: 0 triple number: 6 672
438 509
Calculating_hit_at_one 0.375 current ranking: 2 triple number: 7 672
438 1675
Calculating_hit_at_one 0.3333333333333333 current ranking: 27 triple number: 8 672
438 625
Calculating_hit_at_one 0.3 current ranking: 1 triple number: 9 672
438 2672
Calculating_hit_at_one 0.36363636363636365 current ranking: 0 triple number: 10 672
438 1404
Calculating_hit_at_one 0.3333333333333333 current ranking: 2 triple number: 11 672
438 553
Cal

KeyboardInterrupt: 

In [144]:
########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_test), int(len(data_test)/5))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 2, 10, one_hop_ind, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    inverse_r = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #we want to see how many inverse relaiton are ranked above true relation
        #then, we remove them from ranking, otherwise it is not fair for us to compare our
        #result with other model who does not consider inverse relations
        if temp_list[p][1] % 2 != 0:
            
            inverse_r += 1
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
              (s_true, sorted_list[p][1], t_true) in data_valid) or (
              (s_true, sorted_list[p][1], t_true) in data_ind):
            
            exist_tri += 1
            
        p += 1
    
    if p - inverse_r - exist_tri == 0:
        
        Hits_at_1 += 1
        
    print('Calculating_hit_at_one', Hits_at_1/(i+1), 'current ranking:', 
          p - inverse_r - exist_tri, 'triple number:', i, len(selected))

438 263
Calculating_hit_at_one 1.0 current ranking: 0 0 284
438 3646
Calculating_hit_at_one 0.5 current ranking: 1 1 284
438 1588
Calculating_hit_at_one 0.3333333333333333 current ranking: 10 2 284
438 707
Calculating_hit_at_one 0.25 current ranking: 4 3 284
438 3451
Calculating_hit_at_one 0.4 current ranking: 0 4 284
438 2757
Calculating_hit_at_one 0.3333333333333333 current ranking: 7 5 284
438 559
Calculating_hit_at_one 0.2857142857142857 current ranking: 1 6 284
438 2868
Calculating_hit_at_one 0.375 current ranking: 0 7 284
438 3601
Calculating_hit_at_one 0.3333333333333333 current ranking: 1 8 284
438 1470
Calculating_hit_at_one 0.4 current ranking: 0 9 284
438 37
Calculating_hit_at_one 0.36363636363636365 current ranking: 1 10 284
438 2329
Calculating_hit_at_one 0.3333333333333333 current ranking: 7 11 284
438 1230
Calculating_hit_at_one 0.38461538461538464 current ranking: 0 12 284
438 929
Calculating_hit_at_one 0.35714285714285715 current ranking: 1 13 284
438 691
Calculating_h

438 660
Calculating_hit_at_one 0.3063063063063063 current ranking: 1 110 284
438 2095
Calculating_hit_at_one 0.30357142857142855 current ranking: 2 111 284
438 2160
Calculating_hit_at_one 0.30973451327433627 current ranking: 0 112 284
438 1207
Calculating_hit_at_one 0.3157894736842105 current ranking: 0 113 284
438 1671
Calculating_hit_at_one 0.3130434782608696 current ranking: 1 114 284
438 5113
Calculating_hit_at_one 0.3103448275862069 current ranking: 1 115 284
438 4248
Calculating_hit_at_one 0.3162393162393162 current ranking: 0 116 284
438 2935
Calculating_hit_at_one 0.3135593220338983 current ranking: 96 117 284
438 36
Calculating_hit_at_one 0.31092436974789917 current ranking: 1 118 284
438 429
Calculating_hit_at_one 0.30833333333333335 current ranking: 1 119 284
438 4473
Calculating_hit_at_one 0.30578512396694213 current ranking: 3 120 284


KeyboardInterrupt: 

In [13]:
test = {(1,2,3),(4,5),(6,7,8)}
test = list(test)
test

[(4, 5), (1, 2, 3), (6, 7, 8)]

In [18]:
K = random.sample(test,2)
K

[(6, 7, 8), (1, 2, 3)]