In [1]:
class Config:
    def __init__(self, ent_num, rel_num):
        self.dim =  200
        self.ent_num = ent_num
        self.rel_num = rel_num
        self.droprate =  0.4
        self.smooth = 0.1
        self.lr = 0.001
        self.epochs = 300
        self.batch_size = 128
        self.weight_decay = 0.0

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ProjE(nn.Module):
    def __init__(self, config):
        super(ProjE, self).__init__()
        self.E = nn.Embedding(config.ent_num, config.dim)
        self.R = nn.Embedding(config.rel_num, config.dim)
        k = 6.0/math.sqrt(config.dim)
        self.eh = nn.Parameter(torch.FloatTensor(config.dim, config.dim).uniform_(-k, k))
        self.rh = nn.Parameter(torch.FloatTensor(config.dim, config.dim).uniform_(-k, k))
        
        #self.eh = nn.Parameter(torch.FloatTensor(config.dim).uniform_(-k, k))
        #self.rh = nn.Parameter(torch.FloatTensor(config.dim).uniform_(-k, k))
        
        self.bc = nn.Parameter(torch.rand(config.dim))
        self.bp = nn.Parameter(torch.rand(1))
        
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout(config.droprate)
        self.loss = nn.BCELoss()
    def forward(self, h, r):
        h_emb = self.E(h)
        r_emb = self.R(r)#batch*dim
        combination = torch.mm(h_emb, self.eh) + torch.mm(r_emb, self.rh) + self.bc
        #combination = self.eh*h_emb + self.rh*r_emb + self.bc
        
        #batch*1*dim*dim *n = batch *1*n;
        combination = combination.unsqueeze(1)
        out = torch.matmul(self.dropout(self.tanh(combination)), self.E.weight.transpose(0, 1)) + self.bp
        out = torch.sigmoid(out)
        out = out.squeeze(1)
        return out
        
        
        
        

In [16]:
from collections import defaultdict as ddict
class dataProcess:
    def __init__(self, file_name):
        self.dir = '/home/qiupp/code/ConvE/data/'+file_name+'/'
        self.ent2id = {}
        self.rel2id = {}
        self.data = {sql:self.read(sql) for sql in ['train', 'valid', 'test']}
        self.train_sr2o = {}
        self.all_sr2o = ddict(set)
        self.triples = ddict(list)
        self.sub_rel_map_obj()
        self.train_triples()
        self.valid_or_test_triples()
        
    def read(self, file_type):
        with open(self.dir+file_type+'.txt', 'r') as f:
            lines = f.readlines()
        triples = []
        for line in lines:
            temp = line.strip().split()
            triples.append((self.get_ent_id(temp[0]), self.get_rel_id(temp[1]), self.get_ent_id(temp[2])))
        return triples
    def get_ent_id(self, ent):
        if  not ent in self.ent2id.keys():
            self.ent2id[ent] = len(self.ent2id)
        return self.ent2id[ent]
        
    def get_rel_id(self, rel):
        if not rel  in self.rel2id.keys():
            self.rel2id[rel] = len(self.rel2id)
        return self.rel2id[rel]
    def ent_num(self):
        return len(self.ent2id)
    def rel_num(self):
        return len(self.rel2id)
    def sub_rel_map_obj(self):
        for fact in self.data['train']:
            h, r, t = fact
            self.all_sr2o[(h, r)].add(t)
            self.all_sr2o[(t, r+ self.rel_num())].add(h)
        
        self.train_sr2o = {k: list(v) for k, v in self.all_sr2o.items()}
        for sql in ['valid', 'test']:
            for fact in self.data[sql]:
                h, r, t = fact
                self.all_sr2o[(h, r)].add(t)
                self.all_sr2o[(t,r+self.rel_num())].add(h)
    def train_triples(self):
        for (h, r), t in self.train_sr2o.items():
            self.triples['train'].append({'triple':(h, r), 'labels':t})
    def valid_or_test_triples(self):
        for sql in ['valid', 'test']:
            for h, r, t in self.data[sql]:
                rev_rel = r + self.rel_num()
                self.triples[sql+'_tail'].append({'triple':(h, r, t), 'labels':self.all_sr2o[(h, r)]})
                self.triples[sql+'_head'].append({'triple':(t,rev_rel, h), 'labels':self.all_sr2o[(t, rev_rel)]})
        

