In [15]:
import os
import time
from tqdm import tqdm, trange
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 
class Config:
    def __init__(self, ent_num, rel_num):
        self.dim =  100
        self.ent_num = ent_num
        self.rel_num = rel_num
        self.inputDrop = 0.2
        self.hideDrop = 0.3
        self.featureDrop = 0.2
        #self.filter_num = 64
        #self.dim1 = 20
        #self.hide_size = 19456
        self.smooth = 0.1
        self.lr = 0.003
        self.weight_decay = 0.0
        self.epochs = 300
        self.batch_size = 128
        self.out_channels = 100
        self.kernal_size = [5, 5]
        self.reshape = [10, 10]

In [16]:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter 
import torch.nn.functional as F
class ConvR(nn.Module):
    def __init__(self, config):
        super(ConvR, self).__init__()
        self.config = config
        self.out_channels = config.out_channels
        self.kernal_size = config.kernal_size
        self.relation_dim = self.out_channels*self.kernal_size[0]*self.kernal_size[1]
        self.reshape = config.reshape
        
        self.E = nn.Embedding(self.config.ent_num, self.config.dim)
        self.R = nn.Embedding(self.config.rel_num, self.relation_dim)
        self.init()
        
        self.input_drop = nn.Dropout(config.inputDrop)
        self.feature_map_drop = nn.Dropout2d(config.featureDrop)
        self.hidden_drop = nn.Dropout(config.hideDrop)
        
        
        self.bn0 =  nn.BatchNorm2d(1)
        self.bn1 = nn.BatchNorm2d(self.out_channels)
        self.bn2 = nn.BatchNorm1d(config.dim)
        self.bn3 = nn.BatchNorm1d(self.relation_dim)
        
        self.register_parameter('b', Parameter(torch.zeros(config.ent_num)))
        self.filter = [
            self.reshape[0] - self.kernal_size[0]+1, 
            self.reshape[1] - self.kernal_size[1]+1
        ]
        
        self.fc_length = self.filter[0]*self.filter[1]
        
        self.fc = nn.Linear(self.fc_length, self.config.dim)
        
        self.loss = nn.BCELoss()
        
    def init(self):
        nn.init.xavier_normal_(self.E.weight.data)
        nn.init.xavier_normal_(self.R.weight.data)
    
    def forward(self, batch_h, batch_r):
        batch_size = batch_h.shape[0]
        
        e1 = self.E(batch_h).view(-1, 1, *self.reshape)
        e1 = self.bn0(e1).view(1, -1, *self.reshape) #1*batch*10*10
        #print(e1.shape)
        e1 = self.input_drop(e1)
        
        r = self.R(batch_r)
        r = self.bn3(r)
        r = self.input_drop(r)
        r = r.view(-1, 1, *self.kernal_size)#batch*100*1*5*5
        #print(r.shape)
        
        x = F.conv2d(e1, r, groups=batch_size)
        #print(x.shape)
        x = x.view(batch_size, self.out_channels, *self.filter)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.sum(dim = 1)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        x = self.hidden_drop(x)
        x = self.bn2(x)
        x = F.relu(x)
        
        x = torch.mm(x, self.E.weight.transpose(1, 0))
        x += self.b.expand_as(x)
        y = torch.sigmoid(x)
        return y
        
        
        
        

In [17]:
from collections import defaultdict as ddict
class dataProcess:
    def __init__(self, file_name):
        self.dir = '/home/qiupp/code/ConvE/data/'+file_name+'/'
        self.ent2id = {}
        self.rel2id = {}
        self.data = {sql:self.read(sql) for sql in ['train', 'valid', 'test']}
        self.train_sr2o = {}
        self.all_sr2o = ddict(set)
        self.triples = ddict(list)
        self.sub_rel_map_obj()
        self.train_triples()
        self.valid_or_test_triples()
        
    def read(self, file_type):
        with open(self.dir+file_type+'.txt', 'r') as f:
            lines = f.readlines()
        triples = []
        for line in lines:
            temp = line.strip().split()
            triples.append((self.get_ent_id(temp[0]), self.get_rel_id(temp[1]), self.get_ent_id(temp[2])))
        return triples
    def get_ent_id(self, ent):
        if  not ent in self.ent2id.keys():
            self.ent2id[ent] = len(self.ent2id)
        return self.ent2id[ent]
        
    def get_rel_id(self, rel):
        if not rel  in self.rel2id.keys():
            self.rel2id[rel] = len(self.rel2id)
        return self.rel2id[rel]
    
    def ent_num(self):
        return len(self.ent2id)
    def rel_num(self):
        return len(self.rel2id)
    
    def sub_rel_map_obj(self):
        for fact in self.data['train']:
            h, r, t = fact
            self.all_sr2o[(h, r)].add(t)
            self.all_sr2o[(t, r)].add(h)
            #self.all_sr2o[(t, r+ self.rel_num())].add(h)
        #self.train_sr2o = {k: list(v) for k, v in self.all_sr2o.items()}
        for sql in ['valid', 'test']:
            for fact in self.data[sql]:
                h, r, t = fact
                self.all_sr2o[(h, r)].add(t)
                self.all_sr2o[(t,r)].add(h)
    def train_triples(self):
        for triple in self.data['train']:
            h, r, t = triple
            self.triples['train'].append({'triple':(h, r), 'labels':t})
    def valid_or_test_triples(self):
        for sql in ['valid', 'test']:
            for h, r, t in self.data[sql]:
                #rev_rel = r + self.rel_num()
                self.triples[sql+'_tail'].append({'triple':(h, r, t), 'labels':self.all_sr2o[(h, r)]})
                self.triples[sql+'_head'].append({'triple':(t,r, h), 'labels':self.all_sr2o[(t,r)]})
        

