In [1]:

import logging
import os
logger = logging.getLogger()
 
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
# Setup file handler
fhandler  = logging.FileHandler('my.log')
fhandler.setLevel(logging.DEBUG)
fhandler.setFormatter(formatter)
 
# Configure stream handler for the cells
chandler = logging.StreamHandler()
chandler.setLevel(logging.DEBUG)
chandler.setFormatter(formatter)
 
# Add both handlers
logger.addHandler(fhandler)
logger.addHandler(chandler)
logger.setLevel(logging.DEBUG)
 
# Show the handlers
logger.handlers
 
# Log Something
logger.info("Test info")
logger.debug("Test debug")
logger.error("Test error")


2022-05-05 03:21:15,794 - root - INFO - Test info
2022-05-05 03:21:15,795 - root - DEBUG - Test debug
2022-05-05 03:21:15,796 - root - ERROR - Test error


In [2]:
class Config:
    def __init__(self):
        self.data_path = "/home/qiupp/data/"
        self.file_name = "FB15K"
        self.ent_num = 10
        self.rel_num = 10
        self.smooth = 0.1
        self.dim = 200
        self.lr = 0.001
        self.batch_size = 1400
        self.eval_batch = 500
        self.drop_out = 0.5
        self.lambd = 0.00001
        self.epochs = 400
        self.model_path =  "/home/qiupp/codestore/CrossE/models"
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    def init_num(self, ent_num, rel_num):
        self.ent_num = ent_num
        self.rel_num = rel_num

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict as ddict
class loadData:
    def __init__(self, config):
#         super(loadData, self).__init__()
        self.file_path = config.data_path+config.file_name+"/"
        self.entity2id = {}
        self.rel2id = {}
        
        logger.info("load "+config.file_name)
        self.data = {sql:self.read_file(sql) for sql in ['train', 'valid', 'test']}
        logger.info("load "+config.file_name+" is over")
        
        logger.info("数据开始处理")
        self.train_hr_tlist, self.train_tr_hlist = self.get_train_list()
        
        
        self.valid_hr_tlist,self.valid_tr_hlist,self.test_hr_tlist,self.test_tr_hlist = self.get_valid_test_list()
        logger.info("数据处理结束")
        
        logger.info("ent_num: "+str(self.get_ent_num()))
        logger.info("rel_num: "+str(self.get_rel_num()))
        
    
    def get_ent_num(self):
        return len(self.entity2id)
    
    def get_rel_num(self):
        return len(self.rel2id)
    
    def get_train_list(self):
        hr_tlist = ddict(set)
        tr_hlist = ddict(set)
        
        new_hr_tlist = []
        new_tr_hlist = []
        for (h, r, t) in self.data['train']:
            hr_tlist[(h, r)].add(t)
            tr_hlist[(t, r)].add(h)
        for (h, r, t) in self.data['train']:
            new_hr_tlist.append({"triple":(h, r), 'labels':hr_tlist[(h, r)]})
            new_tr_hlist.append({"triple":(t, r), 'labels':tr_hlist[(t, r)]})
        return new_hr_tlist, new_tr_hlist
    
    def get_valid_test_list(self):
        hr_tlist = ddict(set)
        tr_hlist = ddict(set)
        
        valid_hr_tlist = []
        valid_tr_hlist = []
        
        test_hr_tlist = []
        test_tr_hlist = []
        
        for sql in ['test', 'valid', 'train']:
            for (h, r, t) in self.data[sql]:
                hr_tlist[(h, r)].add(t)
                tr_hlist[(t, r)].add(h)
                
        for (h, r, t) in self.data['valid']:
            valid_hr_tlist.append({"triple":(h, r, t), 'labels':hr_tlist[(h, r)]})
            valid_tr_hlist.append({"triple":(t, r, h), 'labels':tr_hlist[(t, r)]})
        
        for (h, r, t) in self.data['test']:
            test_hr_tlist.append({"triple":(h, r, t), 'labels':hr_tlist[(h, r)]})
            test_tr_hlist.append({"triple":(t, r, h), 'labels':tr_hlist[(t, r)]})
        return valid_hr_tlist, valid_tr_hlist, test_hr_tlist, test_tr_hlist      
        
    def get_entity_id(self, entity):
        if not entity in self.entity2id.keys():
            self.entity2id[entity] = len(self.entity2id)
        return self.entity2id[entity]
    
    def get_rel_id(self, rel):
        if not rel in self.rel2id.keys():
            self.rel2id[rel] = len(self.rel2id)
        return self.rel2id[rel]
    
    def get_triple(self, line):
        parts = line.strip().split("\t")
        return tuple((self.get_entity_id(parts[0]), self.get_rel_id(parts[1]), self.get_entity_id(parts[2])))
    
    def read_file(self, text):
        triples = []
        with open(self.file_path+text+".txt", 'r')as f:
            lines = f.readlines()
        for line in lines:
            triples.append(self.get_triple(line))
        return triples
    
    
        
        
        
    
    

