In [1]:
import os
import time
from tqdm import tqdm, trange
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 
class Config:
    def __init__(self, ent_num, rel_num):
        self.dim =  200
        self.ent_num = ent_num
        self.rel_num = rel_num
        self.inputDrop = 0.2
        self.hideDrop = 0.3
        self.featureDrop = 0.3
        self.filter_num = 64
        self.dim1 = 20
        self.hide_size = 19456
        self.smooth = 0.1
        self.lr = 0.0001
        self.weight_decay = 0.0
        self.epochs = 300
        self.batch_size = 128

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import  Parameter
class ConvE(nn.Module):
    def __init__(self, config):
        super(ConvE, self).__init__()
        self.config = config
        self.ent_embs = nn.Embedding(self.config.ent_num, self.config.dim)
        self.rel_embs = nn.Embedding(self.config.rel_num, self.config.dim)
        self.input_drop = nn.Dropout(config.inputDrop)
        self.hide_drop = nn.Dropout(config.hideDrop)
        self.feature_drop = nn.Dropout2d(config.featureDrop)
        self.conv = nn.Conv2d(1, 64, (3, 3), bias=True)
        self.bn0 = nn.BatchNorm2d(1)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm1d(config.dim)
        self.fc = nn.Linear(config.hide_size, config.dim)
        self.dim = config.dim #dim = 200
        self.dim1 = config.dim1  #dim1 = 20
        self.dim2 = self.dim // self.dim1 # dim2 = 10
        self.loss = nn.BCELoss()
        self.register_parameter('b',Parameter(torch.zeros(config.ent_num)))
        self.init()
        
    def init(self):
        nn.init.xavier_normal_(self.ent_embs.weight.data)
        nn.init.xavier_normal_(self.rel_embs.weight.data)
    def forward(self, e1, rel):
        e1_emb = self.ent_embs(e1).view(-1, 1, self.dim1, self.dim2)#el_emb; batch*1*20*10
        rel_emb = self.rel_embs(rel).view(-1, 1 ,self.dim1, self.dim2)
        conv_input = torch.cat([e1_emb, rel_emb], dim = 2)#con_input: bath*1*40*10
        conv_input = torch.transpose(conv_input, 2, 1).reshape((-1, 1, 2*self.dim1, self.dim2))
        conv_input = self.bn0(conv_input)
        x = self.input_drop(conv_input)
        x = self.conv(conv_input)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_drop(x)
        x = x.view(x.shape[0], -1)#bacth*hide_size(38*8*32 = 9728)
        x = self.fc(x)
        x = self.hide_drop(x)
        x = self.bn2(x)
        x = F.relu(x)#batch*dim          ent_ems.weight   dim*ent_num
        #print(x.shape, self.ent_embs.weight.shape)
        x = torch.mm(x, self.ent_embs.weight.transpose(1, 0))
        x += self.b.expand_as(x)
        pred = torch.sigmoid(x)
        return pred
        
        
        
        

In [3]:
#config = Config(4, 4)
#model = ConvE(config)
#e1 = torch.LongTensor([1, 2, 1])
#rel = torch.LongTensor([1, 2, 3])
#print(model(e1, rel))

In [4]:
from collections import defaultdict as ddict
class dataProcess:
    def __init__(self, file_name):
        self.dir = '/home/qiupp/code/ConvE/data/'+file_name+'/'
        self.ent2id = {}
        self.rel2id = {}
        self.data = {sql:self.read(sql) for sql in ['train', 'valid', 'test']}
        self.train_sr2o = {}
        self.all_sr2o = ddict(set)
        self.triples = ddict(list)
        self.sub_rel_map_obj()
        self.train_triples()
        self.valid_or_test_triples()
        
    def read(self, file_type):
        with open(self.dir+file_type+'.txt', 'r') as f:
            lines = f.readlines()
        triples = []
        for line in lines:
            temp = line.strip().split()
            triples.append((self.get_ent_id(temp[0]), self.get_rel_id(temp[1]), self.get_ent_id(temp[2])))
        return triples
    def get_ent_id(self, ent):
        if  not ent in self.ent2id.keys():
            self.ent2id[ent] = len(self.ent2id)
        return self.ent2id[ent]
        
    def get_rel_id(self, rel):
        if not rel  in self.rel2id.keys():
            self.rel2id[rel] = len(self.rel2id)
        return self.rel2id[rel]
    def ent_num(self):
        return len(self.ent2id)
    def rel_num(self):
        return len(self.rel2id)
    def sub_rel_map_obj(self):
        for fact in self.data['train']:
            h, r, t = fact
            self.all_sr2o[(h, r)].add(t)
            self.all_sr2o[(t, r+ self.rel_num())].add(h)
        
        self.train_sr2o = {k: list(v) for k, v in self.all_sr2o.items()}
        for sql in ['valid', 'test']:
            for fact in self.data[sql]:
                h, r, t = fact
                self.all_sr2o[(h, r)].add(t)
                self.all_sr2o[(t,r+self.rel_num())].add(h)
    def train_triples(self):
        for (h, r), t in self.train_sr2o.items():
            self.triples['train'].append({'triple':(h, r), 'labels':t})
    def valid_or_test_triples(self):
        for sql in ['valid', 'test']:
            for h, r, t in self.data[sql]:
                rev_rel = r + self.rel_num()
                self.triples[sql+'_tail'].append({'triple':(h, r, t), 'labels':self.all_sr2o[(h, r)]})
                self.triples[sql+'_head'].append({'triple':(t,rev_rel, h), 'labels':self.all_sr2o[(t, rev_rel)]})
        

In [5]:
from torch.utils.data import DataLoader, Dataset
class MyDataSet(Dataset):
    def __init__(self, dataPro, type, smooth):
        self.type = type
        self.data = dataPro.triples[type]
        self.ent_num = dataPro.ent_num()
        self.smooth = smooth
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        element = self.data[index]
        
        triple = torch.LongTensor(element['triple'])
        labels = self.get_label(element['labels'])
        return triple, labels
    def get_label(self, labels):
        temp_label = torch.zeros(self.ent_num)
        for label in labels:
            temp_label[label] = 1
        if self.smooth != 0.0:
            temp_label = (1-self.smooth)*temp_label+1.0/self.ent_num
        return torch.FloatTensor(temp_label)
        
        