In [17]:
from torch.utils.data import DataLoader, Dataset
class MyDataSet(Dataset):
    def __init__(self, dataPro, type, smooth):
        self.type = type
        self.data = dataPro.triples[type]
        self.ent_num = dataPro.ent_num()
        self.smooth = smooth
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        element = self.data[index]
        
        triple = torch.LongTensor(element['triple'])
        labels = self.get_label(element['labels'])
        return triple, labels
    def get_label(self, labels):
        temp_label = torch.zeros(self.ent_num)
        for label in labels:
            temp_label[label] = 1
        if self.smooth != 0.0:
            temp_label = (1-self.smooth)*temp_label+1.0/self.ent_num
        return torch.FloatTensor(temp_label)
        
        

In [18]:
class Measure:
    def __init__(self):
        self.hit1 = 0.0
        self.hit3 = 0.0
        self.hit10 = 0.0
        self.mr = 0.0
        self.mrr = 0.0
    def updata(self, rank):
        if rank == 1:
            self.hit1 += 1.0
        if rank <= 3:
            self.hit3 += 1.0
        if rank <= 10:
            self.hit10 += 1.0
        self.mr += rank
        self.mrr += 1.0 / rank
    def deal(self, facts_num):
        self.hit1 /=  facts_num
        self.hit3 /= facts_num
        self.hit10 /= facts_num
        self.mr /= facts_num
        self.mrr /= facts_num
    def print_(self):
        print('hit1: '+str(self.hit1))
        print('hit3: '+str(self.hit3))
        print('hit10: '+str(self.hit10))
        print('mr: '+str(self.mr))
        print('mrr: '+str(self.mrr))
    def init(self):
        self.hit1 = 0.0
        self.hit3 = 0.0
        self.hit10 = 0.0
        self.mr = 0.0
        self.mrr = 0.0
    
        
        

In [19]:
import numpy as np
from tqdm import trange
import os
class Trainer:
    def __init__(self, dataPro, config, device):
        self.device = device
        self.dataPro = dataPro
        self.mydataset = MyDataSet(dataPro, 'train', 0.1)
        self.measure = Measure()
        self.model = ProjE(config).to(device)
        self.config = config
        self.valid_data = DataLoader(MyDataSet(dataPro, 'valid_tail',0.00), batch_size = config.batch_size)
        self.valid_data_len = len(dataPro.triples['valid_tail'])
        print("valid_data_len"+str(self.valid_data_len))
        
    def MyTrain(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr,weight_decay=self.config.weight_decay )
        dataload = DataLoader(self.mydataset, batch_size = self.config.batch_size, shuffle = True)
        for epoch in trange(1,self.config.epochs+1):
            self.model.train()
            total_loss = 0.0
            for i,(triples, labels) in enumerate(dataload):
                triples = triples.to(self.device)
                labels = labels.to(self.device)
                e1 = triples[:, 0]
                rel = triples[:, 1]
                optimizer.zero_grad()
                #print(e1, rel)
                score = self.model(e1, rel)