In [18]:
from torch.utils.data import DataLoader, Dataset
class MyDataSet(Dataset):
    def __init__(self, dataPro, type, smooth):
        self.type = type
        self.data = dataPro.triples[type]
        self.ent_num = dataPro.ent_num()
        self.smooth = smooth
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        element = self.data[index]
        
        triple = torch.LongTensor(element['triple'])
        labels = self.get_label(element['labels'])
        return triple, labels
    def get_label(self, labels):
        temp_label = torch.zeros(self.ent_num)
        #print(type(labels))
        if isinstance(labels,int):
            temp_label[labels] = 1
            
        else:
            for label in labels:
                temp_label[label] = 1
        if self.smooth != 0.0:
            temp_label = (1-self.smooth)*temp_label+1.0/self.ent_num
        return torch.FloatTensor(temp_label)
        
        

In [19]:
class Measure:
    def __init__(self):
        self.hit1 = 0.0
        self.hit3 = 0.0
        self.hit10 = 0.0
        self.mr = 0.0
        self.mrr = 0.0
    def updata(self, rank):
        if rank == 1:
            self.hit1 += 1.0
        if rank <= 3:
            self.hit3 += 1.0
        if rank <= 10:
            self.hit10 += 1.0
        self.mr += rank
        self.mrr += 1.0 / rank
    def deal(self, facts_num):
        self.hit1 /=  facts_num
        self.hit3 /= facts_num
        self.hit10 /= facts_num
        self.mr /= facts_num
        self.mrr /= facts_num
    def print_(self):
        print('hit1: '+str(self.hit1))
        print('hit3: '+str(self.hit3))
        print('hit10: '+str(self.hit10))
        print('mr: '+str(self.mr))
        print('mrr: '+str(self.mrr))
    def init(self):
        self.hit1 = 0.0
        self.hit3 = 0.0
        self.hit10 = 0.0
        self.mr = 0.0
        self.mrr = 0.0
    
        
        


In [20]:
import numpy as np
import os
class Trainer:
    def __init__(self, dataPro, config, device):
        self.device = device
        self.dataPro = dataPro
        self.mydataset = MyDataSet(dataPro, 'train', 0.1)
        self.measure = Measure()
        self.model = ConvR(config).to(device)
        self.config = config
        self.valid_data = DataLoader(MyDataSet(dataPro, 'valid_tail',0.00), batch_size = config.batch_size)
        self.valid_data_len = len(dataPro.triples['valid_tail'])
        print("valid_data_len"+str(self.valid_data_len))
        
    def MyTrain(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr,weight_decay=self.config.weight_decay )
        dataload = DataLoader(self.mydataset, batch_size = self.config.batch_size, shuffle = True)
        for epoch in trange(1,self.config.epochs+1):
            self.model.train()
            total_loss = 0.0
            for i,(triples, labels) in enumerate(dataload):
                triples = triples.to(self.device)
                labels = labels.to(self.device)
                e1 = triples[:, 0]
                rel = triples[:, 1]
                optimizer.zero_grad()
                #print(e1, rel)
                score = self.model(e1, rel)