In [6]:
class Measure:
    def __init__(self):
        self.hit1 = 0.0
        self.hit3 = 0.0
        self.hit10 = 0.0
        self.mr = 0.0
        self.mrr = 0.0
    def updata(self, rank):
        if rank == 1:
            self.hit1 += 1.0
        if rank <= 3:
            self.hit3 += 1.0
        if rank <= 10:
            self.hit10 += 1.0
        self.mr += rank
        self.mrr += 1.0 / rank
    def deal(self, facts_num):
        self.hit1 /=  facts_num
        self.hit3 /= facts_num
        self.hit10 /= facts_num
        self.mr /= facts_num
        self.mrr /= facts_num
    def print_(self):
        print('hit1: '+str(self.hit1))
        print('hit3: '+str(self.hit3))
        print('hit10: '+str(self.hit10))
        print('mr: '+str(self.mr))
        print('mrr: '+str(self.mrr))
    def init(self):
        self.hit1 = 0.0
        self.hit3 = 0.0
        self.hit10 = 0.0
        self.mr = 0.0
        self.mrr = 0.0
    
        
        

In [7]:
import numpy as np
import os
class Trainer:
    def __init__(self, dataPro, config, device):
        self.device = device
        self.dataPro = dataPro
        self.mydataset = MyDataSet(dataPro, 'train', 0.1)
        self.measure = Measure()
        self.model = ConvE(config).to(device)
        self.config = config
        self.valid_data = DataLoader(MyDataSet(dataPro, 'valid_tail',0.00), batch_size = config.batch_size)
        self.valid_data_len = len(dataPro.triples['valid_tail'])
        print("valid_data_len"+str(self.valid_data_len))
        
    def MyTrain(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr,weight_decay=self.config.weight_decay )
        dataload = DataLoader(self.mydataset, batch_size = self.config.batch_size, shuffle = True)
        for epoch in trange(1,self.config.epochs+1):
            self.model.train()
            total_loss = 0.0
            for i,(triples, labels) in enumerate(dataload):
                triples = triples.to(self.device)
                labels = labels.to(self.device)
                e1 = triples[:, 0]
                rel = triples[:, 1]
                optimizer.zero_grad()
                #print(e1, rel)
                score = self.model(e1, rel)
#                 if i% 500 == 0:
#                     print(score[1])
                loss = self.model.loss(score, labels)
                total_loss += loss.cpu().item()
                loss.backward()
                optimizer.step()
            #total_loss_to_tensor = torch.FloatTensor(total_loss)
            #print(total_loss_to_tensor)
            print("损失值为："+str(total_loss))
            
            if epoch % 1 == 0:
                self.save_model(epoch)
                self.measure.init()
                self.model.eval()
                with torch.no_grad():
                    tot_num = 0
                    for i, (triples, labels) in enumerate(self.valid_data):
                        tot_num += e1.shape[0]
                        e1 = triples[:, 0]
                        rel = triples[:, 1]
                        e2 = triples[:, 2]
                        e1 = e1.to(self.device)
                        rel = rel.to(self.device)
                        pred = self.model(e1, rel)
                        pred = pred.cpu()
                        self.predict(e2, labels,pred)
                    self.measure.deal(tot_num)
                    #print(tot_num)
                    self.measure.print_()
                 
                
    
    def save_model(self, epoch):
        model_path = '/home/qiupp/code/ConvE/models/'
        if not os.path.exists(model_path):
            os.mkdir(model_path)
        torch.save(self.model, model_path+str(epoch)+'.pkl')
    def predict(self, e2, labels, pred):
        
        n_batch = torch.arange(e2.shape[0])
        target = pred[n_batch, e2]
        pred = torch.where(labels.byte(), torch.zeros_like(pred), pred)
        pred[n_batch, e2] = target
        #print(target)
        pred = pred.detach().numpy()
        e2 = e2.numpy()
        for i in range(e2.shape[0]):
            ele = e2[i]
            one_pred =  pred[i]
            sc = one_pred[ele]
            one_pred = np.delete(one_pred, ele)
            one_pred = np.insert(one_pred,0, sc )
            rank = (one_pred[0] <= one_pred).sum()
            #print(rank)
            self.measure.updata(rank)
        
        
        

In [8]:
dataPro = dataProcess('FB15k-237')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
config = Config(dataPro.ent_num(), dataPro.rel_num()*2)
print(len(dataPro.data['valid']))
trainer = Trainer(dataPro,config, device)
trainer.MyTrain()



17535
valid_data_len17535


  0%|                                                                                                   | 0/300 [00:00<?, ?it/s]

损失值为：123.67273036297411


  pred = torch.where(labels.byte(), torch.zeros_like(pred), pred)
  0%|▎                                                                                        | 1/300 [00:24<2:00:25, 24.16s/it]

hit1: 0.09544803893501289
hit3: 0.10775837389063841
hit10: 0.14285714285714285
mr: 4815.245805897509
mrr: 0.11077693791927085
损失值为：9.453660823404789


  1%|▌                                                                                        | 2/300 [01:21<3:35:32, 43.40s/it]

hit1: 0.11898081878041798
hit3: 0.1372459204122531
hit10: 0.19181219582021186
mr: 3818.9705124534785
mrr: 0.14150349964244144
损失值为：4.570632785791531


  1%|▉                                                                                        | 3/300 [02:08<3:44:47, 45.41s/it]

hit1: 0.14056684798167765
hit3: 0.17411966790724306
hit10: 0.22404809619238478
mr: 3184.323790438019
mrr: 0.16977583821323627
损失值为：3.3803812444675714


  1%|█▏                                                                                       | 4/300 [03:01<3:59:00, 48.45s/it]

hit1: 0.16095047237331805
hit3: 0.20647008302318923
hit10: 0.2718007443458345
mr: 2730.4263956484397
mrr: 0.19691674393858533
损失值为：2.9656731536379084


  2%|█▍                                                                                       | 5/300 [03:52<4:01:24, 49.10s/it]

hit1: 0.17829945605496708
hit3: 0.23469796736329804
hit10: 0.3027769825365016
mr: 2394.3884912682506
mrr: 0.22047426723740635
损失值为：2.7915118895471096


  2%|█▊                                                                                       | 6/300 [04:44<4:05:17, 50.06s/it]