#                 if i% 500 == 0:
#                     print(score[1])
                loss = self.model.loss(score, labels)
                total_loss += loss.cpu().item()
                loss.backward()
                optimizer.step()
            #total_loss_to_tensor = torch.FloatTensor(total_loss)
            #print(total_loss_to_tensor)
            print("损失值为："+str(total_loss))
            
            if epoch % 5 == 0:
                self.save_model(epoch)
                self.measure.init()
                self.model.eval()
                with torch.no_grad():
                    tot_num = 0
                    for i, (triples, labels) in enumerate(self.valid_data):
                        tot_num += e1.shape[0]
                        e1 = triples[:, 0]
                        rel = triples[:, 1]
                        e2 = triples[:, 2]
                        e1 = e1.to(self.device)
                        rel = rel.to(self.device)
                        pred = self.model(e1, rel)
                        pred = pred.cpu()
                        self.predict(e2, labels,pred)
                    self.measure.deal(tot_num)
                    #print(tot_num)
                    self.measure.print_()
                 
                
    
    def save_model(self, epoch):
        model_path = '/home/qiupp/codestore/ProjE/models/'
        if not os.path.exists(model_path):
            os.mkdir(model_path)
        torch.save(self.model, model_path+str(epoch)+'.pkl')
    def predict(self, e2, labels, pred):
        
        n_batch = torch.arange(e2.shape[0])
        target = pred[n_batch, e2]
        pred = torch.where(labels.byte(), torch.zeros_like(pred), pred)
        pred[n_batch, e2] = target
        #print(target)
        pred = pred.detach().numpy()
        e2 = e2.numpy()
        for i in range(e2.shape[0]):
            ele = e2[i]
            one_pred =  pred[i]
            sc = one_pred[ele]
            one_pred = np.delete(one_pred, ele)
            one_pred = np.insert(one_pred,0, sc )
            rank = (one_pred[0] <= one_pred).sum()
            #print(rank)
            self.measure.updata(rank)
        
        
        

In [20]:
dataPro = dataProcess('FB15k-237')
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
config = Config(dataPro.ent_num(), dataPro.rel_num()*2)
print(len(dataPro.data['valid']))
trainer = Trainer(dataPro,config, device)
trainer.MyTrain()



17535
valid_data_len17535


  0%|▎                                                                                        | 1/300 [00:14<1:13:03, 14.66s/it]

损失值为：3363.754267467186


  1%|▌                                                                                        | 2/300 [00:28<1:11:06, 14.32s/it]

损失值为：18.180644224397838


  1%|▉                                                                                        | 3/300 [00:43<1:12:49, 14.71s/it]

损失值为：14.833424688316882


  1%|█▏                                                                                       | 4/300 [00:59<1:13:19, 14.86s/it]

损失值为：14.50577223720029
损失值为：14.487500756047666


  2%|█▍                                                                                       | 5/300 [01:18<1:21:04, 16.49s/it]

hit1: 0.022101345548239337
hit3: 0.06618952190094475
hit10: 0.11852275980532494
mr: 4358.7343830518175
mrr: 0.05641023581444981


  2%|█▊                                                                                       | 6/300 [01:33<1:18:45, 16.07s/it]

损失值为：14.547336080577224


  2%|██                                                                                       | 7/300 [01:48<1:17:12, 15.81s/it]

损失值为：14.637384577654302


  3%|██▎                                                                                      | 8/300 [02:03<1:15:43, 15.56s/it]

损失值为：14.664096639491618


  3%|██▋                                                                                      | 9/300 [02:18<1:14:34, 15.38s/it]

损失值为：14.725623349659145
损失值为：14.723705558106303


  3%|██▉                                                                                     | 10/300 [02:38<1:20:12, 16.59s/it]

hit1: 0.07340395075866017
hit3: 0.09206985399370168
hit10: 0.13529916976810766
mr: 3886.636243916404
mrr: 0.09473498783804728


  4%|███▏                                                                                    | 11/300 [02:53<1:17:13, 16.03s/it]

损失值为：14.643996466882527


  4%|███▌                                                                                    | 12/300 [03:07<1:15:14, 15.68s/it]

损失值为：14.580760075710714


  4%|███▊                                                                                    | 13/300 [03:22<1:13:39, 15.40s/it]

损失值为：14.557184929028153


  5%|████                                                                                    | 14/300 [03:37<1:12:30, 15.21s/it]

损失值为：14.483465882018209
损失值为：14.421251560561359


  5%|████▍                                                                                   | 15/300 [03:56<1:17:45, 16.37s/it]

