In [1]:
import logging
import os

logger = logging.getLogger()

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Setup file handler
fhandler = logging.FileHandler('my.log')
fhandler.setLevel(logging.DEBUG)
fhandler.setFormatter(formatter)

# Configure stream handler for the cells
chandler = logging.StreamHandler()
chandler.setLevel(logging.DEBUG)
chandler.setFormatter(formatter)

# Add both handlers
logger.addHandler(fhandler)
logger.addHandler(chandler)
logger.setLevel(logging.DEBUG)

# Show the handlers
logger.handlers

# Log Something
logger.info("Test info")
logger.debug("Test debug")
logger.error("Test error")


2023-03-10 16:37:57,712 - root - INFO - Test info
2023-03-10 16:37:57,715 - root - DEBUG - Test debug
2023-03-10 16:37:57,716 - root - ERROR - Test error


In [14]:
import torch
import torch.nn as nn


class Config:
    def __init__(self):
        self.ent_num = 10
        self.rel_num = 10
        self.dim = 100
        self.neg_ratio = 100
        self.batch_size = 100
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.lambd = 0.00001
        self.lr = 0.001
        self.epochs = 30
        self.ent_dim = 100
        self.rel_dim = 100
        self.data_path = "../data/"
        self.data_name = "FB15k-237"
        self.model_path = "./models"

    def init_rel_ent(self, ent_num, rel_num):
        self.ent_num = ent_num
        self.rel_num = rel_num


In [4]:
class NTN(nn.Module):
    def __init__(self, config):
        super(NTN, self).__init__()
        self.config = config
        self.ent_emb = nn.Embedding(config.ent_num, config.ent_dim)
        self.rel_emb = nn.Embedding(config.rel_num, config.rel_dim)

        self.mr = nn.Parameter(torch.randn(config.rel_dim, config.ent_dim * config.ent_dim), requires_grad=True)
        self.mr1 = nn.Parameter(torch.randn(config.ent_dim, config.rel_dim))

        self.mr2 = nn.Parameter(torch.randn(config.ent_dim, config.rel_dim))

        self.b = nn.Parameter(torch.randn(1, config.rel_dim))

        self.init()
        self.loss = nn.Softplus()

    def init(self):
        nn.init.xavier_normal_(self.ent_emb.weight.data)
        nn.init.xavier_normal_(self.rel_emb.weight.data)
        nn.init.xavier_normal_(self.mr.data)
        nn.init.xavier_normal_(self.mr1.data)
        nn.init.xavier_normal_(self.mr2.data)

    def forward(self, h, r, t):
        h_e = self.ent_emb(h)
        r_e = self.rel_emb(r)
        t_e = self.ent_emb(t)

        mr1_res = torch.mm(h_e, self.mr1)
        mr2_res = torch.mm(t_e, self.mr2)

        expand_h = h_e.unsqueeze(0).repeat(self.config.rel_dim, 1, 1)
        expand_t = t_e.unsqueeze(-1)
        mr_res = torch.matmul(expand_h,
                              self.mr.view(self.config.rel_dim, self.config.ent_dim, self.config.ent_dim)).permute(1, 0,
                                                                                                                   2)
        #         print(mr_res.shape)
        mr_all = torch.matmul(mr_res, expand_t)
        mr_all = mr_all.squeeze(-1)

        return torch.sum(torch.tanh(mr_all + mr1_res + mr2_res + self.b) * r_e, -1)

    def regularization(self):
        return ((torch.norm(self.ent_emb.weight, 2) ** 2) + (torch.norm(self.rel_emb.weight, 2) ** 2) + (
                torch.norm(self.mr) ** 2) + (torch.norm(self.mr1) ** 2) + (torch.norm(self.mr2) ** 2))





In [4]:
# config = Config()
# h = torch.zeros(10).long()
# r = torch.zeros(10).long()
# t = torch.zeros(10).long()
# print(config.ent_num, config.rel_num)
# model = NTN(config)
# x = model(h, r, t)
# print(x.shape)

In [5]:
class loadData:
    def __init__(self, congfig):
        self.path = congfig.data_path + congfig.data_name + "/"  #文件路径自己设置
        self.rel2id = {}
        self.ent2id = {}
        self.data = {sql: self.read(sql) for sql in ['train', 'valid', 'test']}

    def read(self, file_name):
        with open(self.path + file_name + '.txt', 'r') as f:
            lines = f.readlines()
        triples = []
        for line in lines:
            temp = line.strip().split()
            triples.append((self.get_ent(temp[0]), self.get_rel(temp[1]), self.get_ent(temp[2])))
        return triples

    def get_ent(self, ent):
        if not ent in self.ent2id.keys():
            self.ent2id[ent] = len(self.ent2id)
        return self.ent2id[ent]

    def get_rel(self, rel):
        if not rel in self.rel2id.keys():
            self.rel2id[rel] = len(self.rel2id)
        return self.rel2id[rel]

    def ent_num(self):
        return len(self.ent2id)

    def rel_num(self):
        return len(self.rel2id)


