In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.utils import clip_grad_norm_
from torch.optim import lr_scheduler

import numpy as np
import torch
import os
from tqdm import tqdm_notebook as tqdm
from random import choice
import random
np.random.seed(1)
cuda = torch.cuda.is_available()
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
class Data_processer():
    def __init__(self, train_set, valid_set, test_set, unseen_node, hop_num):
        self.train_set = 'Freebase13/train.txt' #train_set
        self.valid_set = 'Freebase13/dev.txt' #valid_set
        self.test_set = 'Freebase13/test.txt'   #test_set
        self.unseen_node = False
        self.hop_num = 1
        
    def data_preprocesser(self, data_path):
        triple_list = []
        entity_list = []
        head_entity_list = []
        tail_entity_list = []
        relation_list = []
        with open(data_path, 'r') as f:
            i = 0
            for line in f:
                line = line.strip('\n')
                triple = line.split('\t')
                triple_list.append(triple)
                entity_list.append(triple[0])
                head_entity_list.append(triple[0])
                relation_list.append(triple[1])
                entity_list.append(triple[2])
                tail_entity_list.append(triple[2])
        entity_set = set(entity_list)
        entity_list = list(entity_set)
        print('entity_len', len(entity_list))
        relation_set = set(relation_list)
        relation_list = list(relation_set)
        print('relation_len', len(relation_list))
        
        if self.unseen_node == True:
            pass  #这里是用于unseen_node 实验的数据处理部分
        return triple_list, entity_list, head_entity_list, tail_entity_list, relation_list
    
    def unseen_data_processer(self):
        pass     #这里是unseen_node 实验数据处理的具体操作
    
    def convert_triple_to_index(self, entity_emb_dict, relation_emb_dict, triple_list):
        triple_id_list = []
        for triple in triple_list:
            triple_id = []
            triple_id.append(entity_emb_dict[triple[0]])
            triple_id.append(relation_emb_dict[triple[1]])
            triple_id.append(entity_emb_dict[triple[2]])
            triple_id_list.append(triple_id)
        return triple_id_list
        
        
    def vacabulary_builder(self, train_entity_list, train_relation_list):
        train_entity_list += ['<UNK_TOKEN>','<PAD_TOKEN>']
        entity_emb_dict = dict([(e, idx) for idx, e in enumerate(train_entity_list)])
        print(len(entity_emb_dict))
        print('antoine_brutus_menier:',entity_emb_dict['antoine_brutus_menier'])
        print('roman_catholic_church:',entity_emb_dict['roman_catholic_church'])
        train_relation_list += ['<UNK_TOKEN>','<PAD_TOKEN>']
        relation_emb_dict = dict([(e, idx) for idx, e in enumerate(train_relation_list)])
        print(len(relation_emb_dict))
        print('religion', relation_emb_dict['religion'])
        return entity_emb_dict, relation_emb_dict
    
    def negative_triple_sampling(self, train_entity_list, positive_triple_list):
        negative_triple_list = []
        print('negative sampling processing:')
        pbar = tqdm(total=len(positive_triple_list))
        for triple in positive_triple_list:
            negative_entity = choice(train_entity_list)
            if np.random.randint(0, 2) == 0:
                triple[0] = negative_entity
            else:
                triple[2] = negative_entity
            negative_triple_list.append(triple)
            pbar.update(1)
        pbar.close()
        return negative_triple_list
    
    def entity_neighbors_dict_generator(self, train_triple_id_list):
        entity_neighbors_id_dict = {}
        entity_neighbors_triple_id_dict = {}
        for triple_id in train_triple_id_list:
            if triple_id[0] not in entity_neighbors_id_dict:
                entity_neighbors_id_dict[triple_id[0]] = [[triple_id[1], triple_id[2]]]
                entity_neighbors_triple_id_dict[triple_id[0]] = [triple_id]
            else:
                entity_neighbors_id_dict[triple_id[0]].append([triple_id[1], triple_id[2]])
                entity_neighbors_triple_id_dict[triple_id[0]].append(triple_id)
            if triple_id[2] not in entity_neighbors_id_dict:
                entity_neighbors_id_dict[triple_id[2]] = [[triple_id[1], triple_id[0]]]
                entity_neighbors_triple_id_dict[triple_id[2]] = [triple_id]
            else:
                entity_neighbors_id_dict[triple_id[2]].append([triple_id[1], triple_id[0]])
                entity_neighbors_triple_id_dict[triple_id[2]].append(triple_id)
        return entity_neighbors_id_dict, entity_neighbors_triple_id_dict
                        
    def data_info_generator(self):
        train_triple_list, train_entity_list, train_head_entity_list, train_tail_entity_list, train_relation_list = \
        self.data_preprocesser(self.train_set)
        print('train_entity_list', train_entity_list[:10])
        
        valid_triple_list, valid_entity_list, valid_head_entity_list, valid_tail_entity_list, valid_relation_list = \
        self.data_preprocesser(self.valid_set)
        
        print(valid_triple_list[0])
        
        test_triple_list, test_entity_list, test_head_entity_list, test_tail_entity_list, test_relation_list = \
        self.data_preprocesser(self.test_set)
        
        #         all_entity_list = train_entity_list + valid_entity_list + test_entity_list