hit1: 0.13924992842828515
hit3: 0.17520755797308904
hit10: 0.23778986544517608
mr: 3617.2559977097053
mrr: 0.17111295773305307


  5%|████▋                                                                                   | 16/300 [04:11<1:15:30, 15.95s/it]

损失值为：14.342028421349823


  6%|████▉                                                                                   | 17/300 [04:26<1:13:52, 15.66s/it]

损失值为：14.255457245744765


  6%|█████▎                                                                                  | 18/300 [04:41<1:12:36, 15.45s/it]

损失值为：14.09822151158005


  6%|█████▌                                                                                  | 19/300 [04:56<1:11:39, 15.30s/it]

损失值为：13.921970370225608
损失值为：13.79276301059872


  7%|█████▊                                                                                  | 20/300 [05:15<1:17:08, 16.53s/it]

hit1: 0.1742341826510163
hit3: 0.21328371027769824
hit10: 0.29470369310048666
mr: 3275.182880045806
mrr: 0.2096692062711896


  7%|██████▏                                                                                 | 21/300 [05:32<1:16:38, 16.48s/it]

损失值为：13.576257367618382


  7%|██████▍                                                                                 | 22/300 [05:47<1:14:21, 16.05s/it]

损失值为：13.37153011187911


  8%|██████▋                                                                                 | 23/300 [06:02<1:12:40, 15.74s/it]

损失值为：13.21185643505305


  8%|███████                                                                                 | 24/300 [06:17<1:11:22, 15.52s/it]

损失值为：12.995912225916982
损失值为：12.77499738521874


  8%|███████▎                                                                                | 25/300 [06:36<1:16:19, 16.65s/it]

hit1: 0.20183223590037216
hit3: 0.25863154881190953
hit10: 0.3271686229602061
mr: 2875.884798167764
mrr: 0.24347979874548467


  9%|███████▋                                                                                | 26/300 [06:51<1:13:36, 16.12s/it]

损失值为：12.559710930101573


  9%|███████▉                                                                                | 27/300 [07:06<1:11:41, 15.76s/it]

损失值为：12.360089147929102


  9%|████████▏                                                                               | 28/300 [07:21<1:10:19, 15.51s/it]

损失值为：12.214122628327459


 10%|████████▌                                                                               | 29/300 [07:36<1:09:11, 15.32s/it]

损失值为：12.005327257327735
损失值为：11.80861711455509


 10%|████████▊                                                                               | 30/300 [07:55<1:14:18, 16.51s/it]

hit1: 0.21002004008016031
hit3: 0.2789006584597767
hit10: 0.34572001145147435
mr: 2605.831663326653
mrr: 0.2572996686854263


 10%|█████████                                                                               | 31/300 [08:10<1:11:48, 16.02s/it]

损失值为：11.627767070196569


 11%|█████████▍                                                                              | 32/300 [08:25<1:10:01, 15.68s/it]

损失值为：11.457520791795105


 11%|█████████▋                                                                              | 33/300 [08:39<1:08:40, 15.43s/it]

损失值为：11.26436769682914


 11%|█████████▉                                                                              | 34/300 [08:54<1:07:40, 15.27s/it]

损失值为：11.104870683513582
损失值为：10.954393574502319


 12%|██████████▎                                                                             | 35/300 [09:14<1:12:34, 16.43s/it]

hit1: 0.21855138849126826
hit3: 0.2934440309189808
hit10: 0.3699971371314057
mr: 2421.7163469796737
mrr: 0.2694564709635863


 12%|██████████▌                                                                             | 36/300 [09:29<1:10:30, 16.03s/it]

损失值为：10.816362923011184


 12%|██████████▊                                                                             | 37/300 [09:44<1:08:56, 15.73s/it]

损失值为：10.66775973699987


 13%|███████████▏                                                                            | 38/300 [09:59<1:07:50, 15.54s/it]

损失值为：10.511794088874012


 13%|███████████▍                                                                            | 39/300 [10:14<1:06:58, 15.40s/it]