hit1: 0.18202118522759805
hit3: 0.25078728886344115
hit10: 0.3342685370741483
mr: 2135.344288577154
mrr: 0.23153853478050726
损失值为：2.6883374421158805


  2%|██                                                                                       | 7/300 [05:17<3:38:21, 44.71s/it]

hit1: 0.19198396793587175
hit3: 0.2645863154881191
hit10: 0.36364156885198967
mr: 1906.8139707987402
mrr: 0.2468623920675806
损失值为：2.5878592174267396


  3%|██▎                                                                                      | 8/300 [05:41<3:05:30, 38.12s/it]

hit1: 0.20876037789865445
hit3: 0.28674491840824506
hit10: 0.37933008874892643
mr: 1634.722244488978
mrr: 0.26550334084545985
损失值为：2.4803286376409233


  3%|██▋                                                                                      | 9/300 [06:33<3:26:04, 42.49s/it]

hit1: 0.21780704265674206
hit3: 0.30329229888348125
hit10: 0.3968508445462353
mr: 1309.154995705697
mrr: 0.2789681252863997
损失值为：2.3817509156651795


  3%|██▉                                                                                     | 10/300 [07:28<3:43:34, 46.26s/it]

hit1: 0.22479244202691095
hit3: 0.3180647008302319
hit10: 0.4176925279129688
mr: 1016.575780131692
mrr: 0.2908952955489617
损失值为：2.3070615048054606


  4%|███▏                                                                                    | 11/300 [08:16<3:45:42, 46.86s/it]

hit1: 0.23527054108216433
hit3: 0.32888634411680506
hit10: 0.4349842542227312
mr: 755.8280561122244
mrr: 0.30245683169409243
损失值为：2.235534643754363


  4%|███▌                                                                                    | 12/300 [09:03<3:45:07, 46.90s/it]

hit1: 0.24151159461780705
hit3: 0.3398797595190381
hit10: 0.4506155167477813
mr: 595.4436873747495
mrr: 0.312869181203836
损失值为：2.187629775609821


  4%|███▊                                                                                    | 13/300 [09:59<3:56:38, 49.47s/it]

hit1: 0.24660750071571716
hit3: 0.3504151159461781
hit10: 0.4607500715717149
mr: 497.64689378757515
mrr: 0.3200907369001833
损失值为：2.1470480238785967


  5%|████                                                                                    | 14/300 [10:49<3:57:28, 49.82s/it]

hit1: 0.2533066132264529
hit3: 0.3566561694818208
hit10: 0.46973947895791585
mr: 422.5898654451761
mrr: 0.32730963180097816
损失值为：2.1115457486594096


  5%|████▍                                                                                   | 15/300 [11:13<3:19:00, 41.90s/it]

hit1: 0.2606928141998282
hit3: 0.36432865731462927
hit10: 0.4807901517320355
mr: 377.80933295161753
mrr: 0.3347507743740598
损失值为：2.072384605417028


  5%|████▋                                                                                   | 16/300 [11:38<2:53:57, 36.75s/it]

hit1: 0.2645863154881191
hit3: 0.36993987975951903
hit10: 0.48628685943315203
mr: 338.7569997137131
mrr: 0.33980512528180307
损失值为：2.048118746140972


  6%|████▉                                                                                   | 17/300 [12:02<2:35:57, 33.07s/it]

hit1: 0.26618952190094475
hit3: 0.3727454909819639
hit10: 0.4913255081591755
mr: 313.34749498997996
mrr: 0.3424034444193058
损失值为：2.0267560386564583


  6%|█████▎                                                                                  | 18/300 [12:26<2:22:48, 30.39s/it]

hit1: 0.2675636988262239
hit3: 0.38013169195533925
hit10: 0.4960778700257658
mr: 296.32527912968794
mrr: 0.3459205616890691
损失值为：1.997544475714676


  6%|█████▌                                                                                  | 19/300 [12:50<2:13:06, 28.42s/it]

hit1: 0.2749498997995992
hit3: 0.38614371600343544
hit10: 0.5046092184368738
mr: 281.2155167477813
mrr: 0.3528271884912547
损失值为：1.9790663719177246


  7%|█████▊                                                                                  | 20/300 [13:14<2:06:03, 27.01s/it]

hit1: 0.27827082736902375
hit3: 0.38946464357286
hit10: 0.5079301460062983
mr: 268.4705983395362
mrr: 0.35665735894877193
损失值为：1.9634904003469273


  7%|██████▏                                                                                 | 21/300 [13:38<2:01:31, 26.13s/it]

hit1: 0.28067563698826226
hit3: 0.39541941024906957
hit10: 0.5118809046664758
mr: 254.5580303464071
mrr: 0.3602322763248881
损失值为：1.9430071823298931


  7%|██████▍                                                                                 | 22/300 [14:02<1:57:48, 25.43s/it]

hit1: 0.2839965645576868
hit3: 0.3987403378184941
hit10: 0.5192671056398511
mr: 245.8728313770398
mrr: 0.3641883619498083
损失值为：1.928626561188139


  8%|██████▋                                                                                 | 23/300 [14:26<1:55:07, 24.94s/it]

hit1: 0.2849699398797595
hit3: 0.4020040080160321
hit10: 0.5220154594904094
mr: 234.75224735184656
mrr: 0.366223219501109
损失值为：1.9116828092373908


  8%|███████                                                                                 | 24/300 [14:50<1:53:25, 24.66s/it]

hit1: 0.288920698539937
hit3: 0.40589750930432295
hit10: 0.5245347838534211
mr: 227.15574005153164
mrr: 0.3702162508284524
损失值为：1.8967784906271845


  8%|███████▎                                                                                | 25/300 [15:14<1:52:09, 24.47s/it]

hit1: 0.29292871457200115
hit3: 0.40927569424563415
hit10: 0.5287718293730318
mr: 218.4440881763527
mrr: 0.37390590244654637
损失值为：1.8858761611627415


  9%|███████▋                                                                                | 26/300 [15:38<1:51:01, 24.31s/it]

hit1: 0.29561981105067275
hit3: 0.4106498711709132
hit10: 0.5330088748926425
mr: 213.61941024906957
mrr: 0.37632502124907413
损失值为：1.8698129842523485


  9%|███████▉                                                                                | 27/300 [16:02<1:50:47, 24.35s/it]