#         all_entity_set = set(all_entity_list)
#         all_entity_list = list(all_entity_set)
#        print('all_entity_len', len(all_entity_list))
        entity_emb_dict, relation_emb_dict = self.vacabulary_builder(train_entity_list, train_relation_list)
        train_triple_id_list = self.convert_triple_to_index(entity_emb_dict, relation_emb_dict, train_triple_list)
        print('train_triple_list',train_triple_list[0])
        print('train_triple_id', train_triple_id_list[0])
#         print('train_triple', train_triple_list[0])
#         print(entity_emb_dict['/m/027rn'], relation_emb_dict['/location/country/form_of_government'], entity_emb_dict['/m/06cx9'])
        negative_triple_list = self.negative_triple_sampling(train_entity_list, train_triple_list)
        negative_triple_id_list = self.convert_triple_to_index(entity_emb_dict, relation_emb_dict, negative_triple_list)
        entity_neighbors_id_dict, entity_neighbors_triple_id_dict = self.entity_neighbors_dict_generator(train_triple_id_list)
        print(entity_neighbors_id_dict[train_triple_id_list[0][0]], entity_neighbors_triple_id_dict[train_triple_id_list[0][0]])
        data_info = {'train_triple_list': train_triple_list, 'train_entity_list': train_entity_list, \
                    'train_relation_list': train_relation_list, 'valid_triple_list': valid_triple_list,\
                    'valid_entity_list': valid_entity_list, 'valid_relation_list': valid_relation_list, \
                    'test_triple_list': test_triple_list, 'test_entity_list': test_entity_list,\
                    'test_relation_list': test_relation_list, 'entity_emb_dict': entity_emb_dict, \
                    'relation_emb_dict': relation_emb_dict, 'train_triple_id_list': train_triple_id_list,\
                    'negative_triple_list': negative_triple_list, 'negative_triple_id_list': negative_triple_id_list,\
                    'entity_neighbors_triple_id_dict': entity_neighbors_triple_id_dict,\
                     'entity_neighbors_id_dict': entity_neighbors_id_dict}
        return data_info

In [3]:
class Data_Loader():
    def __init__(self, data_info):
        self.data_info = data_info
        self.entity_emb_dict = data_info['entity_emb_dict']
        self.relation_emb_dict = data_info['relation_emb_dict']
        self.train_triple_id_list = data_info['train_triple_id_list']
        self.negative_triple_id_list = data_info['negative_triple_id_list']
        self.entity_neighbors_id_dict = data_info['entity_neighbors_id_dict'] 
        
        self.train_init()

    def train_init(self):
        self.train_triple_id_pair_list = []
        for i in range(len(self.train_triple_id_list)):
            self.train_triple_id_pair_list.append([{'train_triple_id_list': self.train_triple_id_list[i]}, \
                                              {'negative_triple_id_list': self.negative_triple_id_list[i]}])
        #print('train_triple_id_pair_list', len(self.train_triple_id_pair_list))
        random.shuffle(self.train_triple_id_pair_list)

    def data_iter_train(self, batch_size=300):
        train_triple_id_pair_list = self.train_triple_id_pair_list
        
        barch_num = int(len(train_triple_id_pair_list)/batch_size)
        
        for batch_i in range(batch_num+1):
            batch_start = batch_i * batch_size
            batch_end = min((batch_i+1)*batch_size, len(train_triple_id_pair_list))
            
            batch_train_triple_id_pair_list = train_triple_id_pair_list[batch_start: batch_end]
            batch_elem = {}
            batch_elem['batch_size'] = len(batch_train_triple_id_pair_list)
            batch_elem['batch_train_triple_id_pair_list'] = batch_train_triple_id_pair_list
            
            yield batch_elem