损失值为：10.328255964443088
损失值为：10.187579988036305


 13%|███████████▋                                                                            | 40/300 [10:33<1:11:50, 16.58s/it]

hit1: 0.22513598625823075
hit3: 0.30088748926424275
hit10: 0.38179215574005154
mr: 2307.7054108216435
mrr: 0.27774610523980486


 14%|████████████                                                                            | 41/300 [10:48<1:08:56, 15.97s/it]

损失值为：10.035231268499047


 14%|████████████▎                                                                           | 42/300 [11:02<1:07:04, 15.60s/it]

损失值为：9.88606122136116


 14%|████████████▌                                                                           | 43/300 [11:17<1:05:36, 15.32s/it]

损失值为：9.74089361121878


 15%|████████████▉                                                                           | 44/300 [11:32<1:04:26, 15.10s/it]

损失值为：9.584168389905244
损失值为：9.424278124701232


 15%|█████████████▏                                                                          | 45/300 [11:51<1:09:03, 16.25s/it]

hit1: 0.23366733466933867
hit3: 0.3186372745490982
hit10: 0.40074434583452617
mr: 2213.408932150014
mrr: 0.2906562069229736


 15%|█████████████▍                                                                          | 46/300 [12:05<1:06:39, 15.75s/it]

损失值为：9.30918749468401


 16%|█████████████▊                                                                          | 47/300 [12:20<1:05:01, 15.42s/it]

损失值为：9.174394571222365


 16%|██████████████                                                                          | 48/300 [12:34<1:03:50, 15.20s/it]

损失值为：9.02630351902917


 16%|██████████████▎                                                                         | 49/300 [12:50<1:03:20, 15.14s/it]

损失值为：8.879098047036678
损失值为：8.761931137181818


 17%|██████████████▋                                                                         | 50/300 [13:08<1:07:52, 16.29s/it]

hit1: 0.24088176352705412
hit3: 0.3320354995705697
hit10: 0.42175780131691953
mr: 2137.607729745205
mrr: 0.3013638090625652


 17%|██████████████▉                                                                         | 51/300 [13:23<1:05:01, 15.67s/it]

损失值为：8.62396102072671


 17%|███████████████▎                                                                        | 52/300 [13:37<1:02:57, 15.23s/it]

损失值为：8.529761186335236


 18%|███████████████▌                                                                        | 53/300 [13:51<1:01:21, 14.90s/it]

损失值为：8.3879580097273


 18%|███████████████▊                                                                        | 54/300 [14:05<1:00:13, 14.69s/it]

损失值为：8.234971835277975
损失值为：8.10941005917266


 18%|████████████████▏                                                                       | 55/300 [14:24<1:04:39, 15.84s/it]

hit1: 0.24780990552533638
hit3: 0.3448038935012883
hit10: 0.43418265101631837
mr: 2080.8518179215575
mrr: 0.3112083410881758


 19%|████████████████▍                                                                       | 56/300 [14:38<1:02:24, 15.35s/it]

损失值为：7.992814378812909


 19%|████████████████▋                                                                       | 57/300 [14:52<1:01:08, 15.10s/it]

损失值为：7.8694629543460906


 19%|█████████████████                                                                       | 58/300 [15:08<1:00:52, 15.09s/it]

损失值为：7.759684414137155


 20%|█████████████████▎                                                                      | 59/300 [15:23<1:00:39, 15.10s/it]

损失值为：7.632360314019024
损失值为：7.513805011752993


 20%|█████████████████▌                                                                      | 60/300 [15:42<1:05:32, 16.38s/it]

hit1: 0.25324935585456626
hit3: 0.34738047523618665
hit10: 0.43985113083309474
mr: 2034.8316060692814
mrr: 0.31660170452177083


 20%|█████████████████▉                                                                      | 61/300 [15:57<1:03:35, 15.96s/it]

损失值为：7.399132426362485


 21%|██████████████████▏                                                                     | 62/300 [16:12<1:02:16, 15.70s/it]