hit1: 0.2992270254795305
hit3: 0.4144861150873175
hit10: 0.5373031777841397
mr: 209.31285427998856
mrr: 0.3801366837728469
损失值为：1.8540815477026626


  9%|████████▏                                                                               | 28/300 [16:57<2:32:38, 33.67s/it]

hit1: 0.29974234182651016
hit3: 0.41815058688806184
hit10: 0.5415402233037504
mr: 202.00801603206412
mrr: 0.38234883839947403
损失值为：1.8433220533188432


 10%|████████▌                                                                               | 29/300 [17:55<3:04:13, 40.79s/it]

hit1: 0.30323504151159464
hit3: 0.4190667048382479
hit10: 0.5466361294016605
mr: 197.21643286573146
mrr: 0.3852079491583884
损失值为：1.82844506832771


 10%|████████▊                                                                               | 30/300 [18:54<3:29:01, 46.45s/it]

hit1: 0.30501002004008015
hit3: 0.42347552247351844
hit10: 0.5486973947895791
mr: 191.51680503864873
mrr: 0.38791783503167887
损失值为：1.8163489050930366


 10%|█████████                                                                               | 31/300 [19:53<3:44:14, 50.02s/it]

hit1: 0.3043801889493272
hit3: 0.42496421414257085
hit10: 0.5519038076152305
mr: 188.01024906956772
mrr: 0.38843271786649414
损失值为：1.805627437774092


 11%|█████████▍                                                                              | 32/300 [20:45<3:46:04, 50.61s/it]

hit1: 0.307357572287432
hit3: 0.4272545090180361
hit10: 0.5552819925565416
mr: 183.86847981677641
mrr: 0.3917041466305216
损失值为：1.7963070028927177


 11%|█████████▋                                                                              | 33/300 [21:39<3:50:13, 51.74s/it]

hit1: 0.3089607787002577
hit3: 0.4303464070999141
hit10: 0.5581448611508731
mr: 181.02673919267104
mrr: 0.39355255077482526
损失值为：1.782819879008457


 11%|█████████▉                                                                              | 34/300 [22:34<3:54:00, 52.78s/it]

hit1: 0.31176638992270256
hit3: 0.434697967363298
hit10: 0.5627827082736903
mr: 176.69876896650445
mrr: 0.3968633055372817
损失值为：1.7748408098705113


 12%|██████████▎                                                                             | 35/300 [23:30<3:56:39, 53.58s/it]

hit1: 0.31279702261666187
hit3: 0.435785857429144
hit10: 0.565931863727455
mr: 173.95344975665617
mrr: 0.3984693708677793
损失值为：1.7621306562796235


 12%|██████████▌                                                                             | 36/300 [24:22<3:54:17, 53.25s/it]

hit1: 0.31480103063269393
hit3: 0.4373890638419696
hit10: 0.5675350701402806
mr: 170.43864872602347
mrr: 0.40051518939253233
损失值为：1.7555804182775319


 12%|██████████▊                                                                             | 37/300 [25:15<3:52:27, 53.03s/it]

hit1: 0.3166905239049528
hit3: 0.4414543372459204
hit10: 0.5723446893787575
mr: 167.83544231319783
mrr: 0.40300417808943795
损失值为：1.7455870234407485


 13%|███████████▏                                                                            | 38/300 [26:10<3:54:17, 53.65s/it]

hit1: 0.3192098482679645
hit3: 0.44248496993987974
hit10: 0.5734325794446035
mr: 166.12865731462927
mrr: 0.4053684967222965
损失值为：1.7357602333649993


 13%|███████████▍                                                                            | 39/300 [27:00<3:48:11, 52.46s/it]

hit1: 0.3202977383338105
hit3: 0.4443744632121386
hit10: 0.5766962496421414
mr: 163.6403091898082
mrr: 0.4067314680442324
损失值为：1.7254469306208193


 13%|███████████▋                                                                            | 40/300 [27:48<3:41:49, 51.19s/it]

hit1: 0.32253077583738904
hit3: 0.44660750071571714
hit10: 0.5780131691955339
mr: 161.75470941883768
mrr: 0.4089796551540985
损失值为：1.7195133094210178


 14%|████████████                                                                            | 41/300 [28:40<3:41:36, 51.34s/it]

hit1: 0.3226452905811623
hit3: 0.4476953907815631
hit10: 0.5807615230460922
mr: 158.27632407672488
mrr: 0.40988921711828447
损失值为：1.7074601973872632


 14%|████████████▎                                                                           | 42/300 [29:34<3:45:11, 52.37s/it]

hit1: 0.32499284282851415
hit3: 0.45015745777268823
hit10: 0.5829373031777841
mr: 156.8826223876324
mrr: 0.41214742452554676
损失值为：1.7012796173803508


 14%|████████████▌                                                                           | 43/300 [30:29<3:47:58, 53.22s/it]

hit1: 0.3260234755224735
hit3: 0.4503292298883481
hit10: 0.5855711422845692
mr: 154.1939879759519
mrr: 0.4132600246178013
损失值为：1.694797778269276


 15%|████████████▉                                                                           | 44/300 [31:25<3:50:05, 53.93s/it]

hit1: 0.3289436014886917
hit3: 0.45508159175493845
hit10: 0.5864300028628686
mr: 153.46430002862868
mrr: 0.41614438915533986
损失值为：1.686352259828709


 15%|█████████████▏                                                                          | 45/300 [32:13<3:41:29, 52.11s/it]

hit1: 0.3296306899513312
hit3: 0.45588319496135127
hit10: 0.5897509304322932
mr: 152.61941024906957
mrr: 0.4170295260830658
损失值为：1.6757813539588824


 15%|█████████████▍                                                                          | 46/300 [32:45<3:14:38, 45.98s/it]

hit1: 0.3299169768107644
hit3: 0.4564557686802176
hit10: 0.5902662467792729
mr: 149.82902948754653
mrr: 0.41829733493543014
损失值为：1.6695831143297255


 16%|█████████████▊                                                                          | 47/300 [33:09<2:46:04, 39.38s/it]

hit1: 0.3314056684798168
hit3: 0.4588033209275694
hit10: 0.5915831663326653
mr: 147.8354995705697
mrr: 0.419939252378771
损失值为：1.6643766664201394


 16%|██████████████                                                                          | 48/300 [33:33<2:26:18, 34.84s/it]