In [4]:
import numpy as np
class myTrainDataSet(Dataset):
    def __init__(self, myLoadData, config):
        super(myTrainDataSet, self).__init__()
        self.data_hr_tlist = myLoadData.train_hr_tlist
        self.data_tr_hlist = myLoadData.train_tr_hlist
        self.ent_num = config.ent_num
        self.smooth = config.smooth
    def __len__(self):
        return len(self.data_hr_tlist)
    
    def __getitem__(self, index):
        h_triple = torch.LongTensor(self.data_hr_tlist[index]['triple'])
        h_list = torch.Tensor(self.getlabel(self.data_hr_tlist[index]['labels']))
        
        t_triple = torch.LongTensor(self.data_tr_hlist[index]['triple'])
        t_list = torch.Tensor(self.getlabel(self.data_tr_hlist[index]['labels']))
        
        return h_triple,h_list, t_triple,t_list 
    
    def getlabel(self, labels_list):
        labels = np.zeros(self.ent_num)
        for label in labels_list:
            labels[label] = 1
        if self.smooth != 0.0:
            labels = (1 -  self.smooth)*labels + self.smooth/self.ent_num
        return labels

    

In [5]:
class myTestDateSet(Dataset):
    def __init__(self, myLoadData, config, file_type):
        self.ent_num = config.ent_num
        if file_type == 'test':
            self.hr_tlist =  myLoadData.test_hr_tlist
            self.tr_hlist = myLoadData.test_tr_hlist
        else:
            
            self.hr_tlist = myLoadData.valid_hr_tlist
            self.tr_hlist = myLoadData.valid_tr_hlist
    
    def __len__(self):
        return len(self.hr_tlist)
    
    
    def __getitem__(self, index):
        fact_h = self.hr_tlist[index]['triple']
        fact_t = self.tr_hlist[index]['triple']
        
        label_h = self.hr_tlist[index]['labels']
        label_t = self.tr_hlist[index]['labels']
        return torch.LongTensor(fact_h), torch.Tensor(self.getlabel(label_h)), torch.LongTensor(fact_t), torch.LongTensor(self.getlabel(label_t))
    def getlabel(self, labels_list):
        labels = np.zeros(self.ent_num)
        for label in labels_list:
            labels[label] = 1
        return labels


In [6]:
class Measure:
    def __init__(self):
        self.hit1 = 0.0
        self.hit3 = 0.0
        self.hit10 = 0.0
        self.mr = 0.0
        self.mrr = 0.0
    
    def update(self, rank):
        if rank == 1:
            self.hit1 += 1
        if rank <= 3:
            self.hit3 += 1
        if rank <= 10:
            self.hit10 += 1
        self.mr += rank
        self.mrr += 1.0/rank
    def deal(self, fact_num):
        self.hit1 /= fact_num
        self.hit3 /= fact_num
        self.hit10 /= fact_num
        self.mr /= fact_num
        self.mrr /= fact_num
        return self.mrr
    def init(self):
        self.hit1 = 0.0
        self.hit3 = 0.0
        self.hit10 = 0.0
        self.mr = 0.0
        self.mrr = 0.0
    def print_(self):
        logger.info("--------------------")
        logger.info("hit1: "+str(self.hit1))
        logger.info("hit3: "+str(self.hit3))
        logger.info("hit10: "+str(self.hit10))
        logger.info("mr: "+ str(self.mr))
        logger.info("mrr: "+str(self.mrr))

