# TuckER的pytorch实现


In [2]:
import numpy as np
import torch
from torch.utils import data
from torch.nn.init import xavier_normal_
import torch.nn as nn
import tqdm

## 构建数据集


In [3]:
# 训练集和验证集
class TripleDataset(data.Dataset):
    def __init__(self, ent2id, rel2id, triple_data_list):
        self.ent2id = ent2id
        self.rel2id = rel2id
        self.data = triple_data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        head, relation, tail = self.data[index]
        head_id = self.ent2id[head]
        relation_id = self.rel2id[relation]
        tail_id = self.ent2id[tail]
        return head_id, relation_id, tail_id

# 测试集    
class TestDataset(data.Dataset):
    def __init__(self, ent2id, rel2id, test_data_list):
        self.ent2id = ent2id
        self.rel2id = rel2id
        self.data = test_data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        head, relation = self.data[index]
        head_id = self.ent2id[head]
        relation_id = self.rel2id[relation]
        return head_id, relation_id

## TuckER模型

In [4]:
class TuckER(nn.Module):
    def __init__(self, entity_num, relation_num, dim=100, input_dropout=0.3, hidden_dropout1=0.4, hidden_dropout2=0.5):
        # d
        super(TuckER, self).__init__()

        self.dim = dim
        self.entity_num = entity_num

        self.E = nn.Embedding(entity_num, dim)
        self.R = nn.Embedding(relation_num, dim)
        self.W = nn.Parameter(torch.tensor(np.random.uniform(-1, 1, (dim, dim, dim)), 
                                    dtype=torch.float, device="cuda", requires_grad=True))

        self.input_dropout = nn.Dropout(input_dropout)
        self.hidden_dropout1 = nn.Dropout(hidden_dropout1)
        self.hidden_dropout2 = nn.Dropout(hidden_dropout2)
        self.loss = nn.BCELoss()

        self.bn0 = nn.BatchNorm1d(dim)
        self.bn1 = nn.BatchNorm1d(dim)
        

    def init(self):
        xavier_normal_(self.E.weight.data)
        xavier_normal_(self.R.weight.data)

    def forward(self, e1_idx, r_idx):
        e1 = self.E(e1_idx)
        x = self.bn0(e1)
        x = self.input_dropout(x)
        x = x.view(-1, 1, e1.size(1))

        r = self.R(r_idx)
        W_mat = torch.mm(r, self.W.view(r.size(1), -1))
        W_mat = W_mat.view(-1, e1.size(1), e1.size(1))
        W_mat = self.hidden_dropout1(W_mat)

        x = torch.bmm(x, W_mat) 
        x = x.view(-1, e1.size(1))      
        x = self.bn1(x)
        x = self.hidden_dropout2(x)
        x = torch.mm(x, self.E.weight.transpose(1,0))
        pred = torch.sigmoid(x)
        return pred

    def link_predict(self, head, relation, tail=None, k=10):
        e1 = self.E(head)
        r = self.R(relation)
        x = self.bn0(e1)
        x = x.view(-1, 1, e1.size(1))
        W_mat = torch.mm(r, self.W.view(r.size(1), -1))
        W_mat = W_mat.view(-1, e1.size(1), e1.size(1))
        W_mat = self.hidden_dropout1(W_mat)
        x = torch.bmm(x, W_mat)
        x = x.view(-1, e1.size(1))
        x = self.bn1(x)
        x = self.hidden_dropout2(x)
        h_add_r = x
        scores = torch.mm(h_add_r, self.E.weight.transpose(1, 0))
        _, indices = torch.topk(scores, k=k, dim=1, largest=True)
        
        if tail is not None:
            tail = tail.view(-1, 1)
            rank_num = torch.eq(indices, tail).nonzero().permute(1, 0)[1] + 1
            rank_num[rank_num > 9] = 10000
            mrr = torch.sum(1 / rank_num.float())
            hits_1_num = torch.sum(torch.eq(indices[:, :1], tail)).item()
            hits_3_num = torch.sum(torch.eq(indices[:, :3], tail)).item()
            hits_10_num = torch.sum(torch.eq(indices[:, :10], tail)).item()
            return mrr, hits_1_num, hits_3_num, hits_10_num
        
        return indices[:, :k]


    def evaluate(self, data_loader, dev_num=5000.0):
        mrr_sum = hits_1_nums = hits_3_nums = hits_10_nums = 0
        device = next(self.parameters()).device
        
        with torch.no_grad():
            for heads, relations, tails in tqdm.tqdm(data_loader):
                mrr_sum_batch, hits_1_num, hits_3_num, hits_10_num = self.link_predict(heads.to(device), relations.to(device), tails.to(device))
                mrr_sum += mrr_sum_batch
                hits_1_nums += hits_1_num
                hits_3_nums += hits_3_num
                hits_10_nums += hits_10_num
        
        return mrr_sum / dev_num, hits_1_nums / dev_num, hits_3_nums / dev_num, hits_10_nums / dev_num