In [4]:
train_set = 'Freebase13/train.txt' #train_set
valid_set = 'Freebase13/dev.txt' #valid_set
test_set = 'Freebase13/test.txt'  
data = Data_processer(train_set, valid_set, test_set, unseen_node=False, hop_num=1)
data_info = data.data_info_generator()
data_loader = Data_Loader(data_info)


entity_len 75043
relation_len 13
train_entity_list ['robert_lewis_taylor', 'cappadocia', 'augustus_iii_of_poland', 'hermann_klaatsch', 'cipriano_ferrandini', 'george_frisbie_hoar', 'john_c_rice', 'miep_gies', 'thomas_bernard_hackett', 'langdon_cheves']
entity_len 7010
relation_len 7
['cornelie_van_zanten', 'gender', 'female', '1']
entity_len 19557
relation_len 7
75045
antoine_brutus_menier: 29388
roman_catholic_church: 63682
15
religion 4
train_triple_list ['antoine_brutus_menier', 'religion', 'roman_catholic_church']
train_triple_id [29388, 4, 63682]
negative sampling processing:


HBox(children=(IntProgress(value=0, max=316232), HTML(value='')))


[[4, 63682], [6, 27074], [1, 40192], [10, 40751], [8, 56435], [11, 40751]] [[29388, 4, 63682], [29388, 6, 27074], [29388, 1, 40192], [40751, 10, 29388], [29388, 8, 56435], [29388, 11, 40751]]


# Model: simple_version_LAN