损失值为：7.269829350989312


 21%|██████████████████▍                                                                     | 63/300 [16:27<1:01:16, 15.51s/it]

损失值为：7.167306074872613


 21%|██████████████████▊                                                                     | 64/300 [16:42<1:00:30, 15.39s/it]

损失值为：7.0537875718437135
损失值为：6.950381953269243


 22%|███████████████████                                                                     | 65/300 [17:02<1:04:56, 16.58s/it]

hit1: 0.25662754079587746
hit3: 0.35596908101918123
hit10: 0.4570855997709705
mr: 1968.7978242198683
mrr: 0.32407488023878334


 22%|███████████████████▎                                                                    | 66/300 [17:17<1:02:54, 16.13s/it]

损失值为：6.8434994420968


 22%|███████████████████▋                                                                    | 67/300 [17:32<1:01:35, 15.86s/it]

损失值为：6.734260741155595


 23%|███████████████████▉                                                                    | 68/300 [17:47<1:00:29, 15.65s/it]

损失值为：6.620194626506418


 23%|████████████████████▋                                                                     | 69/300 [18:02<59:42, 15.51s/it]

损失值为：6.520468291826546
损失值为：6.423873818013817


 23%|████████████████████▌                                                                   | 70/300 [18:22<1:04:08, 16.73s/it]

hit1: 0.2589750930432293
hit3: 0.3624391640423705
hit10: 0.4669911251073576
mr: 1895.4452905811622
mrr: 0.3286532982426389


 24%|████████████████████▊                                                                   | 71/300 [18:37<1:01:56, 16.23s/it]

损失值为：6.314680885989219


 24%|█████████████████████                                                                   | 72/300 [18:52<1:00:19, 15.87s/it]

损失值为：6.209348425734788


 24%|█████████████████████▉                                                                    | 73/300 [19:07<59:10, 15.64s/it]

损失值为：6.124221497215331


 25%|██████████████████████▏                                                                   | 74/300 [19:22<58:16, 15.47s/it]

损失值为：6.028537467587739
损失值为：5.921514432411641


 25%|██████████████████████                                                                  | 75/300 [19:42<1:02:29, 16.67s/it]

hit1: 0.26263956484397366
hit3: 0.3691955339249928
hit10: 0.4798167764099628
mr: 1799.9447466361294
mrr: 0.33580277600111025


 25%|██████████████████████▎                                                                 | 76/300 [19:57<1:00:27, 16.19s/it]

损失值为：5.84829613333568


 26%|███████████████████████                                                                   | 77/300 [20:12<58:52, 15.84s/it]

损失值为：5.756033339072019


 26%|███████████████████████▍                                                                  | 78/300 [20:27<57:39, 15.58s/it]

损失值为：5.654589647427201


 26%|███████████████████████▋                                                                  | 79/300 [20:42<56:44, 15.40s/it]

损失值为：5.571468282258138
损失值为：5.485224981093779


 27%|███████████████████████▍                                                                | 80/300 [21:01<1:00:50, 16.59s/it]

hit1: 0.27220154594904095
hit3: 0.3806470083023189
hit10: 0.4918980818780418
mr: 1718.3157744059547
mrr: 0.3464484102501714


 27%|████████████████████████▎                                                                 | 81/300 [21:16<58:46, 16.10s/it]

损失值为：5.4107673284597695


 27%|████████████████████████▌                                                                 | 82/300 [21:31<57:14, 15.76s/it]

损失值为：5.335961579112336


 28%|████████████████████████▉                                                                 | 83/300 [21:46<56:07, 15.52s/it]

损失值为：5.237748353742063


 28%|█████████████████████████▏                                                                | 84/300 [22:01<55:08, 15.32s/it]

损失值为：5.169337625382468
损失值为：5.105533191934228


 28%|█████████████████████████▌                                                                | 85/300 [22:20<59:12, 16.52s/it]

hit1: 0.2759232751216719
hit3: 0.38786143716003435
hit10: 0.5002576581734899
mr: 1631.7492699685085
mrr: 0.3521926630566415


 29%|█████████████████████████▊                                                                | 86/300 [22:35<57:17, 16.06s/it]