hit1: 0.3319782421986831
hit3: 0.4612081305468079
hit10: 0.5942170054394503
mr: 147.77108502719724
mrr: 0.42105138713360507
损失值为：1.6550240977667272


 16%|██████████████▎                                                                         | 49/300 [33:57<2:12:01, 31.56s/it]

hit1: 0.33266533066132264
hit3: 0.46263956484397367
hit10: 0.5955911823647294
mr: 144.66229602061264
mrr: 0.42178406076016
损失值为：1.6480094797443599


 17%|██████████████▋                                                                         | 50/300 [34:21<2:02:08, 29.31s/it]

hit1: 0.3338677354709419
hit3: 0.463670197537933
hit10: 0.5961064987117092
mr: 144.3751503006012
mrr: 0.4230600767432955
损失值为：1.641967012314126


 17%|██████████████▉                                                                         | 51/300 [34:45<1:55:20, 27.79s/it]

hit1: 0.33432579444603494
hit3: 0.46555969081019183
hit10: 0.5979387346120814
mr: 142.23716003435442
mrr: 0.4242340105559397
损失值为：1.6329507402842864


 17%|███████████████▎                                                                        | 52/300 [35:09<1:50:41, 26.78s/it]

hit1: 0.3357572287432007
hit3: 0.46733466933867734
hit10: 0.598454050959061
mr: 142.01981105067279
mrr: 0.42530670804131765
损失值为：1.630333533976227


 18%|███████████████▌                                                                        | 53/300 [35:34<1:47:02, 26.00s/it]

hit1: 0.33518465502433437
hit3: 0.46802175780131694
hit10: 0.6004580589750931
mr: 140.54320068708847
mrr: 0.42561666340125986
损失值为：1.618699756800197


 18%|███████████████▊                                                                        | 54/300 [35:58<1:44:17, 25.44s/it]

hit1: 0.3381620383624392
hit3: 0.4679645004294303
hit10: 0.6017177211565989
mr: 140.9400515316347
mrr: 0.4275907543425222
损失值为：1.6172220355365425


 18%|████████████████▏                                                                       | 55/300 [36:22<1:42:27, 25.09s/it]

hit1: 0.337417692527913
hit3: 0.4693959347265961
hit10: 0.6017177211565989
mr: 139.48771829373032
mrr: 0.42758124231610073
损失值为：1.6110363982152194


 19%|████████████████▍                                                                       | 56/300 [36:46<1:40:50, 24.80s/it]

hit1: 0.3373604351560263
hit3: 0.4700830231892356
hit10: 0.6047523618665903
mr: 139.1789865445176
mrr: 0.4277971531545125
损失值为：1.6028581319842488


 19%|████████████████▋                                                                       | 57/300 [37:32<2:05:47, 31.06s/it]

hit1: 0.33873461208130545
hit3: 0.4713426853707415
hit10: 0.6080160320641282
mr: 137.3399942742628
mrr: 0.42942347708215856
损失值为：1.5983121088938788


 19%|█████████████████                                                                       | 58/300 [38:20<2:25:48, 36.15s/it]

hit1: 0.3381620383624392
hit3: 0.47059833953621527
hit10: 0.6065273403950758
mr: 138.0847409103922
mrr: 0.4289312535811438
损失值为：1.5945953142363578


 20%|█████████████████▎                                                                      | 59/300 [39:12<2:44:30, 40.96s/it]

hit1: 0.3382765531062124
hit3: 0.46979673632980246
hit10: 0.6077870025765817
mr: 136.92424849699398
mrr: 0.42923964464675923
损失值为：1.586737362202257


 20%|█████████████████▌                                                                      | 60/300 [40:02<2:54:20, 43.58s/it]

hit1: 0.3401087890065846
hit3: 0.47225880332092757
hit10: 0.6091039221299742
mr: 135.95768680217577
mrr: 0.4306210148745327
损失值为：1.5795602181460708


 20%|█████████████████▉                                                                      | 61/300 [40:54<3:04:19, 46.27s/it]

hit1: 0.3402805611222445
hit3: 0.47260234755224734
hit10: 0.6108216432865732
mr: 135.33014600629832
mrr: 0.43123894658215944
损失值为：1.5736489845439792


 21%|██████████████████▏                                                                     | 62/300 [41:43<3:05:53, 46.87s/it]

hit1: 0.3399370168909247
hit3: 0.4727168622960206
hit10: 0.6123103349556256
mr: 136.1421700543945
mrr: 0.4310956027514857
损失值为：1.571907832287252


 21%|██████████████████▍                                                                     | 63/300 [42:47<3:26:23, 52.25s/it]

hit1: 0.3391926710563985
hit3: 0.4720870312052677
hit10: 0.6123675923275121
mr: 135.39599198396795
mrr: 0.4313732371040222
损失值为：1.563640457810834


 21%|██████████████████▊                                                                     | 64/300 [43:40<3:25:49, 52.33s/it]

hit1: 0.3406813627254509
hit3: 0.4739765244775265
hit10: 0.6129974234182651
mr: 135.13277984540508
mrr: 0.4319753396944702
损失值为：1.5601406085770577


 22%|███████████████████                                                                     | 65/300 [44:29<3:21:40, 51.49s/it]

hit1: 0.33867735470941884
hit3: 0.4747781276839393
hit10: 0.6140853134841111
mr: 135.26916690523905
mrr: 0.43150723798000284
损失值为：1.5551119587616995


 22%|███████████████████▎                                                                    | 66/300 [45:25<3:25:11, 52.61s/it]

hit1: 0.3393071858001718
hit3: 0.4738620097337532
hit10: 0.6152877182937303
mr: 135.90518179215573
mrr: 0.4319778671585217
损失值为：1.5521972936112434


 22%|███████████████████▋                                                                    | 67/300 [46:11<3:16:57, 50.72s/it]

hit1: 0.3383338104780991
hit3: 0.47403378184941314
hit10: 0.6158602920125966
mr: 137.33140566847982
mrr: 0.4313151319660146
损失值为：1.5454261604463682


 23%|███████████████████▉                                                                    | 68/300 [46:41<2:51:50, 44.44s/it]