In [5]:
class Transform_Embedding(nn.Module):
    def __init__(self, entity_embed_num, entity_embed_dim, relation_embed_num, relation_embed_dim, \
                 entity_neighbors_dict, entity_neighbors_triple_dict):
        super(Transform_Embedding, self).__init__()
        self.entity_embed_num = entity_embed_num
        self.entity_embed_dim = entity_embed_dim
        self.relation_embed_num = relation_embed_num
        self.relation_embed_dim = relation_embed_dim
        self.entity_neighbors_dict = entity_neighbors_dict
        self.entity_neighbors_triple_dict = entity_neighbors_triple_dict
        
        self.entity_embed = nn.Embedding(self.entity_embed_num, self.entity_embed_dim)
        self.relation_embed = nn.Embedding(self.relation_embed_num, self.relation_embed_dim)
        
    def get_unitized_divisor(self, relation_embedding):
        quadratic_sum = sum([float(x)*float(x) for x in relation_embedding])
        #print('quadratic_sum', quadratic_sum)
        return quadratic_sum
    
    def transform_module(self, entity_neighbors):
        transformed_neighbors_embedding_list = []
        neighbors_relation_embedding_list = []
        for neighbors_relation, neighbors_entity in entity_neighbors:
            #print(neighbors_entity, neighbors_relation)
            neighbors_entity = self.entity_embed(neighbors_entity).unsqueeze(0)
            neighbors_relation = self.relation_embed(neighbors_relation)
            neighbors_relation_embedding_list.append(neighbors_relation.unsqueeze(0))
            #print(neighbors_entity.size(), neighbors_relation.size())
            unitized_divisor = self.get_unitized_divisor(neighbors_relation)
            neighbors_relation = neighbors_relation.unsqueeze(0).permute(1, 0)
            #print(neighbors_entity.size(), neighbors_relation.size())
            projection_coefficient = torch.matmul(neighbors_entity, neighbors_relation) / unitized_divisor
            transformed_neighbors_embedding_list.append(neighbors_entity - torch.matmul(projection_coefficient, neighbors_entity))
        transformed_neighbors_embedding = torch.stack(transformed_neighbors_embedding_list).permute(1,0,2)
        neighbors_relation_embedding = torch.stack(neighbors_relation_embedding_list).permute(1,0,2)
        return transformed_neighbors_embedding, neighbors_relation_embedding
    
    def input_triple_embedding(self, neighbors_triple_list):
        neighbors_triple_embedding_list = []
        for neighbors_triple in neighbors_triple_list:
            head_entity_embedding = self.entity_embed(neighbors_triple[0])
            relation_embedding = self.relation_embed(neighbors_triple[1])
            tail_entity_embedding = self.entity_embed(neighbors_triple[2])
            neighbors_triple_embedding_list.append([head_entity_embedding, relation_embedding, tail_entity_embedding])
        return neighbors_triple_embedding_list
    
    def forward(self, triple):
        #head_entity_neighbors, use to get neighbors' relation and entity
        head_entity_neighbors = self.entity_neighbors_dict[triple[0]]
        print(head_entity_neighbors)
        head_entity_neighbors = torch.LongTensor(head_entity_neighbors).cuda()
        #tail_entity_beighbors, use to get neighbors' relation and entity
        tail_entity_neighbors = self.entity_neighbors_dict[triple[2]]
        print('tail_entity_neighbors', tail_entity_neighbors)
        tail_entity_neighbors = torch.LongTensor(tail_entity_neighbors).cuda()
        #neighbors_triple_list, use to get the input triple embeddings to calculate the input loss
        neighbors_triple_list = self.entity_neighbors_triple_dict[triple[0]] + self.entity_neighbors_triple_dict[triple[2]]
        neighbors_triple_list = torch.LongTensor(neighbors_triple_list).cuda()
        #get the transformed neighbors embeddings to feed the attention layer
        head_transformed_neighbors_embedding, head_neighbors_relation_embedding = self.transform_module(head_entity_neighbors)
        tail_transformed_neighbors_embedding, tail_neighbors_relation_embedding = self.transform_module(tail_entity_neighbors)
        #get the input triple embeddings to calculate the input loss
        neighbors_triple_embedding_list = self.input_triple_embedding(neighbors_triple_list)
        print(head_transformed_neighbors_embedding.size(),tail_transformed_neighbors_embedding.size())
        print(head_neighbors_relation_embedding.size(), tail_neighbors_relation_embedding.size())
        print(neighbors_triple_embedding_list[0][0].size())
        head_transformed_concate_embedding = torch.cat((head_transformed_neighbors_embedding, head_neighbors_relation_embedding), dim=2)
        print(head_transformed_concate_embedding.size())
        return head_transformed_neighbors_embedding, tail_transformed_neighbors_embedding, \
                head_neighbors_relation_embedding, tail_neighbors_relation_embedding
    
class NN_Attention(nn.Module):
    def __init__(self, output_dim):
        super(Attention, self).__init__()
        self.linear_output = nn.Linear(dim*2, dim)
        self.mask = None

    def set_mask(self, mask):
        self.mask = mask

        
    def forward(self, output, context):
        '''
        output: decoder,  (batch, 1, hiddem_dim2)
        context: from encoder, (batch, n, hidden_dim1)
        actually, dim2 == dim1, otherwise cannot do matrix multiplication 
        '''
        batch_size = output.size(0)
        hidden_size = output.size(2)
        input_size = context.size(1)
        # (b, o, dim) * (b, dim, i) -> (b, o, i)
        
        attn = torch.bmm(output, context.transpose(1,2))
        if self.mask is not None:
            attn.data.masked_fill_(self.mask, -float('inf'))
        attn = F.softmax(attn.view(-1, input_size), dim=1).view(batch_size, -1, input_size)

        # (b, o, i) * (b, i, dim) -> (b, o, dim)
        mix = torch.bmm(attn, context)

        combined = torch.cat((mix, output), dim=2)

        output = F.tanh(self.linear_out(combined.view(-1, 2*hidden_size)))\
                            .view(batch_size, -1, hidden_size)

        # output: (b, o, dim)
        # attn  : (b, o, i)
        return output, attn