损失值为：5.017489206977189


 29%|██████████████████████████                                                                | 87/300 [22:50<55:51, 15.74s/it]

损失值为：4.939611198147759


 29%|██████████████████████████▍                                                               | 88/300 [23:05<54:41, 15.48s/it]

损失值为：4.873119342839345


 30%|██████████████████████████▋                                                               | 89/300 [23:20<53:48, 15.30s/it]

损失值为：4.81423976784572
损失值为：4.727946565952152


 30%|███████████████████████████                                                               | 90/300 [23:39<57:36, 16.46s/it]

hit1: 0.28164901231033496
hit3: 0.3928428285141712
hit10: 0.5095906097910106
mr: 1530.0678499856856
mrr: 0.35858601636808496


 30%|███████████████████████████▎                                                              | 91/300 [23:54<55:39, 15.98s/it]

损失值为：4.663845242466778


 30%|███████████████████████████▎                                                              | 91/300 [24:00<55:07, 15.83s/it]


KeyboardInterrupt: 

In [12]:
class Tester:
    def __init__(self,dataPro, model_path, device):
        self.model = torch.load(model_path, map_location = device)
        self.test_tail = MyDataSet(dataPro, 'test_tail', 0.0)
        self.test_head = MyDataSet(dataPro, 'test_head', 0.0)
        self.lens = len(dataPro.data['test'])
        self.test_tail_load = DataLoader(self.test_tail, batch_size = 128)
        self.test_head_load = DataLoader(self.test_head, batch_size = 128)
        self.measure = Measure()
        self.device = device
        self.model.eval()
    def test(self):
        with torch.no_grad():
            self.measure.init()
            for i, (triples, labels) in enumerate(self.test_tail_load):
                e1 = triples[:, 0]
                rel = triples[:,1]
                if i == 0:
                    print(labels)
                #print(e1)
                e1 = e1.to(self.device)
                rel = rel.to(self.device)
                pred = self.model(e1, rel)
                #print(pred)
                pred = pred.cpu()
                e2 = triples[:, 2]
                self.predict(e2, labels, pred)
            print(self.measure.mrr)
            self.measure.deal(self.lens)
            self.measure.print_()
            self.measure.init()
        
            for i, (triples, labels) in enumerate(self.test_head_load):
            
                e1 = triples[:, 0]
                rel = triples[:,1]
                e1 = e1.to(self.device)
                rel = rel.to(self.device)
                pred = self.model(e1, rel)
                pred = pred.cpu()
                e2 = triples[:, 2]
                self.predict(e2, labels, pred)
            self.measure.deal(self.lens)
            self.measure.print_()
        
            
            
    def predict(self, e2, labels, pred):
        
        n_batch = torch.arange(e2.shape[0])
        target = pred[n_batch, e2]
        pred = torch.where(labels.byte(), torch.zeros_like(pred), pred)
        pred[n_batch, e2] = target
        #print(target)
        pred = pred.detach().numpy()
        e2 = e2.numpy()
        for i in range(e2.shape[0]):
            ele = e2[i]
            one_pred =  pred[i]
            sc = one_pred[ele]
            one_pred = np.delete(one_pred, ele)
            one_pred = np.insert(one_pred,0, sc )
            rank = (one_pred[0] <= one_pred).sum()
            #print(rank)
            self.measure.updata(rank)    
            
    
    
    
  

In [14]:
model_path =  "/home/qiupp/codestore/ProjE/models/55.pkl"
tester = Tester(dataPro, model_path, device)
tester.test()

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
7329.083649609044
hit1: 0.27284276360793513
hit3: 0.3943613798495065
hit10: 0.5242841786377407
mr: 296.7852535913222
mrr: 0.35811021448299835
hit1: 0.09581745333724226
hit3: 0.16783934330108471
hit10: 0.279585654255839
mr: 593.4205022964917
mrr: 0.1569592806785564