hit1: 0.3408531348411108
hit3: 0.47432006870884624
hit10: 0.6169481820784426
mr: 136.54926996850844
mrr: 0.4333744312380468
损失值为：1.5400776222813874


 23%|████████████████████▏                                                                   | 69/300 [47:46<3:15:24, 50.75s/it]

hit1: 0.34033781849413114
hit3: 0.47787002576581733
hit10: 0.6159748067563698
mr: 135.3361580303464
mrr: 0.43325202625816805
损失值为：1.538944575469941


 23%|████████████████████▌                                                                   | 70/300 [48:43<3:21:29, 52.56s/it]

hit1: 0.34056684798167763
hit3: 0.47500715717148584
hit10: 0.6182651016318351
mr: 136.07002576581735
mrr: 0.4335275038886014
损失值为：1.5310437850421295


 24%|████████████████████▊                                                                   | 71/300 [49:42<3:27:24, 54.34s/it]

hit1: 0.3406813627254509
hit3: 0.47580876037789865
hit10: 0.6191239622101345
mr: 135.67987403378186
mrr: 0.4336045591815634
损失值为：1.5280451336875558


 24%|█████████████████████                                                                   | 72/300 [50:32<3:22:16, 53.23s/it]

hit1: 0.3399942742628113
hit3: 0.477927283137704
hit10: 0.6174634984254223
mr: 136.94938448325223
mrr: 0.4335704352065721
损失值为：1.521868989802897


 24%|█████████████████████▍                                                                  | 73/300 [51:21<3:16:06, 51.84s/it]

hit1: 0.3397079874033782
hit3: 0.47746922416261095
hit10: 0.6180933295161752
mr: 137.78133409676497
mrr: 0.43364860780723624
损失值为：1.5188080130610615


 25%|█████████████████████▋                                                                  | 74/300 [52:22<3:26:11, 54.74s/it]

hit1: 0.3407958774692242
hit3: 0.476839393071858
hit10: 0.617807042656742
mr: 138.025765817349
mrr: 0.4340682129554641
损失值为：1.519717762595974


 25%|██████████████████████                                                                  | 75/300 [53:02<3:08:56, 50.38s/it]

hit1: 0.3433152018322359
hit3: 0.4783280847409104
hit10: 0.6193529916976811
mr: 137.46395648439736
mrr: 0.43565718756092053
损失值为：1.5108614061027765


 25%|██████████████████████▎                                                                 | 76/300 [53:58<3:14:04, 51.99s/it]

hit1: 0.34194102490695677
hit3: 0.47918694531920986
hit10: 0.6186086458631549
mr: 137.36175207557974
mrr: 0.4356333197894513
损失值为：1.507819625781849


 26%|██████████████████████▌                                                                 | 77/300 [54:48<3:11:10, 51.44s/it]

hit1: 0.34136845118809045
hit3: 0.47746922416261095
hit10: 0.6202691096478672
mr: 138.15722874320068
mrr: 0.43458900511707077
损失值为：1.5078031609300524


 26%|██████████████████████▉                                                                 | 78/300 [55:42<3:12:42, 52.08s/it]

hit1: 0.34360148869166907
hit3: 0.477927283137704
hit10: 0.6190094474663613
mr: 140.06183796163756
mrr: 0.4359193255855359
损失值为：1.504186131292954


 26%|███████████████████████▏                                                                | 79/300 [56:36<3:14:06, 52.70s/it]

hit1: 0.3407958774692242
hit3: 0.4771829373031778
hit10: 0.6194675064414543
mr: 140.4255940452333
mrr: 0.43478931279604466
损失值为：1.5024001045385376


 27%|███████████████████████▍                                                                | 80/300 [57:30<3:14:25, 53.02s/it]

hit1: 0.341597480675637
hit3: 0.4785571142284569
hit10: 0.6208989407386201
mr: 141.51600343544231
mrr: 0.434985604698632
损失值为：1.4945215780753642


 27%|███████████████████████▊                                                                | 81/300 [58:33<3:24:08, 55.93s/it]

hit1: 0.3407386200973375
hit3: 0.4779845405095906
hit10: 0.6222731176638993
mr: 139.79290008588606
mrr: 0.4345158880974948
损失值为：1.490773972705938


 27%|████████████████████████                                                                | 82/300 [59:23<3:17:43, 54.42s/it]

hit1: 0.34154022330375033
hit3: 0.4764385914686516
hit10: 0.6229602061265388
mr: 138.67265960492412
mrr: 0.43465434569297584
损失值为：1.4898947265464813


 28%|███████████████████████▊                                                              | 83/300 [1:00:14<3:12:15, 53.16s/it]

hit1: 0.34131119381620384
hit3: 0.4764958488405382
hit10: 0.6211852275980533
mr: 141.58465502433438
mrr: 0.434364289111168
损失值为：1.4830785929225385


 28%|████████████████████████                                                              | 84/300 [1:01:09<3:13:57, 53.88s/it]

hit1: 0.3393644431720584
hit3: 0.4762668193529917
hit10: 0.6227311766389922
mr: 141.17257371886632
mrr: 0.433605747485172
损失值为：1.482524967403151


 28%|████████████████████████▎                                                             | 85/300 [1:01:58<3:07:21, 52.29s/it]

hit1: 0.3417692527912969
hit3: 0.476839393071858
hit10: 0.6225021471514457
mr: 141.52092756942457
mrr: 0.4350263873125355
损失值为：1.4800327636767179


 29%|████████████████████████▋                                                             | 86/300 [1:02:45<3:00:47, 50.69s/it]

hit1: 0.3423418265101632
hit3: 0.476839393071858
hit10: 0.6227311766389922
mr: 140.58768966504437
mrr: 0.43516790992191434
损失值为：1.4744756725849584


 29%|████████████████████████▉                                                             | 87/300 [1:03:48<3:13:17, 54.45s/it]

hit1: 0.3416547380475236
hit3: 0.4779845405095906
hit10: 0.6207844259948468
mr: 142.33157744059548
mrr: 0.4354567173173554
损失值为：1.4716593408957124


 29%|█████████████████████████▏                                                            | 88/300 [1:04:37<3:06:56, 52.91s/it]

hit1: 0.34194102490695677
hit3: 0.4791296879473232
hit10: 0.6224448897795591
mr: 142.25880332092757
mrr: 0.43587127924707747
损失值为：1.4706133322324604


 30%|█████████████████████████▌                                                            | 89/300 [1:05:29<3:05:04, 52.63s/it]