In [6]:
Embedding = Transform_Embedding(75045, 100, 15, 100, data_info['entity_neighbors_id_dict'], data_info['entity_neighbors_triple_id_dict']).cuda()
Embedding([5867, 4, 63013])    #[47770, 1, 48576]

[[4, 18636], [4, 32459], [4, 31992], [4, 64462], [4, 7542], [4, 50170], [4, 67569], [4, 39453], [4, 3715], [4, 74906]]
tail_entity_neighbors [[6, 41028], [1, 35768], [1, 50908]]
torch.Size([1, 10, 100]) torch.Size([1, 3, 100])
torch.Size([1, 10, 100]) torch.Size([1, 3, 100])
torch.Size([100])
torch.Size([1, 10, 200])


(tensor([[[-1.2568, -0.1915, -1.1645,  1.1669, -0.0233, -0.4536,  0.1769,
           -0.5379,  0.9906, -0.3221,  0.0918, -0.6413,  0.6398,  0.0971,
           -0.4101, -0.6597, -0.3318,  1.6718,  0.5664,  0.4370, -0.1530,
           -0.0716, -0.1564,  0.6770,  1.0413,  0.9777, -1.2983,  0.6492,
            0.8654, -0.9530,  1.2200,  0.3792,  0.1065,  1.1964,  0.2683,
            0.5635, -0.3220, -0.1619, -1.1121,  0.2743, -0.4819, -0.9779,
            0.0340,  0.0650, -0.8013,  0.2553,  0.6784,  0.4926, -0.3126,
           -1.5158,  0.9968, -0.1886,  0.7535, -0.0924, -2.2433, -0.1958,
           -0.2301,  0.3421, -0.7785, -0.5358, -1.3350, -0.9280,  1.2613,
            0.5545, -0.4886, -0.3083,  0.9416, -0.4819,  0.7427,  0.4741,
           -0.9964, -0.5280, -0.0728, -0.3686,  1.5166, -0.2487, -0.2715,
           -1.2307,  1.6900,  0.8302, -0.4523, -0.9446, -0.2760, -0.6616,
           -1.4314,  0.5406, -0.8349,  0.1988,  0.6245, -0.5845, -0.1048,
           -1.3093,  0.5728,  1.3155, 

In [7]:
a = [-0.8513,  0.3108, -1.2641, -1.0224,  0.2763, -1.0591, -1.6643,  0.9950,
          2.1550, -0.1997, -1.8261, -0.7156, -0.1861,  1.9700, -1.2295, -1.9794,
          0.5312,  0.5054, -1.7210, -0.4392, -0.2408, -0.1616,  1.7982, -0.1468,
         -0.7754,  1.4710, -0.6014, -1.3528,  0.3130,  1.7130, -0.0720,  1.2369,
         -0.5501,  0.5647,  2.8201,  0.9985,  0.4545,  0.8498,  1.1995,  0.3889,
         -0.7076, -0.2447, -0.6196,  0.3353,  1.0386,  0.0815, -1.8523,  1.9071,
         -0.4660,  1.7228, -0.6703,  0.4635,  0.1990,  0.3829, -0.9589,  0.8642,
          2.0992,  0.5252,  2.1343, -1.4706, -0.3503,  0.6326, -0.0854, -0.7726,
          0.2166, -0.6305, -0.5669, -0.1626,  0.3553, -0.1221,  0.3710, -1.4171,
         -1.3543,  0.5404,  0.6815, -0.3347,  1.3084,  1.0757,  0.9373,  1.2821,
         -0.5696,  0.1902, -0.1564, -0.8229, -0.7222,  2.0384,  0.4083, -0.6134,
         -0.5457, -1.0211, -0.1539, -0.6772, -0.5714,  2.2151,  0.1445, -0.9518,
          0.8393,  0.6562,  0.7298,  0.1450]
print(type(a))
print(a[0])
qua_sum = sum([x*x for x in a])
print(qua_sum)

<class 'list'>
-0.8513
109.51043615000006


In [8]:
import torch
a=torch.randn(100)
b=torch.randn(100)
c = [a, b]
torch.stack(c).size()

torch.Size([2, 100])