In [7]:
class CrossE(nn.Module):
    def __init__(self, config):
        super(CrossE, self).__init__()
        self.lambd = config.lambd
        
        self.ent_embedding  = nn.Embedding(config.ent_num, config.dim)
        
        self.rel_embedding = nn.Embedding(config.rel_num, config.dim)
        
        self.rel_rev_embedding = nn.Embedding(config.rel_num, config.dim)
        
        self.h_weight = nn.Embedding(config.rel_num, config.dim)
        
        self.t_weight = nn.Embedding(config.rel_num, config.dim)
        
        self.init()
        
        self.h_t_bias = nn.Parameter(torch.zeros([config.ent_num]), requires_grad= True)
        
        self.t_h_bias = nn.Parameter(torch.zeros([config.ent_num]), requires_grad= True)
        
        self.h_bias = nn.Parameter(torch.zeros([config.dim]), requires_grad= True)
        
        self.t_bias = nn.Parameter(torch.zeros([config.dim]), requires_grad= True)
        
        self.dropout = nn.Dropout(config.drop_out)
        self.loss = nn.BCELoss()
    def init(self):
        nn.init.xavier_normal_(self.ent_embedding.weight.data)
        nn.init.xavier_normal_(self.rel_embedding.weight.data)
        nn.init.xavier_normal_(self.rel_rev_embedding.weight.data)
        nn.init.xavier_normal_(self.h_weight.weight.data)
        nn.init.xavier_normal_(self.t_weight.weight.data)
    def regulation(self):
        return ((torch.norm(self.ent_embedding.weight, 2)**2)\
                +(torch.norm(self.rel_embedding.weight, 2)**2)\
                +(torch.norm(self.rel_rev_embedding.weight, 2)**2)\
                +(torch.norm(self.h_weight.weight, 2)**2)\
                +(torch.norm(self.t_weight.weight, 2)**2))/5
#         return (torch.sum(torch.abs(self.ent_embedding.weight))\
#                 +torch.sum(torch.abs(self.rel_embedding.weight))\
#                 +torch.sum(torch.abs(self.rel_rev_embedding.weight))\
#                 +torch.sum(torch.abs(self.h_weight.weight))\
#                 +torch.sum(torch.abs(self.t_weight.weight)))/5
    
    
    def forward(self,h_list, h_labels, t_list, t_labels):
        h_h = h_list[:, 0]
        h_r = h_list[:, 1]
        
        t_t = t_list[:, 0]
        t_r = t_list[:, 1]
        
        h_emb = self.ent_embedding(h_h)
        r_emb = self.rel_embedding(h_r)
        h_w = self.h_weight(h_r)
        
        h_res = h_emb*h_w +  r_emb*h_emb*h_w
        
        hrt = torch.mm(self.dropout(torch.tanh(h_res + self.h_bias)),self.ent_embedding.weight.transpose(0, 1))
        hrt += self.h_t_bias.expand_as(hrt)
#         print(hrt.shape)
        hrt = torch.sigmoid(hrt)
    
        hrt_loss = -torch.sum(torch.log(torch.clamp(hrt, 1e-10, 1.0)) * h_labels\
                + torch.log(torch.clamp(1 - hrt, 1e-10, 1.0)) *( 1-h_labels))
    
    
        
        t_emb = self.ent_embedding(t_t)
        r_rev = self.rel_rev_embedding(t_r)
        t_w = self.t_weight(t_r)
        
        t_res = t_emb*t_w + r_rev*t_emb*t_w
        trh = torch.mm(self.dropout(torch.tanh(t_res + self.t_bias)), self.ent_embedding.weight.transpose(0, 1))
        trh += self.t_h_bias.expand_as(trh)
        
        trh = torch.sigmoid(trh)
        
        trh_loss = -torch.sum(torch.log(torch.clamp(trh, 1e-10, 1.0)) *  t_labels\
                + torch.log(torch.clamp(1 - trh, 1e-10, 1.0)) * (1-t_labels))
            
