In [1]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
from collections import defaultdict
from copy import deepcopy
import pickle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras.utils import plot_model

In [2]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = dict()
            
            if r not in one_hop[s]:
                
                one_hop[s][r] = set()
            
            one_hop[s][r].add(t)
            
            if t not in one_hop:
                
                one_hop[t] = dict()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            if r_inv not in one_hop[t]:
                
                one_hop[t][r_inv] = set()
            
            one_hop[t][r_inv].add(s)

In [3]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, size_bd=50, threshold=100000):
                
        self.size_bd = size_bd
        
        self.threshold = threshold
    
    '''
    Given an entity s, here is the function to find:
      1. any else entity t that is directely connected to s
      2. most of the paths from s to each t with length L
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on active entities for validation')
        
        #here is the result dict. Its key is each entity t that is directly connected to s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #direct_nb contains all the direct neighbour of s
        direct_nb = set()
        
        if mode == 'direct_neighbour':
        
            for r in one_hop[s]:
            
                for t in one_hop[s][r]:
                
                    direct_nb.add(t)
                    
        elif mode == 'target_specified':
            
            direct_nb.add(t_input)
            
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                direct_nb.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {e1, ..., ek-1, node}
        from s to this node, we further find the direct neighbor t' of this node. 
        If t' is not a on-path entity (not among e1,...ek-1), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict):
            
            #when the current path is within lower_bd and upper_bd and its corresponding
            #length still within the size_bd and its tail node is within the note dict, 
            #we will then intend to add this path
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in direct_nb) and (length_dict[len(path)] < self.size_bd):
                
                #if this path already exists between the source entity and the current target node,
                #we will not count it.
                #here is an interesting situation: this path may exist between s and some other node t,
                #however, it does not exist between s and this node t. Then, we still count it: length_dict[len(path)] += 1
                #That is, each path may be counted for multiple times.
                #We count how many paths we "actually" found between entity pairs
                #Same type of path between different entity pairs are count separately.
                if tuple(path) not in res[node]:
                
                    res[node].add(tuple(path))
                
                    length_dict[len(path)] += 1
                
            #For some rare entities, we may face such a case: so many paths are evaluated,
            #but no entities on the paths are direct neighbors of the rare entity.
            #In this case, the recursion cannot be bounded and stoped by the size threshold.
            #In order to cure this, we count how many times the recursion happens on a specific length, using the count_dict.
            #Its key is length, value counts the recursion occurred to that length. 
            #The recursion is forced to stop for that length (and hence for longer lengths) once reach the threshold.
            if (len(path) < upper_bd) and (length_dict[len(path) + 1] < self.size_bd) and (
                count_dict[len(path)] <= self.threshold):
                
                #we randomly shuffle relation r so that the reading in order is not fixed
                temp_list = list()
                
                for r in one_hop[node]:
                    
                    temp_list.append(r)
                
                for i_0 in range(len(temp_list)):
                    
                    if count_dict[len(path)] > self.threshold:
                        break
                    
                    r = random.choice(temp_list)
                    
                    for i_1 in range(len(one_hop[node][r])):
                        
                        if count_dict[len(path)] > self.threshold:
                            break
                        
                        t = random.choice(list(one_hop[node][r]))
                        
                        if t not in on_path_en:
                                
                            count_dict[len(path)] += 1

                            helper(t, path + [r], on_path_en.union({t}), res, direct_nb, 
                                   lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        return(res, length_dict)

In [4]:
train_path = '../data/nell_v4/train.txt'

In [5]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [6]:
#define the dictionaries and sets for load KG
one_hop = dict() 
data = set()
s_t_r = defaultdict(set)
entity2id = dict()
id2entity = dict()
relation2id = dict()
id2relation = dict()

#fill in the sets and dicts
Class_1.load_train_data(train_path, one_hop, data, s_t_r,
                        entity2id, id2entity, relation2id, id2relation)

### Build the deep neural network structure

We use biLSTM to train on the input path embedding sequence to predict the output embedding or the relation.

In [7]:
# Input layer, using integer to represent each relation type
#note that inputs_path is the path inputs, while inputs_out_re is the output relation inputs
fst_path = keras.Input(shape=(None,), dtype="int32")
scd_path = keras.Input(shape=(None,), dtype="int32")

#the relation input layer (for output embedding)
id_rela = keras.Input(shape=(None,), dtype="int32")

# Embed each integer in a 300-dimensional vector as input,
# note that we add another "space holder" embedding, 
# which hold the spaces if the initial length of two paths are not the same
in_embd_var = layers.Embedding(len(relation2id)+1, 300)

# Obtain the embedding
fst_p_embd = in_embd_var(fst_path)
scd_p_embd = in_embd_var(scd_path)

# Embed each integer in a 300-dimensional vector as output
rela_embd = layers.Embedding(len(relation2id)+1, 300)(id_rela)

#add 2 layer bi-directional LSTM
lstm_layer_1 = layers.Bidirectional(layers.LSTM(150, return_sequences=True))
lstm_layer_2 = layers.Bidirectional(layers.LSTM(150, return_sequences=False))

#first LSTM layer
fst_lstm_mid = lstm_layer_1(fst_p_embd)
scd_lstm_mid = lstm_layer_1(scd_p_embd)

#second LSTM layer
fst_lstm_out = lstm_layer_2(fst_lstm_mid)
scd_lstm_out = lstm_layer_2(scd_lstm_mid)

#concatenate the output vector from both siamese tunnel
path_concat = layers.concatenate([fst_lstm_out, scd_lstm_out], axis=-1)

#remove the time dimension from the output embd since there is only one step
sum_r_embd = tf.reduce_sum(rela_embd, axis=1)

#concatenate the lstm output and output embd
concat = layers.concatenate([path_concat, sum_r_embd], axis=-1)

#add the dense layer
dense_1 = layers.Dense(32, activation='relu')(concat)
batch_norm = layers.BatchNormalization()(dense_1)
dropout = layers.Dropout(0.25)(batch_norm)

#final layer
final_out = layers.Dense(2, activation='softmax')(dropout)

#put together the model
model = keras.Model([fst_path, scd_path, id_rela], final_out)

2023-01-22 22:22:19.735895: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [8]:
#optimization settings
opt = keras.optimizers.Adam(learning_rate=0.0005, decay=1e-6)

### Build the batches
We build each big-batch for each path combination with length (i,j). Then, we iteratively train the siamese network on different big-batches. The length of each big-batch is N.

To be specific:
* If we allow the length difference between two paths in a combination to be d, then the combination with path length i and path length j, denoted as (i,j), will be like (2,2), (2,3), (2,4), (3,3), (3,4), (3,5), ... 
* We will first build all the big-batches before fitting the NN model. 
* That is, we will perform the ObtainPathsByDynamicProgramming class function for some randomly chosen source entities. Then, for each target entity, we will further have two for loops:
* for path_1 in all the 
* Do this until all the slots in all big-batchs are filled.
* In every epoch, big-batchs will be re-filled.

Then, in the training, we will use negative sampling: In each batch (actual batch, not the big-batch), we will include K true output relation embeddings and K random selected output relation embeddings. The true label is [1,0], while the false label is [0,1].

In [9]:
#function to build all the big batches
def build_big_batches(diff, holder_len, lower_bd, upper_bd, Class_2, one_hop, s_t_r,
                      x_p_dict, x_r_dict, y_dict, path_comb, filled, total_num_need_to_fill,
                      relation2id, entity2id, id2relation, id2entity, epoch):
    
    if holder_len % 10 != 0:
        raise ValueError('We would like to take 10X as a big-batch size')
    
    #the set of all relation IDs
    relation_id_set = set()
    for i in range(len(id2relation)):
        
        if i not in id2relation:
            raise ValueError('error when generaing id2relation')
        
        relation_id_set.add(i)
    
    num_r = len(id2relation)
    
    #count how many appending has performed
    count = 0
    
    #we count how many combinations need to fill
    need_to_fill = set()
    
    for i in range(lower_bd, upper_bd+1):

        Max = min(upper_bd, i+diff)

        for j in range(i, Max+1):
            
            need_to_fill.add((i,j))

    #in case not all entities in entity2id are in one_hop, 
    #so we need to find out who are indeed in
    existing_ids = set()
    
    for s_1 in one_hop:
        existing_ids.add(s_1)
        
    existing_ids = list(existing_ids)
    
    carry_on = True
    
    while carry_on:

        #obtain paths by dynamic programming
        source_id = random.choice(existing_ids)

        result, length_dict = Class_2.obtain_paths('direct_neighbour', source_id, 
                                                   'not_specified', lower_bd, upper_bd, one_hop)
        for target_id in result:

            if not carry_on:
                break
            
            #we want to make sure s, t are indeed directly connected, 
            #otherwise there is no relation for positive sample
            #also, we want to make sure s and t and not connected by all relations, 
            #although this situation is rare. 
            #But in that case, there is no relation for negative samples 
            if ((source_id, target_id) in s_t_r) and (
                len(s_t_r[(source_id, target_id)]) < len(id2relation)):
                
                dir_r = list(s_t_r[(source_id, target_id)])
                
                non_dir_r = list(relation_id_set.difference(dir_r))
                
                if len(dir_r) <= 0:
                    
                    raise ValueError('errors when creating s_t_r !!')
                
                #iterate over path_1
                for path_1 in result[target_id]:

                    if not carry_on:
                        break

                    #iterate over path_2
                    for path_2 in result[target_id]:

                        if not carry_on:
                            break

                        #decide which path is shorter and which is longer
                        if len(path_1) <= len(path_2):

                            path_s, path_l = path_1, path_2

                        else:

                            path_s, path_l = path_2, path_1                            

                        #whether lengths of the two paths satisfies the requirments
                        if (len(path_s) >= lower_bd) and (len(path_l) <= upper_bd) and (
                            abs(len(path_s)-len(path_l)) <= diff):

                            #further consider: whether the corresponding length comb is not full,
                            #and whether this path pair is new, and whether the two paths are different
                            #But it is optional to require the path to be new. 
                            #We may remove this requirment, especially for short paths
                            '''remember to cancel the comment below when using path_comb'''
                            if (len(y_dict[(len(path_s), len(path_l))]) < holder_len) and (
                                path_s != path_l):
                                
                                #we always add one positive and one negative situation together,
                                #hence, the length of list should always be even.
                                #also we want to make sure the length of lists coincide
                                if (len(x_p_dict[(len(path_s), len(path_l))]['s']) != len(
                                    y_dict[(len(path_s), len(path_l))])) or (
                                    len(x_p_dict[(len(path_s), len(path_l))]['s']) != len(
                                        x_p_dict[(len(path_s), len(path_l))]['l'])) or (
                                    len(y_dict[(len(path_s), len(path_l))]) != len(
                                        x_r_dict[(len(path_s), len(path_l))])) or (
                                    len(y_dict[(len(path_s), len(path_l))]) % 2 != 0):
                                    
                                    raise ValueError('error when building big batches: length error')
                                
                                #####positive#####################
                                relation_id = random.choice(dir_r)
                                
                                #append the paths: note that we add the space holder id at the end
                                #of the shorter path
                                x_p_dict[(len(path_s), len(path_l))]['s'].append(
                                          list(path_s) + [num_r]*abs(len(path_s)-len(path_l)))
                                x_p_dict[(len(path_s), len(path_l))]['l'].append(list(path_l))

                                #append relation
                                x_r_dict[(len(path_s), len(path_l))].append([relation_id])
                                y_dict[(len(path_s), len(path_l))].append([1., 0.])
                                
                                #####negative#####################
                                relation_id = random.choice(non_dir_r)
                                
                                #append the paths: note that we add the space holder id at the end
                                #of the shorter path
                                x_p_dict[(len(path_s), len(path_l))]['s'].append(
                                          list(path_s) + [num_r]*abs(len(path_s)-len(path_l)))
                                x_p_dict[(len(path_s), len(path_l))]['l'].append(list(path_l))

                                #append relation
                                x_r_dict[(len(path_s), len(path_l))].append([relation_id])
                                y_dict[(len(path_s), len(path_l))].append([0., 1.])
                                
                                ######add to path combinations#####
                                #here is the tricky part: we have to add both (path_s, path_l)
                                #and (path_l, path_s). This is because when the length are the same
                                #adding only one situation won't guarantee that 
                                #the same path with different order is also considered.
                                #in other words: path combination don't have order, but our dict does.
                                #so we have to add both situations.
                                '''remember to cancel the comment here when using path_comb'''
                                #path_comb[(len(path_s), len(path_l))].add((path_s, path_l))
                                #path_comb[(len(path_s), len(path_l))].add((path_l, path_s))
                                
                                count += 1
                                
                                if count % 10000 == 0:
                                    print('generating big-batches', count, 
                                          int(holder_len*0.5*len(need_to_fill)), 'in epoch', epoch)
                                
                            if len(y_dict[(len(path_s), len(path_l))]) >= holder_len:
                                
                                prev_size = len(filled)

                                filled.add((len(path_s), len(path_l)))
                                
                                post_size = len(filled)
                                
                                #when we indeed find a new filled combo, we will print, which looks like:
                                #big-batches 1 ( 2 2 ) in N completed in epoch k
                                #big-batches 2 ( 3 5 ) in N completed in epoch k
                                #big-batches 3 ( 3 4 ) in N completed in epoch k
                                if post_size > prev_size:

                                    print('big-batches', len(filled), 
                                          '(', len(path_s), len(path_l), ')',
                                          'in', total_num_need_to_fill, 
                                          'completed in epoch', epoch)
                        
                        #check whether to finish
                        if len(need_to_fill.difference(filled)) == 0:
                            
                            carry_on = False

### Start Training: load the KG and call classes

Here, we use the validation set to see the training efficiency. That is, we use the validation to check whether the true relation between entities can be predicted by paths.

The trick is: in validation, we have to use the same relation ID and entity ID as in the training. But we don't want to use the links in training anymore. That is, in validation, we want to use (and update if necessary) entity2id, id2entity, relation2id and id2relation. But we want to use new one_hop, data, data_ and s_t_r for validation set. Then, path-finding will also be based on new one_hop.


In [10]:
#difine the names for saving
model_name = 'Model_main_1_nell_v4'
ids_name = 'IDs_main_1_nell_v4'

In [11]:
#first, we save the relation and ids
Dict = dict()
Dict['one_hop'] = one_hop
Dict['data'] = data
Dict['s_t_r'] = s_t_r
Dict['entity2id'] = entity2id
Dict['id2entity'] = id2entity
Dict['relation2id'] = relation2id
Dict['id2relation'] = id2relation

with open('../weight_bin/' + ids_name + '.pickle', 'wb') as handle:
    pickle.dump(Dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [12]:
holder_len = 200000
lower_bd = 2
upper_bd = 5
diff = 2
each_epoch = 20
entire_epoch = 0

#90% to be train, 10% to be validation
train_len = 9*int(holder_len/10)

current = 0
    
######################################
###pre-define the lists###############

#define the lists
x_p_list, x_p_dict, x_r_dict, y_dict, path_comb = list(), dict(), dict(), dict(), dict()
filled, total_num_need_to_fill = set(), 0

#build the lists first
for i in range(lower_bd, upper_bd+1):

    Max = min(upper_bd, i+diff)

    for j in range(i, Max+1):

        x_p_list.append((i,j))
        x_p_dict[(i,j)] = {'s': [], 'l': []}
        x_r_dict[(i,j)] = list()
        y_dict[(i,j)] = list()
        path_comb[(i,j)] = set()
        total_num_need_to_fill += 1

#######################################
###build the big-batches###############

#here is a tricky thing: if the distance between upper_bd and lower_bd is big,
#then when run path_finding_dynamic_programming, 
#the chance to obtain paths with length close to lower bd is low.
#For instance, if lower_bd = 2, upper_bd = 10, then most found paths will have length
#from 6 to 10. Paths with length 2 or 3 is relatively rare.
#Hence, we need to first build big-batches for combination between shorter lengths.
#otherwise the while loop may last for really really long time,
#due to difficulty to find shorter paths.
build_big_batches(diff, holder_len, lower_bd, lower_bd, Class_2, one_hop, s_t_r,
                      x_p_dict, x_r_dict, y_dict, path_comb, filled, total_num_need_to_fill,
                      relation2id, entity2id, id2relation, id2entity, entire_epoch)    

if upper_bd - lower_bd > 3:

    build_big_batches(diff, holder_len, lower_bd, lower_bd + 3, Class_2, one_hop, s_t_r,
                      x_p_dict, x_r_dict, y_dict, path_comb, filled, total_num_need_to_fill,
                      relation2id, entity2id, id2relation, id2entity, entire_epoch)        

#fill in the training array list
build_big_batches(diff, holder_len, lower_bd, upper_bd, Class_2, one_hop, s_t_r,
                      x_p_dict, x_r_dict, y_dict, path_comb, filled, total_num_need_to_fill,
                      relation2id, entity2id, id2relation, id2entity, entire_epoch)

#######################################
###do the training#####################
for key in x_p_list:

    #generate the input arrays
    x_train_s = np.asarray(x_p_dict[key]['s'][:train_len], dtype='int')
    x_train_l = np.asarray(x_p_dict[key]['l'][:train_len], dtype='int')
    x_train_r = np.asarray(x_r_dict[key][:train_len], dtype='int')
    y_train = np.asarray(y_dict[key][:train_len], dtype='int')

    x_valid_s = np.asarray(x_p_dict[key]['s'][train_len:], dtype='int')
    x_valid_l = np.asarray(x_p_dict[key]['l'][train_len:], dtype='int')
    x_valid_r = np.asarray(x_r_dict[key][train_len:], dtype='int')
    y_valid = np.asarray(y_dict[key][train_len:], dtype='int')

    print('training on length', key, 'for epoch', entire_epoch)

    #we use a bit dummy method: in order to make sure the model is saved,
    #we save it at each length pair in each epoch!!!
    #so unless it is the first pair first epoch, we will always read from previous checkpoint
    if key != x_p_list[0]:

        print('load model')

        model = keras.models.load_model('../weight_bin/' + model_name + '.h5')

    else:

        #compile the model
        model.compile(
            loss='categorical_crossentropy',
            optimizer=opt,
            metrics=["categorical_accuracy"],)

    model.fit([x_train_s, x_train_l, x_train_r], y_train, 
              validation_data=([x_valid_s, x_valid_l, x_valid_r], y_valid),
              batch_size=4, epochs=current+each_epoch, initial_epoch=current)

    current += each_epoch

    # Save model and weights
    add_h5 = model_name + '.h5'
    save_dir = os.path.join(os.getcwd(), '../weight_bin')

    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    model_path = os.path.join(save_dir, add_h5)
    model.save(model_path)
    print('Save model')
    del(model)

    del(x_train_s, x_train_l, x_train_r, y_train)
    del(x_valid_s, x_valid_l, x_valid_r, y_valid)

del(x_p_list, x_p_dict, x_r_dict, y_dict, path_comb)
del(filled, total_num_need_to_fill)

generating big-batches 10000 100000 in epoch 0
generating big-batches 20000 100000 in epoch 0
generating big-batches 30000 100000 in epoch 0
generating big-batches 40000 100000 in epoch 0
generating big-batches 50000 100000 in epoch 0
generating big-batches 60000 100000 in epoch 0
generating big-batches 70000 100000 in epoch 0
generating big-batches 80000 100000 in epoch 0
generating big-batches 90000 100000 in epoch 0
generating big-batches 100000 100000 in epoch 0
big-batches 1 ( 2 2 ) in 9 completed in epoch 0
generating big-batches 10000 900000 in epoch 0
generating big-batches 20000 900000 in epoch 0
generating big-batches 30000 900000 in epoch 0
generating big-batches 40000 900000 in epoch 0
generating big-batches 50000 900000 in epoch 0
generating big-batches 60000 900000 in epoch 0
generating big-batches 70000 900000 in epoch 0
generating big-batches 80000 900000 in epoch 0
generating big-batches 90000 900000 in epoch 0
generating big-batches 100000 900000 in epoch 0
generating

Save model
training on length (2, 3) for epoch 0
load model
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40
Save model
training on length (2, 4) for epoch 0
load model
Epoch 41/60
Epoch 42/60
Epoch 43/60
Epoch 44/60
Epoch 45/60
Epoch 46/60
Epoch 47/60
Epoch 48/60
Epoch 49/60
Epoch 50/60
Epoch 51/60
Epoch 52/60
Epoch 53/60
Epoch 54/60
Epoch 55/60
Epoch 56/60
Epoch 57/60
Epoch 58/60
Epoch 59/60
Epoch 60/60
Save model
training on length (3, 3) for epoch 0
load model
Epoch 61/80
Epoch 62/80
Epoch 63/80
Epoch 64/80
Epoch 65/80
Epoch 66/80
Epoch 67/80
Epoch 68/80
Epoch 69/80
Epoch 70/80
Epoch 71/80
Epoch 72/80
Epoch 73/80
Epoch 74/80
Epoch 75/80
Epoch 76/80
Epoch 77/80
Epoch 78/80
Epoch 79/80
Epoch 80/80
Save model
training on length (3, 4) for epoch 0
load model
Epoch 81/100
Epoch 82/100
Epoch 83/100
E

Epoch 114/120
Epoch 115/120
Epoch 116/120
Epoch 117/120
Epoch 118/120
Epoch 119/120
Epoch 120/120
Save model
training on length (4, 4) for epoch 0
load model
Epoch 121/140
Epoch 122/140
Epoch 123/140
Epoch 124/140
Epoch 125/140
Epoch 126/140
Epoch 127/140
Epoch 128/140
Epoch 129/140
Epoch 130/140
Epoch 131/140
Epoch 132/140
Epoch 133/140
Epoch 134/140
Epoch 135/140
Epoch 136/140
Epoch 137/140
Epoch 138/140
Epoch 139/140
Epoch 140/140
Save model
training on length (4, 5) for epoch 0
load model
Epoch 141/160
Epoch 142/160
Epoch 143/160
Epoch 144/160
Epoch 145/160
Epoch 146/160
Epoch 147/160
Epoch 148/160
Epoch 149/160
Epoch 150/160
Epoch 151/160
Epoch 152/160
Epoch 153/160
Epoch 154/160
Epoch 155/160
Epoch 156/160
Epoch 157/160
Epoch 158/160
Epoch 159/160
Epoch 160/160
Save model
training on length (5, 5) for epoch 0
load model
Epoch 161/180
Epoch 162/180
Epoch 163/180
Epoch 164/180
Epoch 165/180
Epoch 166/180
Epoch 167/180
Epoch 168/180
Epoch 169/180
Epoch 170/180
Epoch 171/180
Epoch 17

### Result on the testset for inductive link prediction

We use the testset for inductive link prediction.

In [1]:
import librosa
import opensmile
import os
import sys
import numpy as np
import random
from collections import defaultdict
from copy import deepcopy
import pickle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras.utils import plot_model

In [2]:
class LoadKG:
    
    def __init__(self):
        
        self.x = 'Hello'
        
    def load_train_data(self, data_path, one_hop, data, s_t_r, entity2id, id2entity,
                     relation2id, id2relation):
        
        data_ = set()
    
        ####load the train, valid and test set##########
        with open (data_path, 'r') as f:
            
            data_ini = f.readlines()
                        
            for i in range(len(data_ini)):
            
                x = data_ini[i].split()
                
                x_ = tuple(x)
                
                data_.add(x_)
        
        ####relation dict#################
        index = len(relation2id)
     
        for key in data_:
            
            if key[1] not in relation2id:
                
                relation = key[1]
                
                relation2id[relation] = index
                
                id2relation[index] = relation
                
                index += 1
                
                #the inverse relation
                iv_r = '_inverse_' + relation
                
                relation2id[iv_r] = index
                
                id2relation[index] = iv_r
                
                index += 1
        
        #get the id of the inverse relation, by above definition, initial relation has 
        #always even id, while inverse relation has always odd id.
        def inverse_r(r):
            
            if r % 2 == 0: #initial relation
                
                iv_r = r + 1
            
            else: #inverse relation
                
                iv_r = r - 1
            
            return(iv_r)
        
        ####entity dict###################
        index = len(entity2id)
        
        for key in data_:
            
            source, target = key[0], key[2]
            
            if source not in entity2id:
                                
                entity2id[source] = index
                
                id2entity[index] = source
                
                index += 1
            
            if target not in entity2id:
                
                entity2id[target] = index
                
                id2entity[index] = target
                
                index += 1
                
        #create the set of triples using id instead of string        
        for ele in data_:
            
            s = entity2id[ele[0]]
            
            r = relation2id[ele[1]]
            
            t = entity2id[ele[2]]
            
            if (s,r,t) not in data:
                
                data.add((s,r,t))
            
            s_t_r[(s,t)].add(r)
            
            if s not in one_hop:
                
                one_hop[s] = dict()
            
            if r not in one_hop[s]:
                
                one_hop[s][r] = set()
            
            one_hop[s][r].add(t)
            
            if t not in one_hop:
                
                one_hop[t] = dict()
            
            r_inv = inverse_r(r)
            
            s_t_r[(t,s)].add(r_inv)
            
            if r_inv not in one_hop[t]:
                
                one_hop[t][r_inv] = set()
            
            one_hop[t][r_inv].add(s)

In [3]:
class ObtainPathsByDynamicProgramming:

    def __init__(self, size_bd=50, threshold=100000):
                
        self.size_bd = size_bd
        
        self.threshold = threshold
    
    '''
    Given an entity s, here is the function to find:
      1. any else entity t that is directely connected to s
      2. most of the paths from s to each t with length L
    
    One may refer to LeetCode Problem 797 for details:
        https://leetcode.com/problems/all-paths-from-source-to-target/
    '''
    def obtain_paths(self, mode, s, t_input, lower_bd, upper_bd, one_hop):

        if type(lower_bd) != type(1) or lower_bd < 1:
            
            raise TypeError("!!! invalid lower bound setting, must >= 1 !!!")
            
        if type(upper_bd) != type(1) or upper_bd < 1:
            
            raise TypeError("!!! invalid upper bound setting, must >= 1 !!!")
            
        if lower_bd > upper_bd:
            
            raise TypeError("!!! lower bound must not exced upper bound !!!")
            
        if s not in one_hop:
            
            raise ValueError('!!! entity not in one_hop. Please work on active entities for validation')
        
        #here is the result dict. Its key is each entity t that is directly connected to s
        #The value of each t is a set containing the paths from s to t
        #These paths can be either the direct connection r, or a multi-hop path
        res = defaultdict(set)
        
        #direct_nb contains all the direct neighbour of s
        direct_nb = set()
        
        if mode == 'direct_neighbour':
        
            for r in one_hop[s]:
            
                for t in one_hop[s][r]:
                
                    direct_nb.add(t)
                    
        elif mode == 'target_specified':
            
            direct_nb.add(t_input)
            
        elif mode == 'any_target':
            
            for s_any in one_hop:
                
                direct_nb.add(s_any)
                
        else:
            
            raise ValueError('not a valid mode')
        
        '''
        We use recursion to find the paths
        On current node with the path [r1, ..., rk] and on-path entities {e1, ..., ek-1, node}
        from s to this node, we further find the direct neighbor t' of this node. 
        If t' is not a on-path entity (not among e1,...ek-1), we recursively proceed to t' 
        '''
        def helper(node, path, on_path_en, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict):
            
            #when the current path is within lower_bd and upper_bd and its corresponding
            #length still within the size_bd and its tail node is within the note dict, 
            #we will then intend to add this path
            if (len(path) >= lower_bd) and (len(path) <= upper_bd) and (
                node in direct_nb) and (length_dict[len(path)] < self.size_bd):
                
                #if this path already exists between the source entity and the current target node,
                #we will not count it.
                #here is an interesting situation: this path may exist between s and some other node t,
                #however, it does not exist between s and this node t. Then, we still count it: length_dict[len(path)] += 1
                #That is, each path may be counted for multiple times.
                #We count how many paths we "actually" found between entity pairs
                #Same type of path between different entity pairs are count separately.
                if tuple(path) not in res[node]:
                
                    res[node].add(tuple(path))
                
                    length_dict[len(path)] += 1
                
            #For some rare entities, we may face such a case: so many paths are evaluated,
            #but no entities on the paths are direct neighbors of the rare entity.
            #In this case, the recursion cannot be bounded and stoped by the size threshold.
            #In order to cure this, we count how many times the recursion happens on a specific length, using the count_dict.
            #Its key is length, value counts the recursion occurred to that length. 
            #The recursion is forced to stop for that length (and hence for longer lengths) once reach the threshold.
            if (len(path) < upper_bd) and (length_dict[len(path) + 1] < self.size_bd) and (
                count_dict[len(path)] <= self.threshold):
                
                #we randomly shuffle relation r so that the reading in order is not fixed
                temp_list = list()
                
                for r in one_hop[node]:
                    
                    temp_list.append(r)
                
                for i_0 in range(len(temp_list)):
                    
                    if count_dict[len(path)] > self.threshold:
                        break
                    
                    r = random.choice(temp_list)
                    
                    for i_1 in range(len(one_hop[node][r])):
                        
                        if count_dict[len(path)] > self.threshold:
                            break
                        
                        t = random.choice(list(one_hop[node][r]))
                        
                        if t not in on_path_en:
                                
                            count_dict[len(path)] += 1

                            helper(t, path + [r], on_path_en.union({t}), res, direct_nb, 
                                   lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        length_dict = defaultdict(int)
        count_dict = defaultdict(int)
        
        helper(s, [], {s}, res, direct_nb, lower_bd, upper_bd, one_hop, length_dict, count_dict)
        
        return(res, length_dict)

In [4]:
#load the classes
Class_1 = LoadKG()
Class_2 = ObtainPathsByDynamicProgramming()

In [5]:
#load ids and relation/entity dicts
with open('../weight_bin/IDs_main_1_nell_v4.pickle', 'rb') as handle:
    Dict = pickle.load(handle)

one_hop = Dict['one_hop']
data = Dict['data']
s_t_r = Dict['s_t_r']
entity2id = Dict['entity2id']
id2entity = Dict['id2entity']
relation2id = Dict['relation2id']
id2relation = Dict['id2relation']

#we want to keep the initial entity/relation dicts
entity2id_ini = deepcopy(entity2id)
id2entity_ini = deepcopy(id2entity)
relation2id_ini = deepcopy(relation2id)
id2relation_ini = deepcopy(id2relation)

num_r = len(id2relation)

In [6]:
#load the model
model = keras.models.load_model('../weight_bin/Model_main_1_nell_v4.h5')

2023-01-24 07:46:36.552120: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [7]:
ind_train_path = '../data/nell_v4_ind/train.txt'
ind_valid_path = '../data/nell_v4_ind/valid.txt'
ind_test_path = '../data/nell_v4_ind/test.txt'

In [8]:
#load the test dataset
one_hop_ind = dict() 
data_ind = set()
s_t_r_ind = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_train_path, 
                        one_hop_ind, data_ind, s_t_r_ind,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [9]:
print(size_0, size_1, len(data_ind))

2092 4886 7073


In [10]:
#load the test dataset
one_hop_test = dict() 
data_test = set()
s_t_r_test = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_test_path, 
                        one_hop_test, data_test, s_t_r_test,
                        entity2id, id2entity, relation2id, id2relation)


len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [11]:
print(size_0, size_1, len(data_test))

4886 4886 731


In [12]:
#load the validation for existing triple removal when ranking
one_hop_valid = dict() 
data_valid = set()
s_t_r_valid = defaultdict(set)

len_0 = len(relation2id)
size_0 = len(entity2id)

#fill in the sets and dicts
Class_1.load_train_data(ind_valid_path, 
                        one_hop_valid, data_valid, s_t_r_valid,
                        entity2id, id2entity, relation2id, id2relation)

len_1 = len(relation2id)
size_1 = len(entity2id)

if len_0 != len_1:
    raise ValueError('unseen relation!')

In [13]:
print(size_0, size_1, len(data_valid))

4886 4886 716


In [14]:
print(len(entity2id), len(entity2id_ini))

4886 2092


In [15]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_test:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

11

In [16]:
overlapping = 0

for ele in data_valid:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

17

In [17]:
#we want to check whether there are overlapping 
#between the entities of train triples and inductive test and valid triples
overlapping = 0

for ele in data_ind:
    
    s, r, t = ele[0], ele[1], ele[2]
    
    if s in id2entity_ini or t in id2entity_ini:
        
        overlapping += 1
        
overlapping

110

In [18]:
def relation_ranking(s, t, diff, lower_bd, upper_bd, one_hop, id2relation, model):
    
    path_holder = set()
    
    for iteration in range(20):
    
        result, length_dict = Class_2.obtain_paths('target_specified', 
                                                   s, t, lower_bd, upper_bd, one_hop)
        if t in result:
            
            for path in result[t]:
                
                path_holder.add(path)
    
    path_holder_1 = list(path_holder)
    path_holder_2 = list(path_holder)
    
    random.shuffle(path_holder_1)
    random.shuffle(path_holder_2)
    
    score_dict = defaultdict(float)
    
    count = 0
    
    #iterate over path_1
    for path_1 in path_holder_1:
        
        if count == 30:
            break

        #iterate over path_2
        for path_2 in path_holder_2:
            
            if count == 30:
                break

            #decide which path is shorter and which is longer
            if len(path_1) <= len(path_2):

                path_s, path_l = path_1, path_2

            else:

                path_s, path_l = path_2, path_1                            

            #whether lengths of the two paths satisfies the requirments
            if (len(path_s) >= lower_bd) and (len(path_l) <= upper_bd) and (
                abs(len(path_s)-len(path_l)) <= diff):
                
                list_s = list()
                list_l = list()
                list_r = list()

                for i in range(len(id2relation)):
                    
                    if i not in id2relation:
                        
                        raise ValueError ('error when generating id2relation')

                    list_s.append(list(path_s) + [num_r]*abs(len(path_s)-len(path_l)))
                    list_l.append(list(path_l))
                    list_r.append([i])
                
                input_s = np.array(list_s)
                input_l = np.array(list_l)
                input_r = np.array(list_r)
                
                pred = model.predict([input_s, input_l, input_r], verbose = 0)
    
                for i in range(pred.shape[0]):

                    score_dict[i] += float(pred[i][0])
                
                count += 1
                
    print(len(score_dict), len(path_holder))

    return(score_dict)

In [20]:
########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_test), min(len(data_test), 500))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0
Hits_at_10 = 0
MRR_raw = 0.

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 2, 5, one_hop_ind, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    inverse_r = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #we want to see how many inverse relaiton are ranked above true relation
        #then, we remove them from ranking, otherwise it is not fair for us to compare our
        #result with other model who does not consider inverse relations
        if temp_list[p][1] % 2 != 0:
            
            inverse_r += 1
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
              (s_true, sorted_list[p][1], t_true) in data_valid) or (
              (s_true, sorted_list[p][1], t_true) in data_ind):
            
            exist_tri += 1
            
        p += 1
    
    if p - inverse_r - exist_tri == 0:
        
        Hits_at_1 += 1
        
    if p - inverse_r - exist_tri < 10:
        
        Hits_at_10 += 1
        
    MRR_raw += 1./float(p - inverse_r - exist_tri + 1.) 
        
    print('Hits@1', Hits_at_1/(i+1),
          'Hits@10', Hits_at_10/(i+1),
          'MRR', MRR_raw/(i+1),
          'cur_rank', p - inverse_r - exist_tri, 
          'total_num', i, len(selected))

152 202
Hits@1 0.0 Hits@10 0.0 MRR 0.023255813953488372 cur_rank 42 total_num 0 500
0 0
Hits@1 0.0 Hits@10 0.0 MRR 0.020102483247930625 cur_rank 58 total_num 1 500
152 322
Hits@1 0.0 Hits@10 0.0 MRR 0.024895908372183636 cur_rank 28 total_num 2 500
152 298
Hits@1 0.0 Hits@10 0.0 MRR 0.024921931279137728 cur_rank 39 total_num 3 500
152 301
Hits@1 0.0 Hits@10 0.0 MRR 0.024381989467754626 cur_rank 44 total_num 4 500
152 75
Hits@1 0.16666666666666666 Hits@10 0.16666666666666666 MRR 0.18698499122312887 cur_rank 0 total_num 5 500
152 42
Hits@1 0.14285714285714285 Hits@10 0.14285714285714285 MRR 0.162614769994532 cur_rank 60 total_num 6 500
152 199
Hits@1 0.125 Hits@10 0.125 MRR 0.14659826857280173 cur_rank 28 total_num 7 500
152 215
Hits@1 0.1111111111111111 Hits@10 0.1111111111111111 MRR 0.13272503100191071 cur_rank 45 total_num 8 500
152 248
Hits@1 0.1 Hits@10 0.1 MRR 0.12153586123505297 cur_rank 47 total_num 9 500
152 312
Hits@1 0.09090909090909091 Hits@10 0.09090909090909091 MRR 0.1130123

152 400
Hits@1 0.08974358974358974 Hits@10 0.2564102564102564 MRR 0.1480872204305586 cur_rank 8 total_num 77 500
152 61
Hits@1 0.08860759493670886 Hits@10 0.26582278481012656 MRR 0.15043210793565703 cur_rank 2 total_num 78 500
152 689
Hits@1 0.0875 Hits@10 0.2625 MRR 0.1488424042608799 cur_rank 42 total_num 79 500
152 159
Hits@1 0.08641975308641975 Hits@10 0.25925925925925924 MRR 0.14727322804080464 cur_rank 45 total_num 80 500
0 0
Hits@1 0.08536585365853659 Hits@10 0.25609756097560976 MRR 0.14591275313438368 cur_rank 27 total_num 81 500
0 0
Hits@1 0.08433734939759036 Hits@10 0.26506024096385544 MRR 0.14535958743396943 cur_rank 9 total_num 82 500
0 0
Hits@1 0.08333333333333333 Hits@10 0.2619047619047619 MRR 0.1439508664767439 cur_rank 36 total_num 83 500
152 662
Hits@1 0.08235294117647059 Hits@10 0.25882352941176473 MRR 0.14251876477963188 cur_rank 44 total_num 84 500
0 0
Hits@1 0.08139534883720931 Hits@10 0.2558139534883721 MRR 0.14127685223236044 cur_rank 27 total_num 85 500
152 400


152 758
Hits@1 0.046357615894039736 Hits@10 0.2052980132450331 MRR 0.11047528082980249 cur_rank 37 total_num 150 500
0 0
Hits@1 0.046052631578947366 Hits@10 0.20394736842105263 MRR 0.10991294345592219 cur_rank 39 total_num 151 500
152 127
Hits@1 0.0457516339869281 Hits@10 0.20261437908496732 MRR 0.10947872984424667 cur_rank 22 total_num 152 500
152 184
Hits@1 0.045454545454545456 Hits@10 0.2012987012987013 MRR 0.10886929004006324 cur_rank 63 total_num 153 500
152 250
Hits@1 0.04516129032258064 Hits@10 0.2 MRR 0.10831027669930297 cur_rank 44 total_num 154 500
152 90
Hits@1 0.04487179487179487 Hits@10 0.1987179487179487 MRR 0.10807385551167009 cur_rank 13 total_num 155 500
0 0
Hits@1 0.044585987261146494 Hits@10 0.19745222929936307 MRR 0.10787544290919497 cur_rank 12 total_num 156 500
0 0
Hits@1 0.04430379746835443 Hits@10 0.20253164556962025 MRR 0.10782559833382033 cur_rank 9 total_num 157 500
152 757
Hits@1 0.0440251572327044 Hits@10 0.20125786163522014 MRR 0.1072812649842031 cur_rank 

152 37
Hits@1 0.03587443946188341 Hits@10 0.18834080717488788 MRR 0.09814488560932806 cur_rank 4 total_num 222 500
152 479
Hits@1 0.03571428571428571 Hits@10 0.1875 MRR 0.09795475467158801 cur_rank 17 total_num 223 500
152 1
Hits@1 0.04 Hits@10 0.19111111111111112 MRR 0.1019638446508254 cur_rank 0 total_num 224 500
152 565
Hits@1 0.03982300884955752 Hits@10 0.1902654867256637 MRR 0.10160682142557698 cur_rank 46 total_num 225 500
0 0
Hits@1 0.039647577092511016 Hits@10 0.1894273127753304 MRR 0.10131654593786203 cur_rank 27 total_num 226 500
152 285
Hits@1 0.039473684210526314 Hits@10 0.18859649122807018 MRR 0.10099400748101955 cur_rank 35 total_num 227 500
152 303
Hits@1 0.039301310043668124 Hits@10 0.18777292576419213 MRR 0.10067428595393117 cur_rank 35 total_num 228 500
152 460
Hits@1 0.0391304347826087 Hits@10 0.18695652173913044 MRR 0.10035734461403485 cur_rank 35 total_num 229 500
152 219
Hits@1 0.03896103896103896 Hits@10 0.18614718614718614 MRR 0.10009605740791348 cur_rank 24 tot

0 0
Hits@1 0.03728813559322034 Hits@10 0.18305084745762712 MRR 0.10044464077067257 cur_rank 15 total_num 294 500
0 0
Hits@1 0.037162162162162164 Hits@10 0.18243243243243243 MRR 0.10020767587044405 cur_rank 32 total_num 295 500
152 234
Hits@1 0.037037037037037035 Hits@10 0.18181818181818182 MRR 0.09994857869227247 cur_rank 42 total_num 296 500
152 436
Hits@1 0.03691275167785235 Hits@10 0.18120805369127516 MRR 0.09971187797419734 cur_rank 33 total_num 297 500
152 305
Hits@1 0.03678929765886288 Hits@10 0.1806020066889632 MRR 0.09946200547261139 cur_rank 39 total_num 298 500
152 280
Hits@1 0.03666666666666667 Hits@10 0.18 MRR 0.09931565063955454 cur_rank 17 total_num 299 500
0 0
Hits@1 0.036544850498338874 Hits@10 0.17940199335548174 MRR 0.09907798328785429 cur_rank 35 total_num 300 500
152 303
Hits@1 0.03642384105960265 Hits@10 0.17880794701986755 MRR 0.09883940396248732 cur_rank 36 total_num 301 500
152 131
Hits@1 0.036303630363036306 Hits@10 0.1782178217821782 MRR 0.0986452145104659 cur

152 93
Hits@1 0.04087193460490463 Hits@10 0.1771117166212534 MRR 0.10061929679010634 cur_rank 1 total_num 366 500
152 819
Hits@1 0.04076086956521739 Hits@10 0.1766304347826087 MRR 0.10039915697039752 cur_rank 50 total_num 367 500
152 1
Hits@1 0.04065040650406504 Hits@10 0.17886178861788618 MRR 0.10080457930923112 cur_rank 3 total_num 368 500
152 34
Hits@1 0.04054054054054054 Hits@10 0.1783783783783784 MRR 0.10060518051927922 cur_rank 36 total_num 369 500
152 183
Hits@1 0.04043126684636118 Hits@10 0.1778975741239892 MRR 0.10039526702118755 cur_rank 43 total_num 370 500
152 152
Hits@1 0.04032258064516129 Hits@10 0.1774193548387097 MRR 0.1001825824209819 cur_rank 46 total_num 371 500
0 0
Hits@1 0.040214477211796246 Hits@10 0.1769436997319035 MRR 0.10000974516439556 cur_rank 27 total_num 372 500
152 253
Hits@1 0.040106951871657755 Hits@10 0.17647058823529413 MRR 0.09980175713513843 cur_rank 44 total_num 373 500
152 632
Hits@1 0.04 Hits@10 0.176 MRR 0.09975784133833362 cur_rank 11 total_num

152 312
Hits@1 0.04100227790432802 Hits@10 0.17995444191343962 MRR 0.10255793654202368 cur_rank 24 total_num 438 500
152 23
Hits@1 0.04090909090909091 Hits@10 0.18181818181818182 MRR 0.10264952564728533 cur_rank 6 total_num 439 500
152 56
Hits@1 0.04081632653061224 Hits@10 0.1836734693877551 MRR 0.10298365370704204 cur_rank 3 total_num 440 500
0 0
Hits@1 0.04072398190045249 Hits@10 0.18552036199095023 MRR 0.10297690335928857 cur_rank 9 total_num 441 500
0 0
Hits@1 0.040632054176072234 Hits@10 0.18510158013544017 MRR 0.10290568816305669 cur_rank 13 total_num 442 500
152 185
Hits@1 0.04054054054054054 Hits@10 0.18468468468468469 MRR 0.10273318839123378 cur_rank 37 total_num 443 500
152 817
Hits@1 0.04044943820224719 Hits@10 0.1842696629213483 MRR 0.10257042399103557 cur_rank 32 total_num 444 500
0 0
Hits@1 0.04035874439461883 Hits@10 0.18385650224215247 MRR 0.1023711598792552 cur_rank 72 total_num 445 500
152 60
Hits@1 0.040268456375838924 Hits@10 0.18568232662192394 MRR 0.10288785377959

In [18]:
'''Ranking on transductive setting'''

########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_test), int(len(data_test)/5))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 2, 10, one_hop, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    inverse_r = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #we want to see how many inverse relaiton are ranked above true relation
        #then, we remove them from ranking, otherwise it is not fair for us to compare our
        #result with other model who does not consider inverse relations
        if temp_list[p][1] % 2 != 0:
            
            inverse_r += 1
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
              (s_true, sorted_list[p][1], t_true) in data_valid) or (
              (s_true, sorted_list[p][1], t_true) in data):
            
            exist_tri += 1
            
        p += 1
    
    if p - inverse_r - exist_tri == 0:
        
        Hits_at_1 += 1
        
    print('Calculating_hit_at_one', Hits_at_1/(i+1), 'current ranking:', 
          p - inverse_r - exist_tri, 'triple number:', i, len(selected))

438 66
Calculating_hit_at_one 0.0 current ranking: 1 triple number: 0 672
438 1171
Calculating_hit_at_one 0.0 current ranking: 1 triple number: 1 672
438 1044
Calculating_hit_at_one 0.0 current ranking: 1 triple number: 2 672
438 600
Calculating_hit_at_one 0.25 current ranking: 0 triple number: 3 672
438 4356
Calculating_hit_at_one 0.2 current ranking: 1 triple number: 4 672
438 377
Calculating_hit_at_one 0.3333333333333333 current ranking: 0 triple number: 5 672
438 1864
Calculating_hit_at_one 0.42857142857142855 current ranking: 0 triple number: 6 672
438 509
Calculating_hit_at_one 0.375 current ranking: 2 triple number: 7 672
438 1675
Calculating_hit_at_one 0.3333333333333333 current ranking: 27 triple number: 8 672
438 625
Calculating_hit_at_one 0.3 current ranking: 1 triple number: 9 672
438 2672
Calculating_hit_at_one 0.36363636363636365 current ranking: 0 triple number: 10 672
438 1404
Calculating_hit_at_one 0.3333333333333333 current ranking: 2 triple number: 11 672
438 553
Cal

KeyboardInterrupt: 

In [144]:
########################################################
#obtain the precision-recall area under curve (AUC-PR)##

#randomly select 10% of the triples
selected = random.sample(list(data_test), int(len(data_test)/5))

random.shuffle(selected)

###Hit at 1#############################
#generate the negative samples by randomly replace relation with all the other relaiton
Hits_at_1 = 0

for i in range(len(selected)):
    
    s_true, r_true, t_true = selected[i][0], selected[i][1], selected[i][2]
    
    score_dict = relation_ranking(s_true, t_true, 2, 2, 10, one_hop_ind, id2relation, model)
    
    #[... [score, r], ...]
    temp_list = list()
    
    for r in id2relation:
        
        if r in score_dict:
            
            temp_list.append([score_dict[r], r])
            
        else:
            
            temp_list.append([0.0, r])
        
    sorted_list = sorted(temp_list, key = lambda x: x[0], reverse=True)
    
    p = 0
    inverse_r = 0
    exist_tri = 0
    
    while p < len(sorted_list) and sorted_list[p][1] != r_true:
        
        #we want to see how many inverse relaiton are ranked above true relation
        #then, we remove them from ranking, otherwise it is not fair for us to compare our
        #result with other model who does not consider inverse relations
        if temp_list[p][1] % 2 != 0:
            
            inverse_r += 1
        
        #moreover, we want to remove existing triples
        if ((s_true, sorted_list[p][1], t_true) in data_test) or (
              (s_true, sorted_list[p][1], t_true) in data_valid) or (
              (s_true, sorted_list[p][1], t_true) in data_ind):
            
            exist_tri += 1
            
        p += 1
    
    if p - inverse_r - exist_tri == 0:
        
        Hits_at_1 += 1
        
    print('Calculating_hit_at_one', Hits_at_1/(i+1), 'current ranking:', 
          p - inverse_r - exist_tri, 'triple number:', i, len(selected))

438 263
Calculating_hit_at_one 1.0 current ranking: 0 0 284
438 3646
Calculating_hit_at_one 0.5 current ranking: 1 1 284
438 1588
Calculating_hit_at_one 0.3333333333333333 current ranking: 10 2 284
438 707
Calculating_hit_at_one 0.25 current ranking: 4 3 284
438 3451
Calculating_hit_at_one 0.4 current ranking: 0 4 284
438 2757
Calculating_hit_at_one 0.3333333333333333 current ranking: 7 5 284
438 559
Calculating_hit_at_one 0.2857142857142857 current ranking: 1 6 284
438 2868
Calculating_hit_at_one 0.375 current ranking: 0 7 284
438 3601
Calculating_hit_at_one 0.3333333333333333 current ranking: 1 8 284
438 1470
Calculating_hit_at_one 0.4 current ranking: 0 9 284
438 37
Calculating_hit_at_one 0.36363636363636365 current ranking: 1 10 284
438 2329
Calculating_hit_at_one 0.3333333333333333 current ranking: 7 11 284
438 1230
Calculating_hit_at_one 0.38461538461538464 current ranking: 0 12 284
438 929
Calculating_hit_at_one 0.35714285714285715 current ranking: 1 13 284
438 691
Calculating_h

438 660
Calculating_hit_at_one 0.3063063063063063 current ranking: 1 110 284
438 2095
Calculating_hit_at_one 0.30357142857142855 current ranking: 2 111 284
438 2160
Calculating_hit_at_one 0.30973451327433627 current ranking: 0 112 284
438 1207
Calculating_hit_at_one 0.3157894736842105 current ranking: 0 113 284
438 1671
Calculating_hit_at_one 0.3130434782608696 current ranking: 1 114 284
438 5113
Calculating_hit_at_one 0.3103448275862069 current ranking: 1 115 284
438 4248
Calculating_hit_at_one 0.3162393162393162 current ranking: 0 116 284
438 2935
Calculating_hit_at_one 0.3135593220338983 current ranking: 96 117 284
438 36
Calculating_hit_at_one 0.31092436974789917 current ranking: 1 118 284
438 429
Calculating_hit_at_one 0.30833333333333335 current ranking: 1 119 284
438 4473
Calculating_hit_at_one 0.30578512396694213 current ranking: 3 120 284


KeyboardInterrupt: 