In [5]:
# batchsize增大，得分略有上升
train_batch_size = 1000
dev_batch_size = 16  # 显存不够就调小
test_batch_size = 16
epochs = 10
print_frequency = 5  # 每多少step输出一次信息
validation = True  # 是否验证，验证比较费时
dev_interval = 5  # 每多少轮验证一次，微调设小一点，会保存最佳权重
best_mrr = 0
learning_rate = 0.0005  # 学习率建议粗调0.01-0.001，精调0.001-0.0001
embedding_dim = 100  # 维度增大可能会有提升，我感觉没用，100维包含的信息足够丰富

In [8]:
with open('OpenBG500/OpenBG500_entity2text.tsv', 'r', encoding='utf-8') as fp:
    dat = fp.readlines()
    lines = [line.strip('\n').split('\t') for line in dat]
ent2id = {line[0]: i for i, line in enumerate(lines)}
id2ent = {i: line[0] for i, line in enumerate(lines)}
with open('OpenBG500/OpenBG500_relation2text.tsv', 'r', encoding='utf-8') as fp:
    dat = fp.readlines()
    lines = [line.strip().split('\t') for line in dat]
rel2id = {line[0]: i for i, line in enumerate(lines)}
with open('OpenBG500/OpenBG500_train.tsv', 'r', encoding='utf-8') as fp:
    dat = fp.readlines()
    train = [line.strip('\n').split('\t') for line in dat]
with open('OpenBG500/OpenBG500_dev.tsv', 'r', encoding='utf-8') as fp:
    dat = fp.readlines()
    dev = [line.strip('\n').split('\t') for line in dat]
with open('OpenBG500/OpenBG500_test.tsv', 'r', encoding='utf-8') as fp:
    test = fp.readlines()
    test = [line.strip('\n').split('\t') for line in test]
# 构建数据集
train_dataset = TripleDataset(ent2id, rel2id, train)
dev_dataset = TripleDataset(ent2id, rel2id, dev)
train_data_loader = data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
dev_data_loader = data.DataLoader(dev_dataset, batch_size=dev_batch_size)
test_dataset = TestDataset(ent2id, rel2id, test)
test_data_loader = data.DataLoader(test_dataset, batch_size=test_batch_size)


import random

# 获取原始数据集的长度
original_size = len(train_dataset)

# 计算切片后的目标大小
target_size = original_size // 100

# 随机抽样得到切片后的索引
sampled_indices = random.sample(range(original_size), target_size)

# 根据抽样后的索引获取切片后的数据集
sampled_train_data = [train_dataset[idx] for idx in sampled_indices]

# 如果需要将切片后的数据集重新构建为 DataLoader
sampled_train_data_loader = data.DataLoader(sampled_train_data, batch_size=train_batch_size, shuffle=True)