In [6]:
from torch.utils.data import DataLoader, Dataset
import numpy as np
from random import randint, random, shuffle


class MyTrainData(Dataset):
    def __init__(self, loaddata, config):
        super(MyTrainData, self).__init__()
        self.data = loaddata.data['train']
        self.config = config

    def __len__(self):
        return len(self.data)

    def randValue(self, value):
        temp = randint(0, self.config.ent_num - 1)
        while temp == value:
            temp = randint(0, self.config.ent_num - 1)
        return temp

    def __getitem__(self, index):
        fact = self.data[index]
        fact = np.expand_dims(fact, axis=0)

        neg = np.repeat(fact, self.config.neg_ratio, axis=0)
        for i in range(self.config.neg_ratio):
            if random() < 0.5:
                neg[i][0] = self.randValue(neg[i][0])
            else:
                neg[i][2] = self.randValue(neg[i][2])
        fact = np.append(fact, 1)
        neg = np.append(neg, -np.ones((self.config.neg_ratio, 1)), axis=1)
        return torch.LongTensor(fact), torch.LongTensor(neg)





In [7]:
class MyTestData(Dataset):
    def __init__(self, loaddata, data_type):
        self.data = loaddata.data[data_type]
        self.ent_num = loaddata.ent_num()
        self.loaddata = loaddata
        self.all_facts = set(self.get_all_facts())

    def get_all_facts(self):
        triples = []
        for sql in ['train', 'valid', 'test']:
            for fact in self.loaddata.data[sql]:
                triples.append(fact)
        return triples

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        fact = self.data[index]
        neg_tail = []
        h, r, t = fact
        for i in range(0, self.ent_num):
            if t == i:
                continue
            neg_tail.append((h, r, i))

        neg_tail = [fact] + list(set(neg_tail) - self.all_facts)

        neg_head = []
        for i in range(0, self.ent_num):
            if h == i:
                continue
            neg_head.append((i, r, t))

        neg_head = [fact] + list(set(neg_head) - self.all_facts)
        return torch.LongTensor(neg_head), torch.LongTensor(neg_tail)

In [8]:
class Measure:
    def __init__(self):
        self.mrr = {'head': 0.0, 'tail': 0.0}
        self.mr = {'head': 0.0, 'tail': 0.0}
        self.hit1 = {'head': 0.0, 'tail': 0.0}
        self.hit3 = {'head': 0.0, 'tail': 0.0}
        self.hit10 = {'head': 0.0, 'tail': 0.0}

    def updata(self, rank, head_tail):
        if rank == 1:
            self.hit1[head_tail] += 1
        if rank <= 3:
            self.hit3[head_tail] += 1
        if rank <= 10:
            self.hit10[head_tail] += 1
        self.mr[head_tail] += rank
        self.mrr[head_tail] += 1.0 / rank

    def total_deal(self, fact_num):
        print("---------result--------")
        logger.info('hit1:' + str((self.hit1['head'] + self.hit1['tail']) / fact_num))
        logger.info('hit3:' + str((self.hit3['head'] + self.hit3['tail']) / fact_num))
        logger.info('hit10:' + str((self.hit10['head'] + self.hit10['tail']) / fact_num))
        logger.info('mr:' + str((self.mr['head'] + self.mr['tail']) / fact_num))
        logger.info('mrr:' + str((self.mrr['head'] + self.mrr['tail']) / fact_num))
        return (self.mrr['head'] + self.mrr['tail']) / fact_num

    def init(self):
        self.mrr = {'head': 0.0, 'tail': 0.0}
        self.mr = {'head': 0.0, 'tail': 0.0}
        self.hit1 = {'head': 0.0, 'tail': 0.0}
        self.hit3 = {'head': 0.0, 'tail': 0.0}
        self.hit10 = {'head': 0.0, 'tail': 0.0}



In [10]:
import os
from tqdm import trange

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