#         print(self.regulation())
        
        return hrt_loss + trh_loss + self.lambd*self.regulation()
    
    def pred(self, h_list, h_labels, t_list, t_labels):
        h_h = h_list[:, 0]
        h_r = h_list[:, 1]
        
        t_t = t_list[:, 0]
        t_r = t_list[:, 1]
        
        h_emb = self.ent_embedding(h_h)
        r_emb = self.rel_embedding(h_r)
        h_w = self.h_weight(h_r)
        
        h_res = h_emb*h_w +  r_emb*h_emb*h_w
        
        hrt = torch.mm(torch.tanh(h_res + self.h_bias),self.ent_embedding.weight.transpose(0, 1))
        hrt += self.h_t_bias.expand_as(hrt)
#         print(hrt.shape)
        hrt = torch.sigmoid(hrt)
        
        t_emb = self.ent_embedding(t_t)
        r_rev = self.rel_rev_embedding(t_r)
        t_w = self.t_weight(t_r)
        
        t_res = t_emb*t_w + r_rev*t_emb*t_w
        trh = torch.mm(torch.tanh(t_res + self.t_bias), self.ent_embedding.weight.transpose(0, 1))
        trh += self.t_h_bias.expand_as(trh)
        
        trh = torch.sigmoid(trh)
        
        return hrt.detach(), trh.detach()
        
        
        

In [8]:
class Trainer:
    def __init__(self, myLoadData, config, model):
        self.model = model.to(config.device)
        self.device = config.device
        self.measure = Measure()
        self.train_loader = DataLoader(myTrainDataSet(myLoadData, config), batch_size = config.batch_size, shuffle = True)
        self.config = config
        testdata = myTestDateSet(myLoadData, config, 'valid')
        self.test_loader = DataLoader(testdata, batch_size = config.eval_batch, shuffle = True)
#         self.epochs = config.epochs
        self.fact_num = len(testdata)
    def train(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr = self.config.lr)
        for epoch in range(1, self.config.epochs+1):
            self.model.train()
            tot_loss = 0
            for i, (h, h_l, t,t_l) in enumerate(self.train_loader):
                optimizer.zero_grad()
                h = h.to(self.device)
                h_l = h_l.to(self.device)
                t = t.to(self.device)
                t_l = t_l.to(self.device)
                
                loss = self.model(h, h_l, t, t_l)
                loss.backward()
                tot_loss += loss.cpu().item()
                optimizer.step()
            logger.info(str(epoch)+"     loss: "+str(tot_loss))
            self.measure.init()
            self.eval_valid()
    def save_models(self):
        if not os.path.exists(self.config.model_path):
            os.mkdir(self.config.model_path)
        torch.save(self.model, self.config.model_path+"/"+self.config.file_name+"-best_model.pkl")
        
    def eval_valid(self):
        self.model.eval()
        best_acc = 0.0
        cn = 0
        for i, (h, h_l, t,t_l) in enumerate(self.test_loader):
            cn += h.shape[0]
            
            h = h.to(self.device)
            h_l = h_l.to(self.device)
            t = t.to(self.device)
            t_l = t_l.to(self.device)
            h_pred, t_pred= self.model.pred(h, h_l, t, t_l)
            
            h_target = h[:,-1]
            t_target = t[:, -1]
            self.pred(h_pred, h_l, h_target)
            self.pred(t_pred, t_l, t_target)