hit1: 0.33982250214715143
hit3: 0.47603778986544515
hit10: 0.621528771829373
mr: 143.08663040366446
mrr: 0.43366848339394765
损失值为：1.4682887953240424


 30%|█████████████████████████▊                                                            | 90/300 [1:06:24<3:06:27, 53.27s/it]

hit1: 0.34125393644431723
hit3: 0.47575150300601204
hit10: 0.621528771829373
mr: 142.5746922416261
mrr: 0.43431096227303145
损失值为：1.4648120967904106


 30%|██████████████████████████                                                            | 91/300 [1:07:14<3:02:10, 52.30s/it]

hit1: 0.33982250214715143
hit3: 0.4756942456341254
hit10: 0.6210134554823934
mr: 144.20246206699113
mrr: 0.4335164893039336
损失值为：1.4631343457149342


 31%|██████████████████████████▎                                                           | 92/300 [1:08:20<3:15:56, 56.52s/it]

hit1: 0.3409103922129974
hit3: 0.4773547094188377
hit10: 0.6217005439450329
mr: 143.42679645004296
mrr: 0.434353459243937
损失值为：1.4568656096234918


 31%|██████████████████████████▋                                                           | 93/300 [1:09:09<3:06:41, 54.11s/it]

hit1: 0.3393644431720584
hit3: 0.4764958488405382
hit10: 0.6212997423418265
mr: 145.9541368451188
mrr: 0.4333406321124328
损失值为：1.4618334979750216


 31%|██████████████████████████▉                                                           | 94/300 [1:10:09<3:12:18, 56.01s/it]

hit1: 0.34033781849413114
hit3: 0.47603778986544515
hit10: 0.6214142570855997
mr: 146.0363011737761
mrr: 0.4338737201325822
损失值为：1.4528969189850613


 32%|███████████████████████████▏                                                          | 95/300 [1:11:09<3:15:21, 57.18s/it]

hit1: 0.3401087890065846
hit3: 0.47689665044374463
hit10: 0.6223303750357858
mr: 145.8462639564844
mrr: 0.4340776623790574
损失值为：1.4521248545497656


 32%|███████████████████████████▌                                                          | 96/300 [1:12:06<3:14:13, 57.12s/it]

hit1: 0.33924992842828516
hit3: 0.47598053249355854
hit10: 0.6201545949040939
mr: 147.44975665616948
mrr: 0.4328867232024494
损失值为：1.4493288815720007


 32%|███████████████████████████▊                                                          | 97/300 [1:13:13<3:23:08, 60.04s/it]

hit1: 0.34022330375035786
hit3: 0.4763240767248783
hit10: 0.620612653879187
mr: 146.52711136558833
mrr: 0.4337224749514411
损失值为：1.4465167203452438


 33%|████████████████████████████                                                          | 98/300 [1:14:04<3:12:43, 57.25s/it]

hit1: 0.3390781563126252
hit3: 0.475293444030919
hit10: 0.6217005439450329
mr: 148.06023475522474
mrr: 0.43289049792756196
损失值为：1.4456589707406238


 33%|████████████████████████████▍                                                         | 99/300 [1:14:54<3:04:20, 55.03s/it]

hit1: 0.34016604637847125
hit3: 0.4763240767248783
hit10: 0.6202118522759805
mr: 148.25330661322644
mrr: 0.4335509419282985
损失值为：1.4415743704885244


 33%|████████████████████████████▎                                                        | 100/300 [1:15:49<3:03:40, 55.10s/it]

hit1: 0.33850558259375896
hit3: 0.4763240767248783
hit10: 0.6206699112510736
mr: 147.98963641568852
mrr: 0.4328507264804177
损失值为：1.4432968412293121


 34%|████████████████████████████▌                                                        | 101/300 [1:16:39<2:57:11, 53.42s/it]

hit1: 0.3382765531062124
hit3: 0.4755224735184655
hit10: 0.6196392785571142
mr: 150.4985399370169
mrr: 0.43260911192329854
损失值为：1.4365449228789657


 34%|████████████████████████████▉                                                        | 102/300 [1:17:44<3:08:38, 57.17s/it]

hit1: 0.3381047809905525
hit3: 0.4724878328084741
hit10: 0.6207844259948468
mr: 149.94400229029486
mrr: 0.43184359958892277
损失值为：1.4340675486018881


 34%|█████████████████████████████▏                                                       | 103/300 [1:18:37<3:03:23, 55.86s/it]

hit1: 0.337417692527913
hit3: 0.47414829659318636
hit10: 0.6183796163756083
mr: 151.05508159175494
mrr: 0.43151052695657927
损失值为：1.4344888883642852


 35%|█████████████████████████████▍                                                       | 104/300 [1:19:32<3:01:35, 55.59s/it]

hit1: 0.3363870598339536
hit3: 0.4735184655024334
hit10: 0.6203836243916404
mr: 149.42931577440595
mrr: 0.4312192453337222
损失值为：1.431663048104383


 35%|█████████████████████████████▋                                                       | 105/300 [1:20:25<2:57:42, 54.68s/it]

hit1: 0.33718866304036643
hit3: 0.47391926710563986
hit10: 0.6169481820784426
mr: 152.8549098196393
mrr: 0.4313958492377808
损失值为：1.4294349170522764


 35%|██████████████████████████████                                                       | 106/300 [1:21:01<2:39:20, 49.28s/it]

hit1: 0.33661608932150017
hit3: 0.4738047523618666
hit10: 0.6198110506727741
mr: 154.4372745490982
mrr: 0.43075002602654106
损失值为：1.4257738420274109


 36%|██████████████████████████████▎                                                      | 107/300 [1:21:26<2:14:31, 41.82s/it]

hit1: 0.3374749498997996
hit3: 0.475293444030919
hit10: 0.6170626968222158
mr: 155.93907815631263
mrr: 0.43111949187837817
损失值为：1.4272671254584566


 36%|██████████████████████████████▌                                                      | 108/300 [1:21:50<1:56:44, 36.48s/it]

hit1: 0.33867735470941884
hit3: 0.47466361294016607
hit10: 0.6189521900944747
mr: 155.8249069567707
mrr: 0.43214903973867835
损失值为：1.4278850220143795


 36%|██████████████████████████████▉                                                      | 109/300 [1:22:14<1:44:08, 32.72s/it]

