In [1]:
# encoding: utf-8
import os
import pandas as pd
import numpy as np
import time
from collections import defaultdict
# encoding: utf-8
import os
from collections import defaultdict as ddict

import torch
import torch.nn as nn
from torch.nn.init import xavier_normal_
from torch.nn import functional as F
from torch.autograd import Variable
from numpy.random import RandomState



from sklearn.utils import shuffle as skshuffle
import os

## Datasets

In [2]:
class KnowledgeGraph:
    def __init__(self, data_dir, gran=1,rev_set=0):
        self.data_dir = data_dir
        self.entity_dict = {}
        self.gran = gran
        self.entities = []
        self.relation_dict = {}
        self.n_entity = 0
        self.n_relation = 0
        self.training_triples = []  # list of triples in the form of (h, t, r)
        self.validation_triples = []
        self.test_triples = []
        self.training_facts = []
        self.validation_facts = []
        self.test_facts = []
        self.n_training_triple = 0
        self.n_validation_triple = 0
        self.n_test_triple = 0
        self.rev_set = rev_set
        self.start_date = '2014-01-01' if self.data_dir == 'icews14' else '2005-01-01'
        self.start_sec = time.mktime(time.strptime(self.start_date,'%Y-%m-%d'))
        self.n_time=365 if self.data_dir == 'icews14' else 4017
        self.to_skip_final = {'lhs': {}, 'rhs': {}}
        '''load dicts and triples'''
        self.load_dicts()
        self.load_triples()
        self.load_filters()
        '''construct pools after loading'''
        # self.training_triple_pool = set(self.training_triples)
        # self.golden_triple_pool = set(self.training_triples) | set(self.validation_triples) | set(self.test_triples)

    def load_dicts(self):
        entity_dict_file = 'entity2id.txt'
        relation_dict_file = 'relation2id.txt'
        print('-----Loading entity dict-----')
        entity_df = pd.read_table(os.path.join(self.data_dir, entity_dict_file), header=None)
        self.entity_dict = dict(zip(entity_df[0], entity_df[1]))
        self.n_entity = len(self.entity_dict)
        self.entities = list(self.entity_dict.values())
        print('#entity: {}'.format(self.n_entity))
        print('-----Loading relation dict-----')
        relation_df = pd.read_table(os.path.join(self.data_dir, relation_dict_file), header=None)
        self.relation_dict = dict(zip(relation_df[0], relation_df[1]))
        self.n_relation = len(self.relation_dict)
        if self.rev_set>0: self.n_relation *= 2
        print('#relation: {}'.format(self.n_relation))

    def load_triples(self):
        training_file = 'train.txt'
        validation_file = 'valid.txt'
        test_file = 'test.txt'
        print('-----Loading training triples-----')
        training_df = pd.read_table(os.path.join(self.data_dir, training_file), header=None)
        training_df = np.array(training_df).tolist()
        for triple in training_df:
            end_sec = time.mktime(time.strptime(triple[3], '%Y-%m-%d'))
            day = int((end_sec - self.start_sec) / (self.gran*24 * 60 * 60))
            self.training_triples.append([self.entity_dict[triple[0]],self.entity_dict[triple[2]],self.relation_dict[triple[1]],day])
            self.training_facts.append([self.entity_dict[triple[0]],self.entity_dict[triple[2]],self.relation_dict[triple[1]],triple[3],0])
            if self.rev_set>0: self.training_triples.append([self.entity_dict[triple[2]],self.entity_dict[triple[0]],self.relation_dict[triple[1]]+self.n_relation//2,day])

        self.n_training_triple = len(self.training_triples)
        print('#training triple: {}'.format(self.n_training_triple))
        print('-----Loading validation triples-----')
        validation_df = pd.read_table(os.path.join(self.data_dir, validation_file), header=None)
        validation_df = np.array(validation_df).tolist()
        for triple in validation_df:
            end_sec = time.mktime(time.strptime(triple[3], '%Y-%m-%d'))
            day = int((end_sec - self.start_sec) / (self.gran*24 * 60 * 60))
            self.validation_triples.append([self.entity_dict[triple[0]],self.entity_dict[triple[2]],self.relation_dict[triple[1]],day])
            self.validation_facts.append([self.entity_dict[triple[0]],self.entity_dict[triple[2]],self.relation_dict[triple[1]],triple[3],0])

        self.n_validation_triple = len(self.validation_triples)
        print('#validation triple: {}'.format(self.n_validation_triple))
        print('-----Loading test triples------')
        test_df = pd.read_table(os.path.join(self.data_dir, test_file), header=None)
        test_df = np.array(test_df).tolist()
        for triple in test_df:
            end_sec = time.mktime(time.strptime(triple[3], '%Y-%m-%d'))
            day = int((end_sec - self.start_sec) / (self.gran*24 * 60 * 60))
            self.test_triples.append(
                    [self.entity_dict[triple[0]], self.entity_dict[triple[2]], self.relation_dict[triple[1]], day])
            self.test_facts.append([self.entity_dict[triple[0]],self.entity_dict[triple[2]],self.relation_dict[triple[1]],triple[3],0])

        self.n_test_triple = len(self.test_triples)
        print('#test triple: {}'.format(self.n_test_triple))


    def load_filters(self):
        print("creating filtering lists")
        to_skip = {'lhs': defaultdict(set), 'rhs': defaultdict(set)}
        facts_pool = [self.training_facts,self.validation_facts,self.test_facts]
        for facts in facts_pool:
            for fact in facts:
                to_skip['lhs'][(fact[1], fact[2],fact[3], fact[4])].add(fact[0])  # left prediction
                to_skip['rhs'][(fact[0], fact[2],fact[3], fact[4])].add(fact[1])  # right prediction
                
        for kk, skip in to_skip.items():
            for k, v in skip.items():
                self.to_skip_final[kk][k] = sorted(list(v))
        print("data preprocess completed")
        
        


In [3]:
class KnowledgeGraphYG:
    def __init__(self, data_dir, count=300, rev_set=0):
        self.data_dir = data_dir
        self.entity_dict = {}
        self.entities = []
        self.relation_dict = {}
        self.n_entity = 0
        self.n_relation = 0
        self.training_triples = []  # list of triples in the form of (h, t, r)
        self.validation_triples = []
        self.test_triples = []
        self.training_facts = []
        self.validation_facts = []
        self.test_facts = []
        self.n_training_triple = 0
        self.n_validation_triple = 0
        self.n_test_triple = 0
        self.n_time = 0
        self.start_year= -500
        self.end_year = 3000
        self.year_class=[]
        self.year2id = dict()
        self.rev_set = rev_set
        self.fact_count = count
        self.to_skip_final = {'lhs': {}, 'rhs': {}}
        '''load dicts and triples'''
        self.time_list()
        self.load_dicts()
        self.load_triples()
        self.load_filters()
        '''construct pools after loading'''
        # self.training_triple_pool = set(self.training_triples)
        # self.golden_triple_pool = set(self.training_triples) | set(self.validation_triples) | set(self.test_triples)

    def load_dicts(self):
        entity_dict_file = 'entity2id.txt'
        relation_dict_file = 'relation2id.txt'
        print('-----Loading entity dict-----')
        entity_df = pd.read_table(os.path.join(self.data_dir, entity_dict_file), header=None)
        self.entity_dict = dict(zip(entity_df[0], entity_df[1]))
        self.n_entity = len(self.entity_dict)
        self.entities = list(self.entity_dict.values())
        print('#entity: {}'.format(self.n_entity))
        print('-----Loading relation dict-----')
        relation_df = pd.read_table(os.path.join(self.data_dir, relation_dict_file), header=None)
        self.relation_dict = dict(zip(relation_df[0], relation_df[1]))
        self.n_relation = len(self.relation_dict)
        if self.rev_set>0: self.n_relation *= 2
        print('#relation: {}'.format(self.n_relation))

    def time_list(self):
        training_file = 'train.txt'
        validation_file = 'valid.txt'
        test_file = 'test.txt'
        triple_file = 'triple2id.txt'
        training_df = pd.read_table(os.path.join(self.data_dir, training_file), header=None)
        training_df = np.array(training_df).tolist()
        validation_df = pd.read_table(os.path.join(self.data_dir, validation_file), header=None)
        validation_df = np.array(validation_df).tolist()
        test_df = pd.read_table(os.path.join(self.data_dir, test_file), header=None)
        test_df = np.array(test_df).tolist()
 #       triple_df = pd.read_table(os.path.join(self.data_dir, triple_file), header=None)
 #       triple_df = np.array(triple_df).tolist()
        triple_df = np.concatenate([training_df,validation_df,test_df],axis=0)
        n=0
        
        year_list=[]
        for triple in triple_df:
            n+=1
            if triple[3][0]=='-':
                start = -int(triple[3].split('-')[1])
                year_list.append(start)
            else:
                start = triple[3].split('-')[0]
                if start =='####':
                    start = self.start_year
                else:
                    start = start.replace('#', '0')
                    start = int(start)
                    year_list.append(start)


            if triple[4][0]=='-':
                end = -int(triple[4].split('-')[1])
                year_list.append(end)
            else:
                end = triple[4].split('-')[0]
                if end =='####':
                    end = self.end_year
                else:
                    end = end.replace('#', '0')
                    end = int(end)
                    year_list.append(end)

#            for i in range(start,end):
#                 year_list.append(i)
            


        year_list.sort()

        freq=ddict(int)
        for year in year_list:
            freq[year]=freq[year]+1

        year_class=[]
        count=0
        for key in sorted(freq.keys()):
            count += freq[key]
            if count>=self.fact_count:
                year_class.append(key)
                count=0
        year_class[-1]=year_list[-1]

        year2id = dict()
        prev_year = year_list[0]
        i = 0
        for i, yr in enumerate(year_class): 
            year2id[(prev_year, yr)] = i
 #           if i>2: 
            prev_year = yr + 1

        self.year2id=year2id
        self.year_class = year_class
        self.n_time = len(self.year2id.keys())


    def load_triples(self):
        training_file = 'train.txt'
        validation_file = 'valid.txt'
        test_file = 'test.txt'
        print('-----Loading training triples-----')
        training_df = pd.read_table(os.path.join(self.data_dir, training_file), header=None)
        training_df = np.array(training_df).tolist()
        for triple in training_df:
            if triple[3].split('-')[0] == '####':
                start=self.start_year
                start_idx = 0
            elif triple[3][0] == '-':
                start=-int(triple[3].split('-')[1].replace('#', '0'))
            elif triple[3][0] != '-':
                start = int(triple[3].split('-')[0].replace('#','0'))
            
            if triple[4].split('-')[0] == '####':
                end = self.end_year
                end_idx = self.n_time-1
            elif triple[4][0] == '-':
                end =-int(triple[4].split('-')[1].replace('#', '0'))
            elif triple[4][0] != '-':
                end = int(triple[4].split('-')[0].replace('#','0'))
        
            for key, time_idx in sorted(self.year2id.items(), key=lambda x:x[1]):
                if start>=key[0] and start<=key[1]:
                    start_idx = time_idx
                if end>=key[0] and end<=key[1]:
                    end_idx = time_idx


            self.training_triples.append([triple[0],triple[2],triple[1],start_idx,end_idx])
            self.training_facts.append([triple[0],triple[2],triple[1],triple[3],triple[4]])
            if self.rev_set>0: self.training_triples.append([triple[2],triple[0],triple[1]+self.n_relation//2,start_idx,end_idx])
            # for day_idx in range(start_idx,end_idx+1):
            #     try:
            #         self.training_triples.append([triple[0],triple[2],triple[1],day_idx])
            #     except KeyError:
            #         continue
        self.n_training_triple = len(self.training_triples)
        print('#training triple: {}'.format(self.n_training_triple))
        print('-----Loading validation triples-----')
        validation_df = pd.read_table(os.path.join(self.data_dir, validation_file), header=None)
        validation_df = np.array(validation_df).tolist()
        for triple in validation_df:
            if triple[3].split('-')[0] == '####':
                start=self.start_year
                start_idx = 0
            elif triple[3][0] == '-':
                start=-int(triple[3].split('-')[1].replace('#', '0'))
            elif triple[3][0] != '-':
                start = int(triple[3].split('-')[0].replace('#','0'))
            
            if triple[4].split('-')[0] == '####':
                end = self.end_year
                end_idx = self.n_time-1
            elif triple[4][0] == '-':
                end =-int(triple[4].split('-')[1].replace('#', '0'))
            elif triple[4][0] != '-':
                end = int(triple[4].split('-')[0].replace('#','0'))
        
            for key, time_idx in sorted(self.year2id.items(), key=lambda x:x[1]):
                if start>=key[0] and start<=key[1]:
                    start_idx = time_idx
                if end>=key[0] and end<=key[1]:
                    end_idx = time_idx
            
                    
            self.validation_triples.append([triple[0],triple[2],triple[1],start_idx,end_idx])
            self.validation_facts.append([triple[0],triple[2],triple[1],triple[3],triple[4]])
            # for day_idx in range(start_idx,end_idx+1):
            #     try:
            #         self.validation_triples.append([triple[0],triple[2],triple[1],day_idx])
            #     except KeyError:
            #         continue
        self.n_validation_triple = len(self.validation_triples)
        print('#validation triple: {}'.format(self.n_validation_triple))
        print('-----Loading test triples------')
        test_df = pd.read_table(os.path.join(self.data_dir, test_file), header=None)
        test_df = np.array(test_df).tolist()
        for triple in test_df:
            if triple[3].split('-')[0] == '####':
                start=self.start_year
                start_idx = 0
            elif triple[3][0] == '-':
                start=-int(triple[3].split('-')[1].replace('#', '0'))
            elif triple[3][0] != '-':
                start = int(triple[3].split('-')[0].replace('#','0'))
            
            if triple[4].split('-')[0] == '####':
                end = self.end_year
                end_idx = self.n_time-1
            elif triple[4][0] == '-':
                end =-int(triple[4].split('-')[1].replace('#', '0'))
            elif triple[4][0] != '-':
                end = int(triple[4].split('-')[0].replace('#','0'))
        
            for key, time_idx in sorted(self.year2id.items(), key=lambda x:x[1]):
                if start>=key[0] and start<=key[1]:
                    start_idx = time_idx
                if end>=key[0] and end<=key[1]:
                    end_idx = time_idx
                    

            self.test_triples.append([triple[0],triple[2],triple[1],start_idx,end_idx])
            self.test_facts.append([triple[0],triple[2],triple[1],triple[3],triple[4]])
            # for day_idx in range(start_idx,end_idx+1):
            #     try:
            #         self.test_triples.append([triple[0],triple[2],triple[1],day_idx])
            #     except KeyError:
            #         continue
        self.n_test_triple = len(self.test_triples)
        print('#test triple: {}'.format(self.n_test_triple))

    def load_filters(self):
        print("creating filtering lists")
        to_skip = {'lhs': defaultdict(set), 'rhs': defaultdict(set)}
        facts_pool = [self.training_facts,self.validation_facts,self.test_facts]
        for facts in facts_pool:
            for fact in facts:
                to_skip['lhs'][(fact[1], fact[2],fact[3],fact[4])].add(fact[0])  # left prediction
                to_skip['rhs'][(fact[0], fact[2],fact[3],fact[4])].add(fact[1])  # right prediction
                
        for kk, skip in to_skip.items():
            for k, v in skip.items():
                self.to_skip_final[kk][k] = sorted(list(v))
        print("data preprocess completed")


## Models

In [4]:
def quat_dot_prod(q1, q2):
    return torch.sum(q1* q2, -1)

def hamilton_quat_prod(q1, q2):
    s_a, x_a, y_a, z_a = torch.chunk(q1, 4, dim=-1)
    s_b, x_b, y_b, z_b = torch.chunk(q2, 4, dim=-1)
    
    A = s_a * s_b - x_a * x_b - y_a * y_b - z_a * z_b
    B = s_a * x_b + s_b * x_a + y_a * z_b - y_b * z_a
    C = s_a * y_b + s_b * y_a + z_a * x_b - z_b * x_a
    D = s_a * z_b + s_b * z_a + x_a * y_b - x_b * y_a

    return torch.cat([A, B, C, D], dim=-1)

def quat_norm(q):
    s_b, x_b, y_b, z_b = torch.chunk(q, 4, dim=-1)
    denominator_b = torch.sqrt(s_b ** 2 + x_b ** 2 + y_b ** 2 + z_b ** 2)
    s_b = s_b / denominator_b
    x_b = x_b / denominator_b
    y_b = y_b / denominator_b
    z_b = z_b / denominator_b
    return torch.cat([s_b, x_b, y_b, z_b], dim=-1)

def quaternion_init(in_features, out_features, criterion='he'):
    fan_in = in_features
    fan_out = out_features

    if criterion == 'glorot':
        s = 1. / np.sqrt(2 * (fan_in + fan_out))
    elif criterion == 'he':
        s = 1. / np.sqrt(2 * fan_in)
    else:
        raise ValueError('Invalid criterion: ', criterion)
    rng = RandomState(123)

    # Generating randoms and purely imaginary quaternions :
    kernel_shape = (in_features, out_features)

    number_of_weights = np.prod(kernel_shape)
    v_i = np.random.uniform(0.0, 1.0, number_of_weights)
    v_j = np.random.uniform(0.0, 1.0, number_of_weights)
    v_k = np.random.uniform(0.0, 1.0, number_of_weights)

    # Purely imaginary quaternions unitary
    for i in range(0, number_of_weights):
        norm = np.sqrt(v_i[i] ** 2 + v_j[i] ** 2 + v_k[i] ** 2) + 0.0001
        v_i[i] /= norm
        v_j[i] /= norm
        v_k[i] /= norm
    v_i = v_i.reshape(kernel_shape)
    v_j = v_j.reshape(kernel_shape)
    v_k = v_k.reshape(kernel_shape)

    modulus = rng.uniform(low=-s, high=s, size=kernel_shape)
    phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape)

    weight_r = modulus * np.cos(phase)
    weight_i = modulus * v_i * np.sin(phase)
    weight_j = modulus * v_j * np.sin(phase)
    weight_k = modulus * v_k * np.sin(phase)

    return (weight_r, weight_i, weight_j, weight_k)

In [5]:
class TQUATDE(nn.Module):
    def __init__(self, kg, embedding_dim, batch_size, learning_rate, gran, gamma, n_day, gpu=True):
        super(TQUATDE, self).__init__()
        self.gpu = gpu
        self.kg = kg
        self.embedding_dim = embedding_dim
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.n_day = n_day
        self.gran = gran

        # Nets
        self.emb_E = torch.nn.Embedding(self.kg.n_entity, self.embedding_dim * 4)
        self.emb_R = torch.nn.Embedding(self.kg.n_relation*2, self.embedding_dim * 4)
        self.emb_R_trans = torch.nn.Embedding(self.kg.n_relation*2, self.embedding_dim * 4)
        self.emb_Time = torch.nn.Embedding(n_day, self.embedding_dim * 4)
        
        # Initialization
        self.init_weights()
        
        if self.gpu:
            self.cuda()
            
    def init_weights(self):
        r, i, j, k = quaternion_init(self.kg.n_entity, self.embedding_dim)
        r, i, j, k = torch.from_numpy(r), torch.from_numpy(i), torch.from_numpy(j), torch.from_numpy(k)
        vec1 = torch.cat([r, i, j, k], dim=1)
        self.emb_E.weight.data = vec1.type_as(self.emb_E.weight.data)

        s, x, y, z = quaternion_init(self.kg.n_relation*2, self.embedding_dim)
        s, x, y, z = torch.from_numpy(s), torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(z)
        vec2 = torch.cat([s, x, y, z], dim=1)
        self.emb_R.data = vec2.type_as(self.emb_R.weight.data)

        s, x, y, z = quaternion_init(self.kg.n_relation*2, self.embedding_dim)
        s, x, y, z = torch.from_numpy(s), torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(z)
        vec2 = torch.cat([s, x, y, z], dim=1)
        self.emb_R_trans.data = vec2.type_as(self.emb_R_trans.weight.data)

        r, i, j, k = quaternion_init(self.n_day, self.embedding_dim)
        r, i, j, k = torch.from_numpy(r), torch.from_numpy(i), torch.from_numpy(j), torch.from_numpy(k)
        vec1 = torch.cat([r, i, j, k], dim=1)
        self.emb_Time.weight.data = vec1.type_as(self.emb_Time.weight.data)
        
    def _calc(self, h, r):
        return hamilton_quat_prod(h, quat_norm(r))
    
    def _transfer(self, x, x_transfer, r_transfer):
        ent_transfer = self._calc(x, x_transfer)
        ent_rel_transfer = self._calc(ent_transfer, r_transfer)

        return ent_rel_transfer
    
    def forward(self, X):
        h_i, t_i, r_i, d_i = X[:, 0].astype(np.int64), X[:, 1].astype(np.int64), X[:, 2].astype(np.int64), X[:, 3].astype(np.int64)//self.gran

        if self.gpu:
            h_i = Variable(torch.from_numpy(h_i).cuda())
            t_i = Variable(torch.from_numpy(t_i).cuda())
            r_i = Variable(torch.from_numpy(r_i).cuda())
            d_i = Variable(torch.from_numpy(d_i).cuda())
        else:
            h_i = Variable(torch.from_numpy(h_i))
            t_i = Variable(torch.from_numpy(t_i))
            r_i = Variable(torch.from_numpy(r_i))
            d_i = Variable(torch.from_numpy(d_i))
        h = self.emb_E(h_i)
        r = self.emb_R(r_i)
        t = self.emb_E(t_i)
#         print('h.mean: ',h.mean().item())
#         print('r.mean: ',r.mean().item())
#         print('t.mean: ',t.mean().item())
        # (h, r, t) transfer vector
        time_transfer = self.emb_Time(d_i)
        r_transfer = self.emb_R_trans(r_i)

        h1 = self._transfer(h, time_transfer, r_transfer)
        t1 = self._transfer(t, time_transfer, r_transfer)
        # multiplication as QuatE
        hr = self._calc(h1, r)
        # Inner product as QuatE
        score = quat_dot_prod(hr, t1)
#         print('score.mean: ',score.mean().item())
        return score

    def normalize_embeddings(self):
        self.emb_E_real.weight.data.renorm_(p=2, dim=0, maxnorm=1)
        self.emb_E_img.weight.data.renorm_(p=2, dim=0, maxnorm=1)

    def log_rank_loss(self, y_pos, y_neg, temp=0):
#         print('y_pos.shape: ',y_pos.shape, '| y_neg.shape: ',y_neg.shape)
        M = y_pos.size(0)
        N = y_neg.size(0)
        y_pos = self.gamma-y_pos
        y_neg = self.gamma-y_neg
        C = int(N / M)
        y_neg = y_neg.view(C, -1).transpose(0, 1)
        p = F.softmax(temp * y_neg)
        loss_pos = torch.sum(F.softplus(-1 * y_pos))
        loss_neg = torch.sum(p * F.softplus(y_neg))
        loss = (loss_pos + loss_neg) / 2 / M
        if self.gpu:
            loss = loss.cuda()
        return loss


    def rank_loss(self, y_pos, y_neg):
        M = y_pos.size(0)
        N = y_neg.size(0)
        C = int(N / M)
        y_pos = y_pos.repeat(C)
        if self.gpu:
            target = Variable(torch.from_numpy(-np.ones(N, dtype=np.float32))).cuda()
        else:
            target = Variable(torch.from_numpy(-np.ones(N, dtype=np.float32))).cpu()
        loss = nn.MarginRankingLoss(margin=self.gamma)
        loss = loss(y_pos, y_neg, target)
        return loss



    def rank_left(self, X, facts, kg, timedisc, rev_set=0):
        rank = []
        with torch.no_grad():
            if timedisc:
                for triple, fact in zip(X, facts):
                    X_i = np.ones([self.kg.n_entity, 4])
                    Xe_i = np.ones([self.kg.n_entity, 4])
                    for i in range(0, self.kg.n_entity):
                        X_i[i, 0] = i
                        X_i[i, 1] = triple[1]
                        X_i[i, 2] = triple[2] if triple[3]>=0 else triple[2]+self.kg.n_relation
                        X_i[i, 3] = triple[3] if triple[3]>=0 else triple[4]
                        Xe_i[i, 0] = i
                        Xe_i[i, 1] = triple[1]
                        Xe_i[i, 2] = triple[2]+self.kg.n_relation if triple[4]>=0 else triple[2]
                        Xe_i[i, 3] = triple[4] if triple[4]>=0 else triple[3]
                    i_score = self.forward(X_i)+self.forward(Xe_i)
                    if rev_set>0:
                        X_rev = np.ones([self.kg.n_entity,4])
                        Xe_rev = np.ones([self.kg.n_entity,4])
                        for i in range(0, self.kg.n_entity):
                            X_rev[i, 0] = triple[1]
                            X_rev[i, 1] = i
                            X_rev[i, 2] = triple[2]+self.kg.n_relation//2 if triple[3]>=0 else triple[2]+self.kg.n_relation+self.kg.n_relation//2
                            X_rev[i, 3] = triple[3] if triple[3]>=0 else triple[4]
                            Xe_rev[i, 0] = triple[1]
                            Xe_rev[i, 1] = i
                            Xe_rev[i, 2] = triple[2]+self.kg.n_relation//2+self.kg.n_relation if triple[4]>=0 else triple[2]+self.kg.n_relation//2
                            Xe_rev[i, 3] = triple[4] if triple[4]>=0 else triple[3]
                        i_score = i_score + self.forward(X_rev).view(-1)+self.forward(Xe_rev).view(-1)
                    if self.gpu:
                        i_score = i_score.cuda()
        
                    filter_out = kg.to_skip_final['lhs'][(fact[1], fact[2],fact[3], fact[4])]                            
                    target = i_score[int(triple[0])].clone()
                    i_score[filter_out]=1e6 
                    rank_triple=torch.sum((i_score < target).float()).cpu().item()+1
                    rank.append(rank_triple)
                        

            else:
                for triple, fact in zip(X, facts):
                    X_i = np.ones([self.kg.n_entity, 4])
                    for i in range(0, self.kg.n_entity):
                        X_i[i, 0] = i
                        X_i[i, 1] = triple[1]
                        X_i[i, 2] = triple[2]
                        X_i[i, 3] = triple[3]
                    i_score = self.forward(X_i)
                    if rev_set>0:
                        X_rev = np.ones([self.kg.n_entity,4])
                        for i in range(0, self.kg.n_entity):
                            X_rev[i, 0] = triple[1]
                            X_rev[i, 1] = i
                            X_rev[i, 2] = triple[2]+self.kg.n_relation//2
                            X_rev[i, 3] = triple[3]
                        i_score = i_score + self.forward(X_rev).view(-1)
                    if self.gpu:
                        i_score = i_score.cuda()
        
                    filter_out = kg.to_skip_final['lhs'][(fact[1], fact[2],fact[3], fact[4])]                            
                    target = i_score[int(triple[0])].clone()
                    i_score[filter_out]=1e6 
                    rank_triple=torch.sum((i_score < target).float()).cpu().item()+1
                    rank.append(rank_triple)

        return rank

    def rank_right(self, X, facts, kg, timedisc, rev_set=0):
        rank = []
        with torch.no_grad():
            if timedisc:
                for triple, fact in zip(X, facts):
                    X_i = np.ones([self.kg.n_entity, 4])
                    Xe_i = np.ones([self.kg.n_entity, 4])
                    for i in range(0, self.kg.n_entity):
                        X_i[i, 0] = triple[0]
                        X_i[i, 1] = i
                        X_i[i, 2] = triple[2] if triple[3]>=0 else triple[2]+self.kg.n_relation
                        X_i[i, 3] = triple[3] if triple[3]>=0 else triple[4]
                        Xe_i[i, 0] = triple[0] 
                        Xe_i[i, 1] = i
                        Xe_i[i, 2] = triple[2]+self.kg.n_relation if triple[4]>=0 else triple[2]
                        Xe_i[i, 3] = triple[4] if triple[4]>=0 else triple[3]
                    i_score = self.forward(X_i)+self.forward(Xe_i)
                    if rev_set>0: 
                        X_rev = np.ones([self.kg.n_entity,4])
                        Xe_rev = np.ones([self.kg.n_entity,4])
                        for i in range(0, self.kg.n_entity):
                            X_rev[i, 0] = i
                            X_rev[i, 1] = triple[0]
                            X_rev[i, 2] = triple[2]+self.kg.n_relation//2 if triple[3]>=0 else triple[2]+self.kg.n_relation+self.kg.n_relation//2
                            X_rev[i, 3] = triple[3] if triple[3]>=0 else triple[4]
                            Xe_rev[i, 0] = i
                            Xe_rev[i, 1] = triple[0]
                            Xe_rev[i, 2] = triple[2]+self.kg.n_relation//2+self.kg.n_relation if triple[4]>=0 else triple[2]+self.kg.n_relation//2
                            Xe_rev[i, 3] = triple[4] if triple[4]>=0 else triple[3]
                        i_score = i_score + self.forward(X_rev).view(-1)+ self.forward(Xe_rev).view(-1)
                    if self.gpu:
                        i_score = i_score.cuda()
        
                    filter_out = kg.to_skip_final['rhs'][(fact[0], fact[2],fact[3], fact[4])]       
                    target = i_score[int(triple[1])].clone()
                    i_score[filter_out]=1e6
                    rank_triple=torch.sum((i_score < target).float()).cpu().item()+1
        
                    rank.append(rank_triple)
                    
            else:
                for triple, fact in zip(X, facts):
                    X_i = np.ones([self.kg.n_entity, 4])
                    for i in range(0, self.kg.n_entity):
                        X_i[i, 0] = triple[0]
                        X_i[i, 1] = i
                        X_i[i, 2] = triple[2]
                        X_i[i, 3] = triple[3]
                    i_score = self.forward(X_i)
                    if rev_set>0: 
                        X_rev = np.ones([self.kg.n_entity,4])
                        for i in range(0, self.kg.n_entity):
                            X_rev[i, 0] = i
                            X_rev[i, 1] = triple[0]
                            X_rev[i, 2] = triple[2]+self.kg.n_relation//2
                            X_rev[i, 3] = triple[3]
                        i_score = i_score + self.forward(X_rev).view(-1)
                    if self.gpu:
                        i_score = i_score.cuda()
        
                    filter_out = kg.to_skip_final['rhs'][(fact[0], fact[2],fact[3], fact[4])]       
                    target = i_score[int(triple[1])].clone()
                    i_score[filter_out]=1e6
                    rank_triple=torch.sum((i_score < target).float()).cpu().item()+1
        
                    rank.append(rank_triple)

        return rank

    def timepred(self, X):
        rank = []
        with torch.no_grad():
            for triple in X:
                X_i = np.ones([self.kg.n_day, len(triple)])
                for i in range(self.kg.n_day):
                    X_i[i, 0] = triple[0]
                    X_i[i, 1] = triple[1]
                    X_i[i, 2] = triple[2]
                    X_i[i, 3:] = self.kg.time_dict[i]
                i_score = self.forward(X_i)
                if self.gpu:
                    i_score = i_score.cuda()
    
                target = i_score[triple[3]]           
                rank_triple=torch.sum((i_score < target).float()).cpu().item()+1
                rank.append(rank_triple)

        return rank

## Train Loop

In [11]:

def mean_rank(rank):
    m_r = 0
    N = len(rank)
    for i in rank:
        m_r = m_r + i / N

    return m_r


def mrr(rank):
    mrr = 0
    N = len(rank)
    for i in rank:
        mrr = mrr + 1 / i / N

    return mrr


def hit_N(rank, N):
    hit = 0
    for i in rank:
        if i <= N:
            hit = hit + 1

    hit = hit / len(rank)

    return hit

def get_minibatches(X, mb_size, shuffle=True):
    """
    Generate minibatches from given dataset for training.

    Params:
    -------
    X: np.array of M x 3
        Contains the triplets from dataset. The entities and relations are
        translated to its unique indices.

    mb_size: int
        Size of each minibatch.

    shuffle: bool, default True
        Whether to shuffle the dataset before dividing it into minibatches.

    Returns:
    --------
    mb_iter: generator
        Example usage:
        --------------
        mb_iter = get_minibatches(X_train, mb_size)
        for X_mb in mb_iter:
            // do something with X_mb, the minibatch
    """
    X_shuff = X.copy()
    if shuffle:
        X_shuff = skshuffle(X_shuff)

    for i in range(0, X_shuff.shape[0], mb_size):
        yield X_shuff[i:i + mb_size]


def sample_negatives(X, C, kg):
    """
    Perform negative sampling by corrupting head or tail of each triplets in
    dataset.

    Params:
    -------
    X: int matrix of M x 3, where M is the (mini)batch size
        First column contains index of head entities.
        Second column contains index of relationships.
        Third column contains index of tail entities.

    n_e: int
        Number of entities in dataset.

    Returns:
    --------
    X_corr: int matrix of M x 3, where M is the (mini)batch size
        Similar to input param X, but at each column, either first or third col
        is subtituted with random entity.
        
    """
    M = X.shape[0]
    X_corr = X
    for i in range(C-1):
        X_corr = np.concatenate((X_corr,X),0)
    X_corr[:int(M*C/2),0]=torch.randint(kg.n_entity,[int(M*C/2)])        
    X_corr[int(M*C/2):,1]=torch.randint(kg.n_entity,[int(M*C/2)]) 

    return X_corr


def sample_negatives_t(X, C, n_day):
    """
    Perform negative sampling by corrupting head or tail of each triplets in
    dataset.

    Params:
    -------
    X: int matrix of M x 4, where M is the (mini)batch size
        First column contains index of head entities.
        Second column contains index of relationships.
        Third column contains index of tail entities.

    n_e: int
        Number of entities in dataset.

    Returns:
    --------
    X_corr: int matrix of M x 4, where M is the (mini)batch size
        Similar to input param X, but at each column, either first or third col
        is subtituted with random entity.
    """
    M = X.shape[0]
    X_corr = X
    for i in range(C-1):
        X_corr = torch.cat((X_corr,X),0)
    X_corr[:,3]=torch.randint(n_day,[int(M*C)])        


    return X_corr



def train(task ='LinkPrediction',
          modelname='ATISE',
          data_dir='yago',
          dim=500,
          batch=512,
          lr=0.1,
          max_epoch=5000,
          min_epoch=250,
          gamma=1,
          negsample_num=10,
          timedisc = 0,
          lossname = 'logloss',
          cmin = 0.001,
          cuda_able = True,
          rev_set = 1,
          temp = 0.5,
          gran = 7,
          count = 300
          ):

    randseed = 9999
    np.random.seed(randseed)
    torch.manual_seed(randseed)


    """
    Data Loading
    """
    if data_dir == 'yago' or data_dir == 'wikidata':
        kg = KnowledgeGraphYG(data_dir=data_dir, count = count,rev_set = rev_set)
        n_day = kg.n_time
        min_epoch=50
    elif data_dir=='icews14':
        n_day = 365
        kg = KnowledgeGraph(data_dir=data_dir,gran=gran,rev_set = rev_set)
    elif data_dir == 'icews05-15':
        n_day = 4017
        kg = KnowledgeGraph(data_dir=data_dir,gran=gran,rev_set = rev_set)      


    """
    Create a model
    """

    if modelname== 'TQUATDE':
        model = TQUATDE(kg, embedding_dim=dim, batch_size=batch, learning_rate=lr, gamma=gamma, gran=gran, n_day=kg.n_time,gpu=cuda_able)
    
    solver = torch.optim.Adam(model.parameters(), model.learning_rate)
    optimizer = 'Adam'
    
    if timedisc == 0 or timedisc ==2:
        train_pos = np.array(kg.training_triples)
        validation_pos = np.array(kg.validation_triples)
        test_pos = np.array(kg.test_triples)
        
    elif timedisc == 1:
        train_pos = []
        validation_pos = []
        test_pos = []
        for fact in kg.training_triples:
            for time_index in range(fact[3],fact[4]+1):
                train_pos.append([fact[0], fact[1], fact[2], time_index])
        train_pos = np.array(train_pos)
       # for fact in kg.validation_triples:
       #     for time_index in range(fact[3],fact[4]+1):
       #         validation_pos.append([fact[0], fact[1], fact[2], time_index])
        validation_pos = np.array(kg.validation_triples)
       # for fact in kg.test_triples:
       #     for time_index in range(fact[3],fact[4]+1):
       #         test_pos.append([fact[0], fact[1], fact[2], time_index])
       # test_pos = np.array(test_pos)        
        test_pos = np.array(kg.test_triples)

        
    losses = []
    mrr_std = 0
    C = negsample_num
    patience = 0
    path = os.path.join(data_dir,modelname,'timediscrete{:.0f}/dim{:.0f}/lr{:.4f}/neg_num{:.0f}/{:.0f}day/gamma{:.0f}/cmin{:.4f}'
                        .format(timedisc,dim,lr,negsample_num,gran,gamma,cmin))
    if timedisc: path = os.path.join(path,'{:.0f}count'.format(count))
    try: 
        os.makedirs(path)
    except:
        print('path existed')
        return
    
    
        
    """
    Training Process
    """
    for epoch in range(max_epoch):
        print('Epoch-{}'.format(epoch + 1))
        print('————————————————')
        it = 0
        train_triple = list(get_minibatches(train_pos, batch, shuffle=True))
        for iter_triple in train_triple:
            if iter_triple.shape[0] < batch:
                break
            start = time.time()
            if task=='TimePrediction':
                iter_neg = sample_negatives_t(iter_triple, C, n_day)
            else:
                iter_neg = sample_negatives(iter_triple, C, kg)
            if timedisc == 2:
                end_miss = np.where(iter_triple[:,4:5]<0)[0]
                start_miss = np.where(iter_triple[:,3:4]<0)[0]
                neg_end_miss = np.where(iter_neg[:,4:5]<0)[0]
                neg_start_miss = np.where(iter_neg[:,3:4]<0)[0]
                
                
                iter_triple_e = np.delete(iter_triple,3,1)
                iter_triple = np.delete(iter_triple,4,1)
                
                iter_triple_e[:,2:3] += kg.n_relation

                iter_triple_e[end_miss,:]=iter_triple[end_miss,:]
                iter_triple[start_miss,:]=iter_triple_e[start_miss,:]
                
                
                iter_neg_e = np.delete(iter_neg,3,1)
                iter_neg = np.delete(iter_neg,4,1)
                
                iter_neg_e[:,2:3] += kg.n_relation
                
                iter_neg_e[neg_end_miss,:]=iter_neg[neg_end_miss,:]
                iter_neg[neg_start_miss,:]=iter_neg_e[neg_start_miss,:]
                

            pos_score = model.forward(iter_triple)
            neg_score = model.forward(iter_neg)
#             print('pos_score.mean: ',pos_score.mean(), '| neg_score.mean: ',neg_score.mean())
            if timedisc ==2:
                pos_score += model.forward(iter_triple_e)
                neg_score += model.forward(iter_neg_e)
                
            if lossname == 'logloss':
                loss = model.log_rank_loss(pos_score, neg_score,temp=temp)
            else:
                loss = model.rank_loss(pos_score, neg_score)
            losses.append(loss.item())

            solver.zero_grad()
            loss.backward()
            solver.step()


#             if lossname == 'marginloss':
#                 model.normalize_embeddings()
#             if modelname == 'ATISE':
#                 model.regularization_embeddings()

            end = time.time()

            if it % 33 == 0:
                print('Iter-{}; loss: {:.4f};time per batch:{:.4f}s'.format(it, loss.item(), end - start))

            it += 1

        """
        Evaluation for Link Prediction
        """

        if ((epoch+1)//min_epoch>epoch//min_epoch and epoch < max_epoch) :
            if task == 'LinkPrediction':
                rank = model.rank_left(validation_pos,kg.validation_facts,kg,timedisc,rev_set=rev_set)
                rank_right = model.rank_right(validation_pos,kg.validation_facts,kg,timedisc,rev_set=rev_set)
                rank = rank + rank_right
            else:
                rank = model.timepred(validation_pos)

            m_rank = mean_rank(rank)
            mean_rr = mrr(rank)
            hit_1 = hit_N(rank, 1)
            hit_3 = hit_N(rank, 3)
            hit_5 = hit_N(rank, 5)
            hit_10 = hit_N(rank, 10)
            print('validation results:')
            print('Mean Rank: {:.0f}'.format(m_rank))
            print('Mean RR: {:.4f}'.format(mean_rr))
            print('Hit@1: {:.4f}'.format(hit_1))
            print('Hit@3: {:.4f}'.format(hit_3))
            print('Hit@5: {:.4f}'.format(hit_5))
            print('Hit@10: {:.4f}'.format(hit_10))
            f = open(os.path.join(path, 'result{:.0f}.txt'.format(epoch)), 'w')
            f.write('Mean Rank: {:.0f}\n'.format(m_rank))
            f.write('Mean RR: {:.4f}\n'.format(mean_rr))
            f.write('Hit@1: {:.4f}\n'.format(hit_1))
            f.write('Hit@3: {:.4f}\n'.format(hit_3))
            f.write('Hit@5: {:.4f}\n'.format(hit_5))
            f.write('Hit@10: {:.4f}\n'.format(hit_10))
            for loss in losses:
                f.write(str(loss))
                f.write('\n')
            f.close()
            if mean_rr < mrr_std and patience<3:
                patience+=1
            elif (mean_rr < mrr_std and patience>=3) or epoch==max_epoch-1:
                if epoch == max_epoch-1:
                    torch.save(model.state_dict(), os.path.join(path, 'params.pkl'))
                model.load_state_dict(torch.load(os.path.join(path,'params.pkl')))
                if task == 'LinkPrediction':
                    rank = model.rank_left(test_pos,kg.test_facts,kg,timedisc,rev_set=rev_set)
                    rank_right = model.rank_right(test_pos,kg.test_facts,kg,timedisc,rev_set=rev_set)
                    rank = rank + rank_right
                else:
                    rank = model.timepred(test_pos)


                m_rank = mean_rank(rank)
                mean_rr = mrr(rank)
                hit_1 = hit_N(rank, 1)
                hit_3 = hit_N(rank, 3)
                hit_5 = hit_N(rank, 5)
                hit_10 = hit_N(rank, 10)
                print('test result:')
                print('Mean Rank: {:.0f}'.format(m_rank))
                print('Mean RR: {:.4f}'.format(mean_rr))
                print('Hit@1: {:.4f}'.format(hit_1))
                print('Hit@3: {:.4f}'.format(hit_3))
                print('Hit@5: {:.4f}'.format(hit_5))
                print('Hit@10: {:.4f}'.format(hit_10))
                if epoch == max_epoch-1:
                    f = open(os.path.join(path, 'test_result{:.0f}.txt'.format(epoch)), 'w')
                else:
                    f = open(os.path.join(path, 'test_result{:.0f}.txt'.format(epoch)), 'w')
                f.write('Mean Rank: {:.0f}\n'.format(m_rank))
                f.write('Mean RR: {:.4f}\n'.format(mean_rr))
                f.write('Hit@1: {:.4f}\n'.format(hit_1))
                f.write('Hit@3: {:.4f}\n'.format(hit_3))
                f.write('Hit@5: {:.4f}\n'.format(hit_5))
                f.write('Hit@10: {:.4f}\n'.format(hit_10))
                for loss in losses:
                    f.write(str(loss))
                    f.write('\n')
                f.close()
                break
            if mean_rr>=mrr_std:
                
                torch.save(model.state_dict(), os.path.join(path, 'params.pkl'))
                mrr_std = mean_rr
                patience = 0




## Main

In [12]:
class Params():
    def __init__(self):
        pass

In [13]:
args = Params()
args.task = 'LinkPrediction'
args.model = 'TQUATDE'
args.dataset = 'icews14'
args.max_epoch = 5000
args.dim = 200
args.batch = 512
args.lr = 0.001
args.gamma = 10
# args.eta = 10
args.timedisc = 0
args.cuda = True
args.loss = 'logloss'
# rank_loss logloss
args.cmin = 0.005
args.gran = 1
args.thre = 1
args.negsample_num = 30
args.temp = 0.8


In [14]:
def main(args):
    print(args)
    train(task = args.task,
          modelname = args.model,
          data_dir = args.dataset,
          dim = args.dim,
          batch = args.batch,
          min_epoch = 10,
          negsample_num = args.negsample_num,
          temp = args.temp,
          lr = args.lr,
          max_epoch = args.max_epoch,
          gamma = args.gamma,
          lossname = args.loss,
#           negsample_num = args.eta,
          timedisc = args.timedisc,
          cuda_able = args.cuda,
          cmin = args.cmin,
          gran = args.gran,
          count = args.thre
          )

In [None]:
main(args)

<__main__.Params object at 0x7f8eabeec460>
-----Loading entity dict-----
#entity: 7129
-----Loading relation dict-----
#relation: 460
-----Loading training triples-----
#training triple: 145652
-----Loading validation triples-----
#validation triple: 8941
-----Loading test triples------
#test triple: 8963
creating filtering lists
data preprocess completed
Epoch-1
————————————————
Iter-0; loss: 5.0000;time per batch:0.0489s


  p = F.softmax(temp * y_neg)


Iter-33; loss: 5.0000;time per batch:0.1137s
Iter-66; loss: 4.9986;time per batch:0.0463s
Iter-99; loss: 4.9901;time per batch:0.1083s
Iter-132; loss: 4.9648;time per batch:0.0451s
Iter-165; loss: 4.8958;time per batch:0.1082s
Iter-198; loss: 4.7648;time per batch:0.0451s
Iter-231; loss: 4.5160;time per batch:0.1078s
Iter-264; loss: 4.2039;time per batch:0.0519s
Epoch-2
————————————————
Iter-0; loss: 3.9156;time per batch:0.0456s
Iter-33; loss: 3.3675;time per batch:0.1083s
Iter-66; loss: 2.6393;time per batch:0.0449s
Iter-99; loss: 2.0277;time per batch:0.1089s
Iter-132; loss: 1.4450;time per batch:0.0449s
Iter-165; loss: 1.1632;time per batch:0.1083s
Iter-198; loss: 0.9638;time per batch:0.0449s
Iter-231; loss: 0.9231;time per batch:0.1072s
Iter-264; loss: 0.7415;time per batch:0.0449s
Epoch-3
————————————————
Iter-0; loss: 0.6828;time per batch:0.0447s
Iter-33; loss: 0.6645;time per batch:0.1083s
Iter-66; loss: 0.6266;time per batch:0.0445s
Iter-99; loss: 0.6289;time per batch:0.108

Iter-231; loss: 0.0760;time per batch:0.1083s
Iter-264; loss: 0.0764;time per batch:0.0472s
Epoch-20
————————————————
Iter-0; loss: 0.0651;time per batch:0.0451s
Iter-33; loss: 0.0715;time per batch:0.1083s
Iter-66; loss: 0.1052;time per batch:0.0449s
Iter-99; loss: 0.0700;time per batch:0.1079s
Iter-132; loss: 0.0597;time per batch:0.0455s
Iter-165; loss: 0.0750;time per batch:0.1091s
Iter-198; loss: 0.0780;time per batch:0.0452s
Iter-231; loss: 0.0565;time per batch:0.1080s
Iter-264; loss: 0.0946;time per batch:0.0455s
validation results:
Mean Rank: 629
Mean RR: 0.4008
Hit@1: 0.3029
Hit@3: 0.4521
Hit@5: 0.5144
Hit@10: 0.5883
Epoch-21
————————————————
Iter-0; loss: 0.0508;time per batch:0.0454s
Iter-33; loss: 0.0775;time per batch:0.1091s
Iter-66; loss: 0.0815;time per batch:0.0450s
Iter-99; loss: 0.0556;time per batch:0.0518s
Iter-132; loss: 0.0795;time per batch:0.0494s
Iter-165; loss: 0.0887;time per batch:0.0526s
Iter-198; loss: 0.0730;time per batch:0.0521s
Iter-231; loss: 0.0584

Iter-33; loss: 0.0635;time per batch:0.0451s
Iter-66; loss: 0.0526;time per batch:0.1093s
Iter-99; loss: 0.0551;time per batch:0.0448s
Iter-132; loss: 0.0626;time per batch:0.1082s
Iter-165; loss: 0.0493;time per batch:0.0451s
Iter-198; loss: 0.0300;time per batch:0.1082s
Iter-231; loss: 0.0702;time per batch:0.0450s
Iter-264; loss: 0.0495;time per batch:0.1060s
Epoch-39
————————————————
Iter-0; loss: 0.0487;time per batch:0.1152s
Iter-33; loss: 0.0717;time per batch:0.0843s
Iter-66; loss: 0.0211;time per batch:0.1092s
Iter-99; loss: 0.0460;time per batch:0.0733s
Iter-132; loss: 0.0522;time per batch:0.1079s
Iter-165; loss: 0.0462;time per batch:0.0875s
Iter-198; loss: 0.0493;time per batch:0.1084s
Iter-231; loss: 0.0656;time per batch:0.0789s
Iter-264; loss: 0.0624;time per batch:0.1090s
Epoch-40
————————————————
Iter-0; loss: 0.0505;time per batch:0.0634s
Iter-33; loss: 0.0547;time per batch:0.1082s
Iter-66; loss: 0.0431;time per batch:0.0499s
Iter-99; loss: 0.0416;time per batch:0.1