class Trainer:
    def __init__(self, config, loaddata, model):
        self.config = config
        self.loaddata = loaddata
        self.train_loader, self.valid_loader = self.init_data()
        self.model = model.to(config.device)
        self.measure = Measure()
        self.fact_num = len(loaddata.data['valid'])

    def init_data(self):
        myTrainData = MyTrainData(self.loaddata, config)
        train_loader = DataLoader(myTrainData, batch_size=config.batch_size, shuffle=True)

        myTestData = MyTestData(self.loaddata, 'valid')
        valid_loader = DataLoader(myTestData, batch_size=1, shuffle=True)
        return train_loader, valid_loader

    def train(self):
        best_acc = 0.0
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr)
        for epoch in trange(1, self.config.epochs + 1):
            self.model.train()
            tot = 0.0
            cn = 0
            for i, (pos, neg) in enumerate(self.train_loader):
                neg = neg.view(-1, neg.shape[-1])
                data = torch.cat([pos, neg], dim=0)
                index = [i for i in range(data.shape[0])]
                shuffle(index)
                data = data[index]
                data = data.to(self.config.device)
                h = data[:, 0]
                #                 print(data.shape, self.config.batch_size)
                r = data[:, 1]
                t = data[:, 2]
                labels = data[:, -1]
                optimizer.zero_grad()
                #                 print(h.max(),h.min(), t.max(), t.min())
                scores = self.model(h, r, t)
                #                 print(scores.shape)

                loss = torch.sum(self.model.loss(-labels * scores)) + self.config.lambd * self.model.regularization() /h.shape[0]
            loss.backward()
            optimizer.step()
            tot += loss.cpu().item()
            cn = cn + 1
        print("------loss:" + str(tot / cn) + "-------")

        self.model.eval()
        self.measure.init()
        for i, (head, tail) in enumerate(self.valid_loader):
            head = head.view(-1, 3)
            head = head.to(self.config.device)
            #                 print(head.shape)
            h = head[:, 0]
            r = head[:, 1]
            t = head[:, 2]
            #                 print(h.max(),h.min(), t.max(), t.min())
            score = self.model(h, r, t)
            score = score.cpu().data.numpy()
            rank = (score >= score[0]).sum()
            self.measure.updata(rank, 'head')

            tail = tail.view(-1, 3)
            tail = tail.to(self.config.device)
            h = tail[:, 0]
            r = tail[:, 1]
            t = tail[:, 2]
            #                 print(h.max(),h.min(), t.max(), t.min())
            #                 print(tail)
            score = self.model(h, r, t)
            score = score.cpu().data.numpy()
            rank = (score >= score[0]).sum()
            self.measure.updata(rank, 'tail')
        acc = self.measure.total_deal(self.fact_num * 2)
        if acc > best_acc:
            best_acc = acc
            self.save_mode()


def save_mode(self):
    #模型存储路径
    save_path = self.config.model_path + '/'
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    torch.save(self.model.state_dict(), save_path + config.data_name + "_best_acc.pkl")

In [16]:
config = Config()
loaddata = loadData(config)
config.init_rel_ent(loaddata.ent_num(), loaddata.rel_num())
print(config.batch_size, config.rel_num, config.device)
model = NTN(config)
# model.load_state_dict(torch.load(config.model_path+"/"+config.data_name+"_best_acc.pkl"))

trainer = Trainer(config, loaddata, model)
trainer.train()

100 237 cpu


  0%|          | 0/30 [04:36<?, ?it/s]


KeyboardInterrupt: 

In [11]:
class Tester:
    def __init__(self, model, loaddata, config):
        self.loaddata = loaddata
        self.test_loader = self.loadTest()
        self.measure = Measure()
        self.fact_num = len(loaddata.data['test'])
        self.config = config
        self.model = model.to(config.device)

    def loadTest(self):
        myTestData = MyTestData(self.loaddata, 'test')
        test_loader = DataLoader(myTestData, batch_size=1, shuffle=True)
        return test_loader

    def test(self):
        for i, (head, tail) in enumerate(self.test_loader):
            head = head.view(-1, 3)
            head = head.to(self.config.device)
            h = head[:, 0]
            r = head[:, 1]
            t = head[:, 2]
            score = self.model(h, r, t)
            score = score.cpu().data.numpy()
            rank = (score >= score[0]).sum()
            self.measure.updata(rank, 'head')

            tail = tail.view(-1, 3)
            tail = tail.to(self.config.device)
            h = tail[:, 0]
            r = tail[:, 1]
            t = tail[:, 2]
            score = self.model(h, r, t)
            score = score.cpu().data.numpy()
            rank = (score >= score[0]).sum()
            self.measure.updata(rank, 'tail')
        self.measure.total_deal(self.fact_num * 2)



In [12]:
print(config.device)
model.load_state_dict(torch.load(config.model_path + "/" + config.data_name + "_best_acc.pkl"))
tester = Tester(model, loaddata, config)
tester.test()

cuda:0


2022-05-05 08:20:51,455 - root - INFO - hit1:0.11382292582820287
2022-05-05 08:20:51,456 - root - INFO - hit3:0.22092739177171894
2022-05-05 08:20:51,456 - root - INFO - hit10:0.39130753444737615
2022-05-05 08:20:51,457 - root - INFO - mr:270.7241522525164
2022-05-05 08:20:51,457 - root - INFO - mrr:0.2029397597501765


---------result--------
