In [1]:
import torch
import torch.nn as nn

class Config:
    def __init__(self, ent_num, rel_num):
        self.ent_num = ent_num
        self.rel_num = rel_num
        self.dim = 200
        self.neg_ratio = 100
        self.batch_size = 100
        self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
        self.lambd = 0.00001
        self.lr = 0.001
        self.epochs = 30

In [2]:
class ComplEX(nn.Module):
    def __init__(self, config):
        super(ComplEX, self).__init__()
        self.config = config
        self.ent_re_embedding = nn.Embedding(config.ent_num, config.dim)
        self.ent_im_embedding = nn.Embedding(config.ent_num, config.dim)
        
        self.rel_re_embedding = nn.Embedding(config.rel_num, config.dim)
        self.rel_im_embedding = nn.Embedding(config.rel_num, config.dim)
        self.init()
        self.loss = nn.Softplus()
    
    def init(self):
        nn.init.xavier_uniform_(self.ent_re_embedding.weight.data)
        nn.init.xavier_uniform_(self.ent_im_embedding.weight.data)
        nn.init.xavier_uniform_(self.rel_re_embedding.weight.data)
        nn.init.xavier_uniform_(self.rel_im_embedding.weight.data)
    
    def _cal(self, h_re, h_im, r_re, r_im, t_re, t_im):
        
       
        score =  h_re * t_re * r_re+ h_im * t_im * r_re+ h_re * t_im * r_im- h_im * t_re * r_im
        return torch.sum(score, -1)
        
    def forward(self, h, r, t):
        h_re = self.ent_re_embedding(h)
        h_im = self.ent_im_embedding(h)
        
        r_re = self.rel_re_embedding(r)
        r_im = self.rel_im_embedding(r)
        
        t_re = self.ent_re_embedding(t)
        t_im = self.ent_im_embedding(t)
        sc = self._cal(h_re, h_im, r_re, r_im, t_re, t_im)
        return sc 
    
    def regularization(self, h, r, t):
        
        h_re = self.ent_re_embedding(h)
        h_im = self.ent_im_embedding(h)
        r_re = self.rel_re_embedding(r)
        r_im = self.rel_im_embedding(r)
        t_re = self.ent_re_embedding(t)
        t_im = self.ent_im_embedding(t)
        
        regul = (torch.mean(h_re**2)+
                torch.mean(h_im**2)+
                torch.mean(r_re**2)+
                torch.mean(r_im**2)+
                torch.mean(t_re**2)+
                torch.mean(t_im**2))/6
        return regul
        
    

In [3]:
class loadData:
    def __init__(self, data_name):
        self.path = '/home/qiupp/data/FB15K/'#文件路径自己设置
#         self.path = '/home/qiupp/code/ConvE/data/'+data_name+'/'
        self.rel2id = {}
        self.ent2id = {}
        self.data = {sql: self.read(sql) for sql in ['train', 'valid', 'test']}
        
    def read(self, file_name):
        with open(self.path+file_name+'.txt', 'r') as f:
            lines = f.readlines()
        triples = []
        for line in lines:
            temp = line.strip().split()
            triples.append((self.get_ent(temp[0]), self.get_rel(temp[1]), self.get_ent(temp[2])))
        return triples
    
    def get_ent(self,ent):
        if not ent in self.ent2id.keys():
            self.ent2id[ent] = len(self.ent2id)
        return self.ent2id[ent]
    def get_rel(self, rel):
        if not rel in self.rel2id.keys():
            self.rel2id[rel] = len(self.rel2id)
        return self.rel2id[rel]
    def ent_num(self):
        return len(self.ent2id)
    
    def rel_num(self):
        return len(self.rel2id)
    

In [4]:
from torch.utils.data import DataLoader, Dataset
import numpy as np
from random import randint,random, shuffle
class MyTrainData(Dataset):
    def __init__(self, loaddata, config):
        super(MyTrainData,self).__init__()
        self.data = loaddata.data['train']
        self.config = config
    def __len__(self):
        return len(self.data)
    
    def randValue(self, value):
        temp = randint(0, self.config.ent_num-1)
        while temp == value:
            temp = randint(0, self.config.ent_num-1)
        return temp
    
    
    
    def __getitem__(self, index):
        fact = self.data[index]
        fact = np.expand_dims(fact, axis = 0)
        
        neg = np.repeat(fact, self.config.neg_ratio, axis=0)
        for i in  range(self.config.neg_ratio):
            if random() < 0.5:
                neg[i][0] = self.randValue(neg[i][0])
            else:
                neg[i][2] = self.randValue(neg[i][2])
        fact = np.append(fact, 1)
        neg = np.append(neg, -np.ones((self.config.neg_ratio, 1)), axis = 1)
        return torch.LongTensor(fact), torch.LongTensor(neg)
        
        
        
        