#                 if i% 500 == 0:
#                     print(score[1])
                loss = self.model.loss(score, labels)
                total_loss += loss.cpu().item()
                loss.backward()
                optimizer.step()
            #total_loss_to_tensor = torch.FloatTensor(total_loss)
            #print(total_loss_to_tensor)
            print("损失值为："+str(total_loss))
            
            if epoch % 1 == 0:
                self.save_model(epoch)
                self.measure.init()
                self.model.eval()
                with torch.no_grad():
                    tot_num = 0
                    for i, (triples, labels) in enumerate(self.valid_data):
                        tot_num += e1.shape[0]
                        e1 = triples[:, 0]
                        rel = triples[:, 1]
                        e2 = triples[:, 2]
                        e1 = e1.to(self.device)
                        rel = rel.to(self.device)
                        pred = self.model(e1, rel)
                        pred = pred.cpu()
                        self.predict(e2, labels,pred)
                    self.measure.deal(tot_num)
                    print(tot_num)
                    self.measure.print_()
                 
                
    
    def save_model(self, epoch):
        model_path = '/home/qiupp/codestore/ConvR/model/'
        if not os.path.exists(model_path):
            os.mkdir(model_path)
        torch.save(self.model, model_path+str(epoch)+'.pkl')
    def predict(self, e2, labels, pred):
        
        n_batch = torch.arange(e2.shape[0])
        target = pred[n_batch, e2]
        pred = torch.where(labels.byte(), torch.zeros_like(pred), pred)
        pred[n_batch, e2] = target
        #print(target)
        pred = pred.detach().numpy()
        e2 = e2.numpy()
        for i in range(e2.shape[0]):
            ele = e2[i]
            one_pred =  pred[i]
            sc = one_pred[ele]
            one_pred = np.delete(one_pred, ele)
            one_pred = np.insert(one_pred,0, sc )
            rank = (one_pred[0] <= one_pred).sum()
            #print(rank)
            self.measure.updata(rank)
        
        
        

In [21]:
dataPro = dataProcess('FB15k-237')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
config = Config(dataPro.ent_num(), dataPro.rel_num())
print(len(dataPro.data['valid']))
trainer = Trainer(dataPro,config, device)
trainer.MyTrain()


17535
valid_data_len17535


  0%|                                                                                                   | 0/300 [00:00<?, ?it/s]

损失值为：12.809991379501298


  0%|▎                                                                                        | 1/300 [01:23<6:56:03, 83.49s/it]

17523
hit1: 0.04451292586885807
hit3: 0.08548764480967871
hit10: 0.1524282371739999
mr: 3008.529589682132
mrr: 0.08080080621164672


  0%|▎                                                                                        | 1/300 [01:28<7:21:50, 88.66s/it]


KeyboardInterrupt: 

In [None]:
class Tester:
    def __init__(self,dataPro, model_path, device):
        self.model = torch.load(model_path, map_location = device)
        self.test_tail = MyDataSet(dataPro, 'test_tail', 0.0)
        self.test_head = MyDataSet(dataPro, 'test_head', 0.0)
        self.lens = len(dataPro.data['test'])
        self.test_tail_load = DataLoader(self.test_tail, batch_size = 128)
        self.test_head_load = DataLoader(self.test_head, batch_size = 128)
        self.measure = Measure()
        self.device = device
        self.model.eval()
    def test(self):
        with torch.no_grad():
            self.measure.init()
            for i, (triples, labels) in enumerate(self.test_tail_load):
                e1 = triples[:, 0]
                rel = triples[:,1]
                if i == 0:
                    print(labels)
                #print(e1)
                e1 = e1.to(self.device)
                rel = rel.to(self.device)
                pred = self.model(e1, rel)
                #print(pred)
                pred = pred.cpu()
                e2 = triples[:, 2]
                self.predict(e2, labels, pred)
            print(self.measure.mrr)
            self.measure.deal(self.lens)
            self.measure.print_()
            self.measure.init()
        
            for i, (triples, labels) in enumerate(self.test_head_load):
            
                e1 = triples[:, 0]
                rel = triples[:,1]
                e1 = e1.to(self.device)
                rel = rel.to(self.device)
                pred = self.model(e1, rel)
                pred = pred.cpu()
                e2 = triples[:, 2]
                self.predict(e2, labels, pred)
            self.measure.deal(self.lens)
            self.measure.print_()
        
            
            
    def predict(self, e2, labels, pred):
        
        n_batch = torch.arange(e2.shape[0])
        target = pred[n_batch, e2]
        pred = torch.where(labels.byte(), torch.zeros_like(pred), pred)
        pred[n_batch, e2] = target
        #print(target)
        pred = pred.detach().numpy()
        e2 = e2.numpy()
        for i in range(e2.shape[0]):
            ele = e2[i]
            one_pred =  pred[i]
            sc = one_pred[ele]
            one_pred = np.delete(one_pred, ele)
            one_pred = np.insert(one_pred,0, sc )
            rank = (one_pred[0] <= one_pred).sum()
            #print(rank)
            self.measure.updata(rank)    
            
    
    
    
  

In [20]:
model_path =  "/home/qiupp/codestore/ConvR/model/17.pkl"
tester = Tester(dataPro, model_path, device)
tester.test()

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
7771.4487747531275
hit1: 0.294439558291801
hit3: 0.41454119026678393
hit10: 0.5457832502687384
mr: 262.0080132903352
mrr: 0.37972484973874365
hit1: 0.002638522427440633
hit3: 0.010847258868367049
hit10: 0.01744356493696863
mr: 6332.126453630412
mrr: 0.008788045393018128