hit1: 0.3354136845118809
hit3: 0.4730031491554538
hit10: 0.6174062410535357
mr: 154.96478671628972
mrr: 0.4299753340872884
损失值为：1.4199537461972795


 37%|███████████████████████████████▏                                                     | 110/300 [1:22:44<1:41:23, 32.02s/it]

hit1: 0.33421127970226167
hit3: 0.4728886344116805
hit10: 0.6193529916976811
mr: 153.50518179215575
mrr: 0.4292308854402229
损失值为：1.4193019836675376


 37%|███████████████████████████████▍                                                     | 111/300 [1:23:38<2:01:46, 38.66s/it]

hit1: 0.3366733466933868
hit3: 0.471113655883195
hit10: 0.6182651016318351
mr: 155.75133123389637
mrr: 0.43036748459639707
损失值为：1.418655544752255


 37%|███████████████████████████████▋                                                     | 112/300 [1:24:32<2:14:52, 43.04s/it]

hit1: 0.3346693386773547
hit3: 0.47323217864300027
hit10: 0.6165473804752362
mr: 155.3640423704552
mrr: 0.42972643849504316
损失值为：1.4168139168759808


 38%|████████████████████████████████                                                     | 113/300 [1:25:27<2:25:30, 46.69s/it]

hit1: 0.33678786143716005
hit3: 0.474205553965073
hit10: 0.6174634984254223
mr: 156.34955625536787
mrr: 0.43112078111044605
损失值为：1.4202384108211845


 38%|████████████████████████████████▎                                                    | 114/300 [1:26:20<2:30:26, 48.53s/it]

hit1: 0.3354709418837675
hit3: 0.4718007443458345
hit10: 0.6168909247065559
mr: 160.8711136558832
mrr: 0.4292391538359572
损失值为：1.413743745884858


 38%|████████████████████████████████▌                                                    | 115/300 [1:27:10<2:31:30, 49.14s/it]

hit1: 0.33610077297452046
hit3: 0.4713426853707415
hit10: 0.6162038362439164
mr: 158.025250501002
mrr: 0.429723090802977
损失值为：1.41194586455822


 39%|████████████████████████████████▊                                                    | 116/300 [1:28:01<2:32:21, 49.68s/it]

hit1: 0.3381620383624392
hit3: 0.47340395075866015
hit10: 0.6164901231033496
mr: 158.53335241912396
mrr: 0.4308742875048799
损失值为：1.4094085964607075


 39%|█████████████████████████████████▏                                                   | 117/300 [1:28:54<2:34:43, 50.73s/it]

hit1: 0.33381047809905523
hit3: 0.4718007443458345
hit10: 0.6177497852848555
mr: 160.71102204408817
mrr: 0.4284874640260654
损失值为：1.4081622103694826


 39%|█████████████████████████████████▍                                                   | 118/300 [1:29:43<2:32:06, 50.14s/it]

hit1: 0.33529916976810764
hit3: 0.4719152590896078
hit10: 0.6180933295161752
mr: 159.99421700543945
mrr: 0.429016477505457


 39%|█████████████████████████████████▍                                                   | 118/300 [1:30:06<2:18:59, 45.82s/it]


KeyboardInterrupt: 

In [9]:
class Tester:
    def __init__(self,dataPro, model_path, device):
        self.model = torch.load(model_path, map_location = device)
        self.test_tail = MyDataSet(dataPro, 'test_tail', 0.0)
        self.test_head = MyDataSet(dataPro, 'test_head', 0.0)
        self.lens = len(dataPro.data['test'])
        self.test_tail_load = DataLoader(self.test_tail, batch_size = 128)
        self.test_head_load = DataLoader(self.test_head, batch_size = 128)
        self.measure = Measure()
        self.device = device
        self.model.eval()
    def test(self):
        with torch.no_grad():
            self.measure.init()
            for i, (triples, labels) in enumerate(self.test_tail_load):
                e1 = triples[:, 0]
                rel = triples[:,1]
                if i == 0:
                    print(labels)
                #print(e1)
                e1 = e1.to(self.device)
                rel = rel.to(self.device)
                pred = self.model(e1, rel)
                #print(pred)
                pred = pred.cpu()
                e2 = triples[:, 2]
                self.predict(e2, labels, pred)
            print(self.measure.mrr)
            self.measure.deal(self.lens)
            self.measure.print_()
            self.measure.init()
        
            for i, (triples, labels) in enumerate(self.test_head_load):
            
                e1 = triples[:, 0]
                rel = triples[:,1]
                e1 = e1.to(self.device)
                rel = rel.to(self.device)
                pred = self.model(e1, rel)
                pred = pred.cpu()
                e2 = triples[:, 2]
                self.predict(e2, labels, pred)
            self.measure.deal(self.lens)
            self.measure.print_()
        
            
            
    def predict(self, e2, labels, pred):
        
        n_batch = torch.arange(e2.shape[0])
        target = pred[n_batch, e2]
        pred = torch.where(labels.byte(), torch.zeros_like(pred), pred)
        pred[n_batch, e2] = target
        #print(target)
        pred = pred.detach().numpy()
        e2 = e2.numpy()
        for i in range(e2.shape[0]):
            ele = e2[i]
            one_pred =  pred[i]
            sc = one_pred[ele]
            one_pred = np.delete(one_pred, ele)
            one_pred = np.insert(one_pred,0, sc )
            rank = (one_pred[0] <= one_pred).sum()
            #print(rank)
            self.measure.updata(rank)    
            
    
    
    
  

In [10]:
model_path =  "/home/qiupp/code/ConvE/models/118.pkl"
tester = Tester(dataPro, model_path, device)
tester.test()

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
8541.983142886143
hit1: 0.32018958272256426
hit3: 0.46100850190559955
hit10: 0.6100361575295612
mr: 176.86157529561223
mrr: 0.4173743351356466
hit1: 0.13402716700869735
hit3: 0.23624548030880485
hit10: 0.379849506498583
mr: 344.93643115410924
mrr: 0.21511006519040557


[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


KeyError: 'a'

0.9688639316269664


tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [1]:

import torch
torch.randn((1, 2))

tensor([[-1.7090,  1.3172]])