In [5]:
class MyTestData(Dataset):
    def __init__(self, loaddata, data_type):
        self.data = loaddata.data[data_type]
        self.ent_num = loaddata.ent_num()
        self.loaddata = loaddata
        self.all_facts = set(self.get_all_facts())
        
    def get_all_facts(self):
        triples = []
        for sql in ['train', 'valid', 'test']:
            for fact in self.loaddata.data[sql]:
                triples.append(fact)
        return triples
        
        
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        fact = self.data[index]
        #pos = [fact]
        #pos = np.array(pos)
        #neg_tail = np.repeat(pos, self.ent_num, axis = 0)
        #neg_head = np.repeat(pos, self.ent_num, axis = 0)
        neg_tail = []
        h, r, t = fact
        for i in range(0,self.ent_num):
            if t == i:
                continue
            neg_tail.append((h, r, i))
            
        neg_tail = [fact]+list(set(neg_tail)-self.all_facts)
        
        neg_head = []
        for i in range(0,self.ent_num):
            if h == i:
                continue
            neg_head.append((i, r, t))
        
        neg_head = [fact]+list(set(neg_head)-self.all_facts)
#         print(len(neg_tail))
        return torch.LongTensor(neg_head), torch.LongTensor(neg_tail)

In [6]:
class Measure:
    def __init__(self):
        self.mrr = {'head':0.0, 'tail':0.0}
        self.mr = {'head':0.0, 'tail':0.0}
        self.hit1  = {'head':0.0, 'tail':0.0}
        self.hit3 = {'head':0.0, 'tail':0.0}
        self.hit10 = {'head':0.0, 'tail':0.0}
    def updata(self, rank, head_tail):
        if rank == 1:
            self.hit1[head_tail] += 1
        if rank <= 3:
            self.hit3[head_tail] += 1
        if rank <= 10:
            self.hit10[head_tail] += 1
        self.mr[head_tail] += rank
        self.mrr[head_tail] += 1.0/rank
    def total_deal(self, fact_num):
        print("---------result--------")
        print('hit1:'+str((self.hit1['head']+self.hit1['tail'])/fact_num))
        print('hit3:'+str((self.hit3['head']+self.hit3['tail'])/fact_num))
        print('hit10:'+str((self.hit10['head']+self.hit10['tail'])/fact_num))
        print('mr:'+str((self.mr['head']+self.mr['tail'])/fact_num))
        print('mrr:'+str((self.mrr['head']+self.mrr['tail'])/fact_num))
    def init(self):
        self.mrr = {'head':0.0, 'tail':0.0}
        self.mr = {'head':0.0, 'tail':0.0}
        self.hit1  = {'head':0.0, 'tail':0.0}
        self.hit3 = {'head':0.0, 'tail':0.0}
        self.hit10 = {'head':0.0, 'tail':0.0}
        
        

In [7]:
import os
from tqdm import trange
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
class Trainer:
    def __init__(self, config, loaddata):
        self.config = config
        self.loaddata = loaddata
        self.train_loader, self.valid_loader = self.init_data() 
#         self.model = ComplEX(config)
#         self.model = self.model.to(config.device)
        self.model = torch.load("/home/qiupp/codestore/ComplEX/modelFB/122.pkl",map_location = config.device)
        self.measure = Measure()
        self.fact_num = len(loaddata.data['valid'])
    def init_data(self):
        myTrainData = MyTrainData(self.loaddata, config)
        train_loader = DataLoader(myTrainData,batch_size = config.batch_size, shuffle = True)
        
        myTestData = MyTestData(self.loaddata, 'valid')
        valid_loader =  DataLoader(myTestData, batch_size = 1, shuffle = True)
        return train_loader,valid_loader
    def train(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr = self.config.lr)
        for epoch in trange(1, self.config.epochs+1):
            self.model.train()
            tot = 0.0
            cn = 0
            for i, (pos, neg) in enumerate(self.train_loader):
                neg = neg.view(-1, neg.shape[-1])
                data = torch.cat([pos, neg], dim = 0)
                index = [i for i in range(data.shape[0])]
                shuffle(index)
                data = data[index]
                data = data.to(self.config.device)
                h = data[:,0]
                r = data[:, 1]
                t = data[:, 2]
                labels = data[:, -1]
                optimizer.zero_grad()