#         if self.fact_num == cn:
#             print("kjkjkj-----------")
#         else:
#             print("------------------"+str(cn))
        acc = self.measure.deal(self.fact_num*2)
        if acc > best_acc:
            self.save_models()
        self.measure.print_()
    
    def pred(self, pred, labels, t):
        batch = pred.shape[0]
        pred_size = torch.arange(batch)
        target = pred[pred_size,t]
        pred = torch.where(labels.byte(), torch.zeros_like(pred), pred)
        pred[pred_size, t] = target
        pred = pred.cpu().numpy()
        t = t.cpu().numpy()
        for i in range(batch):
            e = t[i]
            one_pred = pred[i]
            aim = one_pred[e]
            sc = np.delete(one_pred, e)
            sc = np.insert(sc, 0,aim)
            rank = (sc[0] <= sc).sum()
            self.measure.update(rank)
                            
                        
        
            
                            
                            
                            
                            
    
        

In [9]:
config = Config()
myLoadData = loadData(config)
config.init_num(myLoadData.get_ent_num(), myLoadData.get_rel_num())
model = CrossE(config)  


trainer = Trainer(myLoadData, config, model)
trainer.train()

2022-05-05 03:21:17,998 - root - INFO - load FB15K
2022-05-05 03:21:19,864 - root - INFO - load FB15K is over
2022-05-05 03:21:19,865 - root - INFO - 数据开始处理
2022-05-05 03:21:27,848 - root - INFO - 数据处理结束
2022-05-05 03:21:27,849 - root - INFO - ent_num: 14951
2022-05-05 03:21:27,850 - root - INFO - rel_num: 1345
2022-05-05 03:24:34,629 - root - INFO - 1     loss: 2171157840.0
  pred = torch.where(labels.byte(), torch.zeros_like(pred), pred)
2022-05-05 03:24:57,825 - root - INFO - --------------------
2022-05-05 03:24:57,826 - root - INFO - hit1: 0.00453
2022-05-05 03:24:57,826 - root - INFO - hit3: 0.01992
2022-05-05 03:24:57,827 - root - INFO - hit10: 0.03627
2022-05-05 03:24:57,827 - root - INFO - mr: 6102.83665
2022-05-05 03:24:57,828 - root - INFO - mrr: 0.016292392055911563
2022-05-05 03:27:58,123 - root - INFO - 2     loss: 474990671.71875
2022-05-05 03:28:20,900 - root - INFO - --------------------
2022-05-05 03:28:20,900 - root - INFO - hit1: 0.01025
2022-05-05 03:28:20,901 - ro

2022-05-05 04:26:20,302 - root - INFO - hit1: 0.31274
2022-05-05 04:26:20,303 - root - INFO - hit3: 0.41
2022-05-05 04:26:20,304 - root - INFO - hit10: 0.51623
2022-05-05 04:26:20,304 - root - INFO - mr: 503.65031
2022-05-05 04:26:20,305 - root - INFO - mrr: 0.38251550919659544
2022-05-05 04:29:28,580 - root - INFO - 20     loss: 102704020.6953125
2022-05-05 04:29:52,430 - root - INFO - --------------------
2022-05-05 04:29:52,432 - root - INFO - hit1: 0.32327
2022-05-05 04:29:52,433 - root - INFO - hit3: 0.42155
2022-05-05 04:29:52,433 - root - INFO - hit10: 0.52992
2022-05-05 04:29:52,433 - root - INFO - mr: 451.36334
2022-05-05 04:29:52,434 - root - INFO - mrr: 0.3938698936538486
2022-05-05 04:33:00,144 - root - INFO - 21     loss: 101124022.23828125
2022-05-05 04:33:24,041 - root - INFO - --------------------
2022-05-05 04:33:24,043 - root - INFO - hit1: 0.33183
2022-05-05 04:33:24,044 - root - INFO - hit3: 0.43325
2022-05-05 04:33:24,045 - root - INFO - hit10: 0.54257
2022-05-05 0