In [10]:
# 构建模型
model = TuckER(len(ent2id), len(rel2id), dim=embedding_dim).cuda()
model.init()
# model.load_state_dict(torch.load('TuckER_best.pth'))
# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 训练
print('start training...')
for epoch in range(epochs):
    all_loss = 0
    for i, (local_heads, local_relations, local_tails) in enumerate(sampled_train_data_loader):

        head = local_heads.cuda()
        relation = local_relations.cuda()
        tail = local_tails.cuda()

        optimizer.zero_grad()
        
        # 计算正样本预测值
        pred = model.forward(head, relation)
        labels = torch.ones(len(local_heads), len(ent2id)).cuda()

        # 计算损失
        loss = model.loss(pred, labels)

        loss.backward()
        
        optimizer.step()
        all_loss += loss.item()
        if i % print_frequency == 0:
            print(
                f"epoch:{epoch}/{epochs}, step:{i}/{len(sampled_train_data_loader)}, loss={loss.item()}, avg_loss={all_loss / (i + 1)}")
    print(f"epoch:{epoch}/{epochs}, all_loss={all_loss}")

    # 验证
    if validation and (epoch + 1) % dev_interval == 0:
        print('testing...')
        improve = ''
        mrr, hits1, hits3, hits10 = model.evaluate(dev_data_loader)
        if mrr >= best_mrr:
            best_mrr = mrr
            improve = '*'
            torch.save(model.state_dict(), 'TuckER_best.pth')
        torch.save(model.state_dict(), 'TuckER_latest.pth')
        print(f'mrr: {mrr}, hit@1: {hits1}, hit@3: {hits3}, hit@10: {hits10}  {improve}')
    if not validation:
        torch.save(model.state_dict(), 'TuckER_latest.pth')

start training...
epoch:0/10, step:0/13, loss=0.6933450698852539, avg_loss=0.6933450698852539
epoch:0/10, step:5/13, loss=0.6932802796363831, avg_loss=0.6933171550432841
epoch:0/10, step:10/13, loss=0.693170964717865, avg_loss=0.6932828697291288
epoch:0/10, all_loss=9.012145519256592
epoch:1/10, step:0/13, loss=0.6927986145019531, avg_loss=0.6927986145019531
epoch:1/10, step:5/13, loss=0.692460298538208, avg_loss=0.6928480466206869
epoch:1/10, step:10/13, loss=0.6921534538269043, avg_loss=0.6926869208162482
epoch:1/10, all_loss=9.004232227802277
epoch:2/10, step:0/13, loss=0.6918404698371887, avg_loss=0.6918404698371887
epoch:2/10, step:5/13, loss=0.6924312710762024, avg_loss=0.6915651361147562
epoch:2/10, step:10/13, loss=0.6901764273643494, avg_loss=0.6914288564161821
epoch:2/10, all_loss=8.986169338226318
epoch:3/10, step:0/13, loss=0.6898820996284485, avg_loss=0.6898820996284485
epoch:3/10, step:5/13, loss=0.6890738606452942, avg_loss=0.6897997458775839
epoch:3/10, step:10/13, loss

100%|██████████| 313/313 [00:34<00:00,  9.05it/s]


mrr: 3.333333370392211e-05, hit@1: 0.0, hit@3: 0.0, hit@10: 0.0002  *
epoch:5/10, step:0/13, loss=0.6807449460029602, avg_loss=0.6807449460029602
epoch:5/10, step:5/13, loss=0.6748770475387573, avg_loss=0.6771734257539114


KeyboardInterrupt: 

In [None]:
predict_all = []
model = TuckER(len(ent2id), len(rel2id), dim=embedding_dim).cuda()
model.load_state_dict(torch.load('TuckER_best.pth'))
for heads, relations in tqdm.tqdm(test_data_loader):
    # 预测的id,结果为tensor(batch_size*10)
    predict_id = model.link_predict(heads.cuda(), relations.cuda())
    # 结果取到cpu并转为一行的list以便迭代
    predict_list = predict_id.cpu().numpy().reshape(1,-1).squeeze(0).tolist()
    # id转为实体
    predict_ent = map(lambda x: id2ent[x], predict_list)
    # 保存结果
    predict_all.extend(predict_ent)
print('prediction finished !')

In [None]:
with open('TuckER_submission.tsv', 'w', encoding='utf-8') as f:
    for i in range(len(test)):
        # 直接writelines没有空格分隔，手工加分割符，得按提交格式来
        list = [x + '\t' for x in test[i]] + [x + '\n' if i == 9 else x + '\t' for i, x in enumerate(predict_all[i*10:i*10+10])]
        f.writelines(list)
print('file saved !')