#                 print(h.max(),h.min(), t.max(), t.min())
                scores = self.model(h, r, t)
                
                loss = torch.sum(self.model.loss(-labels*scores))+self.config.lambd*self.model.regularization(h, r, t)/h.shape[0]
                loss.backward()
                optimizer.step()
                tot += loss.cpu().item()
                cn = cn + 1
            print("------loss:"+str(tot/cn)+"-------")
#             print('++++++++++++++++++++')
#             self.model.eval()
#             self.measure.init()
#             for i,(head, tail) in enumerate(self.valid_loader):
#                 head = head.view(-1, 3)
#                 head = head.to(self.config.device)
# #                 print(head.shape)
#                 h = head[:, 0]
#                 r = head[:, 1]
#                 t = head[:, 2]
# #                 print(h.max(),h.min(), t.max(), t.min())
#                 score = self.model(h, r, t)
#                 score = score.cpu().data.numpy()
#                 rank = (score >= score[0]).sum()
#                 self.measure.updata(rank, 'head')
                
#                 tail = tail.view(-1, 3)
#                 tail = tail.to(self.config.device)
#                 h = tail[:, 0]
#                 r = tail[:, 1]
#                 t = tail[:, 2]
# #                 print(h.max(),h.min(), t.max(), t.min())
# #                 print(tail)
#                 score = self.model(h, r, t)
#                 score = score.cpu().data.numpy()
#                 rank = (score >= score[0]).sum()
#                 self.measure.updata(rank, 'tail')
#             self.measure.total_deal(fact_num)
            self.save_mode(epoch)
                
                
    def save_mode(self, epoch): 
        #模型存储路径
        save_path = '/home/qiupp/codestore/ComplEX/modelFB/'
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        torch.save(self.model, save_path+str(epoch)+'.pkl')

In [8]:
loaddata = loadData('FB15k-237')
config = Config(loaddata.ent_num(), loaddata.rel_num())
print(config.ent_num, config.rel_num)
trainer = Trainer(config, loaddata)
trainer.train()

14951 1345


  0%|                                                                                                        | 0/30 [00:13<?, ?it/s]


KeyboardInterrupt: 

In [15]:
class Tester:
    def __init__(self, model_path, loaddata, config):
        
        self.loaddata = loaddata
        self.test_loader = self.loadTest()
        self.measure = Measure()
        self.fact_num = len(loaddata.data['test'])
        self.config = config
        self.model = torch.load(model_path, map_location = config.device)
    def loadTest(self):
        myTestData = MyTestData(self.loaddata, 'test')
        test_loader = DataLoader(myTestData, batch_size = 1, shuffle = True)
        return test_loader
    def test(self):
        for i,(head, tail) in enumerate(self.test_loader):
            head = head.view(-1, 3)
            head = head.to(self.config.device)
            h = head[:, 0]
            r = head[:, 1]
            t = head[:, 2]
            score = self.model(h, r, t)
            score = score.cpu().data.numpy()
            rank = (score >= score[0]).sum()
            self.measure.updata(rank, 'head')
            
            tail = tail.view(-1, 3)
            tail = tail.to(self.config.device)
            h = tail[:, 0]
            r = tail[:, 1]
            t = tail[:, 2]
            score = self.model(h, r, t)
            score = score.cpu().data.numpy()
            rank = (score >= score[0]).sum()
            self.measure.updata(rank, 'tail')
        self.measure.total_deal(self.fact_num*2)
            
            

In [16]:
print(config.device)
tester = Tester("/home/qiupp/codestore/ComplEX/modelFB/28.pkl", loaddata, config)
tester.test()

cuda:1
---------result--------
hit1:0.5894093548441706
hit3:0.7245179529718474
hit10:0.8147314248954648
mr:147.18333869411387
mrr:0.671076330104978


In [None]:
!nvidia-smi