2022-05-05 05:36:50,841 - root - INFO - --------------------
2022-05-05 05:36:50,843 - root - INFO - hit1: 0.40931
2022-05-05 05:36:50,844 - root - INFO - hit3: 0.52941
2022-05-05 05:36:50,845 - root - INFO - hit10: 0.642
2022-05-05 05:36:50,846 - root - INFO - mr: 197.74656
2022-05-05 05:36:50,847 - root - INFO - mrr: 0.4902660975924878
2022-05-05 05:39:58,629 - root - INFO - 40     loss: 88117521.11328125
2022-05-05 05:40:22,534 - root - INFO - --------------------
2022-05-05 05:40:22,535 - root - INFO - hit1: 0.41169
2022-05-05 05:40:22,536 - root - INFO - hit3: 0.53172
2022-05-05 05:40:22,536 - root - INFO - hit10: 0.64616
2022-05-05 05:40:22,537 - root - INFO - mr: 195.13601
2022-05-05 05:40:22,537 - root - INFO - mrr: 0.4929168621381835
2022-05-05 05:43:29,707 - root - INFO - 41     loss: 87808325.81835938
2022-05-05 05:43:53,971 - root - INFO - --------------------
2022-05-05 05:43:53,973 - root - INFO - hit1: 0.4134
2022-05-05 05:43:53,975 - root - INFO - hit3: 0.53457
2022-05-

2022-05-05 06:47:26,147 - root - INFO - --------------------
2022-05-05 06:47:26,148 - root - INFO - hit1: 0.43683
2022-05-05 06:47:26,149 - root - INFO - hit3: 0.56524
2022-05-05 06:47:26,150 - root - INFO - hit10: 0.67773
2022-05-05 06:47:26,151 - root - INFO - mr: 160.34693
2022-05-05 06:47:26,152 - root - INFO - mrr: 0.5213775612828307
2022-05-05 06:50:33,675 - root - INFO - 60     loss: 84113496.421875
2022-05-05 06:50:57,487 - root - INFO - --------------------
2022-05-05 06:50:57,488 - root - INFO - hit1: 0.43785
2022-05-05 06:50:57,489 - root - INFO - hit3: 0.56682
2022-05-05 06:50:57,489 - root - INFO - hit10: 0.67902
2022-05-05 06:50:57,490 - root - INFO - mr: 159.11039
2022-05-05 06:50:57,490 - root - INFO - mrr: 0.5226494098998804
2022-05-05 06:54:05,033 - root - INFO - 61     loss: 83972438.68164062
2022-05-05 06:54:28,843 - root - INFO - --------------------
2022-05-05 06:54:28,844 - root - INFO - hit1: 0.43816
2022-05-05 06:54:28,845 - root - INFO - hit3: 0.56764
2022-05

KeyboardInterrupt: 

In [None]:
model.load_state_dict(torch.load(config.model_path+"/"+config.file_name+"-best_model.pkl"))
class Tester:
    def __init__(self,myLoadData, config, model):
        self.config = config
        self.device = config.device 
        self.model = model.to(config.device)
        testdata = myTestDateSet(myLoadData, config, 'test')
        self.loader_test = DataLoader(testdata, batch_size = config.eval_batch, shuffle = True)
#         self.epochs = config.epochs
        self.fact_num = len(testdata)
    def eval_valid(self):
        self.model.eval()
        best_acc = 0.0
        for i, (h, h_l, t,t_l) in enumerate(self.test_loader):
            h = h.to(self.device)
            h_l = h_l.to(self.device)
            t = t.to(self.device)
            t_l = t_l.to(self.device)
            h_pred, t_pred= self.model.pred(h, h_l, t, t_l)
            
            h_target = h[:,-1]
            t_target = h[:, -1]
            self.pred(h_pred, h_l, h_target)
            self.pred(t_pred, t_l, t_target)
        acc = self.measure.deal(self.fact_num)
        self.measure.print_()
    
    def pred(self, pred, labels, t):
        batch = pred.shape[0]
        pred_size = torch.arange(batch)
        target = pred[pred_size,t]
        pred = torch.where(labels.byte(), torch.zeros_like(pred), pred)
        pred[pred_size, t] = target
        pred = pred.cpu()numpy()
        t = t.cpu().numpy()
        for i in range(batch):
            e = t[i]
            one_pred = pred[i]
            aim = one_pred[e]
            sc = np.delete(one_pred, e)
            sc = np.insert(sc, 0, aim)
            rank = (sc[0] <= sc).sum()
            self.measure.update(rank)
                            
