In [13]:
import torch
from torch import nn
from torch.utils import data
import numpy as np
import tqdm

In [14]:
class TripleDataset(data.Dataset):
    def __init__(self, ent2id, rel2id, triple_data_list):
        self.ent2id = ent2id
        self.rel2id = rel2id
        self.data = triple_data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        head, relation, tail = self.data[index]
        head_id = self.ent2id[head]
        relation_id = self.rel2id[relation]
        tail_id = self.ent2id[tail]
        return head_id, relation_id, tail_id

class TestDataset(data.Dataset):
    def __init__(self, ent2id, rel2id, test_data_list):
        self.ent2id = ent2id
        self.rel2id = rel2id
        self.data = test_data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        head, relation = self.data[index]
        head_id = self.ent2id[head]
        relation_id = self.rel2id[relation]
        return head_id, relation_id

In [15]:
class TransE(nn.Module):

    def __init__(self, entity_num, relation_num, norm=1, dim=100):
        super(TransE, self).__init__()
        self.norm = norm
        self.dim = dim
        self.entity_num = entity_num
        self.entities_emb = self._init_emb(entity_num)
        self.relations_emb = self._init_emb(relation_num)

    def _init_emb(self, num_embeddings):
        embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=self.dim)
        uniform_range = 6 / np.sqrt(self.dim)
        embedding.weight.data.uniform_(-uniform_range, uniform_range)
        embedding.weight.data = torch.div(embedding.weight.data, embedding.weight.data.norm(p=2, dim=1, keepdim=True))
        return embedding

    def forward(self, positive_triplets: torch.LongTensor, negative_triplets: torch.LongTensor):
        positive_distances = self._distance(positive_triplets.cuda())
        negative_distances = self._distance(negative_triplets.cuda())
        return positive_distances, negative_distances

    def _distance(self, triplets):
        heads = self.entities_emb(triplets[:, 0])
        relations = self.relations_emb(triplets[:, 1])
        tails = self.entities_emb(triplets[:, 2])
        return (heads + relations - tails).norm(p=self.norm, dim=1)

    def link_predict(self, head, relation, tail=None, k=10):
        h_add_r = self.entities_emb(head) + self.relations_emb(relation)
        h_add_r = torch.unsqueeze(h_add_r, dim=1)
        h_add_r = h_add_r.expand(h_add_r.shape[0], self.entity_num, self.dim)
        embed_tail = self.entities_emb.weight.data.expand(h_add_r.shape[0], self.entity_num, self.dim)
        values, indices = torch.topk(torch.norm(h_add_r - embed_tail, dim=2), k=self.entity_num, dim=1, largest=False)
        if tail is not None:
            tail = tail.view(-1, 1)
            rank_num = torch.eq(indices, tail).nonzero().permute(1, 0)[1]+1
            rank_num[rank_num > 9] = 10000
            mrr = torch.sum(1/rank_num)
            hits_1_num = torch.sum(torch.eq(indices[:, :1], tail)).item()
            hits_3_num = torch.sum(torch.eq(indices[:, :3], tail)).item()
            hits_10_num = torch.sum(torch.eq(indices[:, :10], tail)).item()
            return mrr, hits_1_num, hits_3_num, hits_10_num
        return indices[:, :k]

    def evaluate(self, data_loader, dev_num=5000.0):
        mrr_sum = hits_1_nums = hits_3_nums = hits_10_nums = 0
        for heads, relations, tails in tqdm.tqdm(data_loader):
            mrr_sum_batch, hits_1_num, hits_3_num, hits_10_num = self.link_predict(heads.cuda(), relations.cuda(), tails.cuda())
            mrr_sum += mrr_sum_batch
            hits_1_nums += hits_1_num
            hits_3_nums += hits_3_num
            hits_10_nums += hits_10_num
        return mrr_sum/dev_num, hits_1_nums/dev_num, hits_3_nums/dev_num, hits_10_nums/dev_num

In [48]:
train_batch_size = 1500000
dev_batch_size = 20 
test_batch_size = 20
epochs = 40
margin = 1
print_frequency = 5 
validation = True 
dev_interval = 5  
best_mrr = 0
learning_rate = 0.001 
distance_norm = 3  
embedding_dim = 128 

In [49]:
with open('OpenBG500_entity2text.tsv', 'r', encoding='utf-8') as fp:
    dat = fp.readlines()
    lines = [line.strip('\n').split('\t') for line in dat]
ent2id = {line[0]: i for i, line in enumerate(lines)}
id2ent = {i: line[0] for i, line in enumerate(lines)}
with open('OpenBG500_relation2text.tsv', 'r', encoding='utf-8') as fp:
    dat = fp.readlines()
    lines = [line.strip().split('\t') for line in dat]
rel2id = {line[0]: i for i, line in enumerate(lines)}
with open('OpenBG500_train.tsv', 'r', encoding='utf-8') as fp:
    dat = fp.readlines()
    train = [line.strip('\n').split('\t') for line in dat]
with open('OpenBG500_dev.tsv', 'r', encoding='utf-8') as fp:
    dat = fp.readlines()
    dev = [line.strip('\n').split('\t') for line in dat]
with open('OpenBG500_test.tsv', 'r', encoding='utf-8') as fp:
    test = fp.readlines()
    test = [line.strip('\n').split('\t') for line in test]
train_dataset = TripleDataset(ent2id, rel2id, train)
dev_dataset = TripleDataset(ent2id, rel2id, dev)
train_data_loader = data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
dev_data_loader = data.DataLoader(dev_dataset, batch_size=dev_batch_size)
test_dataset = TestDataset(ent2id, rel2id, test)
test_data_loader = data.DataLoader(test_dataset, batch_size=test_batch_size)

In [57]:
# model = TransE(len(ent2id), len(rel2id), norm=distance_norm, dim=embedding_dim).cuda()
model.load_state_dict(torch.load('transE_best.pth'))
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MarginRankingLoss(margin=margin, reduction='mean')
print('start training...')
for epoch in range(epochs):
    all_loss = 0
    for i, (local_heads, local_relations, local_tails) in enumerate(train_data_loader):

        positive_triples = torch.stack((local_heads, local_relations, local_tails), dim=1).cuda()

        # 生成负样本
        head_or_tail = torch.randint(high=2, size=local_heads.size())
        random_entities = torch.randint(high=len(ent2id), size=local_heads.size())
        broken_heads = torch.where(head_or_tail == 1, random_entities, local_heads)
        broken_tails = torch.where(head_or_tail == 0, random_entities, local_tails)
        negative_triples = torch.stack((broken_heads, local_relations, broken_tails), dim=1).cuda()

        optimizer.zero_grad()
        pd, nd = model(positive_triples, negative_triples)
        loss = criterion(pd, nd, torch.tensor([-1], dtype=torch.long).cuda())
        loss.backward()
        all_loss += loss.data
        optimizer.step()
        if i % print_frequency == 0:
            print(
                f"epoch:{epoch}/{epochs}, step:{i}/{len(train_data_loader)}, loss={loss.item()}, avg_loss={all_loss / (i + 1)}")
    print(f"epoch:{epoch}/{epochs}, all_loss={all_loss}")

    # 验证
    if validation and (epoch + 1) % dev_interval == 0:
        print('testing...')
        improve = ''
        mrr, hits1, hits3, hits10 = model.evaluate(dev_data_loader)
        if mrr >= best_mrr:
            best_mrr = mrr
            improve = '*'
            torch.save(model.state_dict(), 'transE_best.pth')
        torch.save(model.state_dict(), 'transE_latest.pth')
        print(f'mrr: {mrr}, hit@1: {hits1}, hit@3: {hits3}, hit@10: {hits10}  {improve}')
    if not validation:
        torch.save(model.state_dict(), 'transE_latest.pth')

start training...
epoch:0/40, step:0/1, loss=0.19845178723335266, avg_loss=0.19845178723335266
epoch:0/40, all_loss=0.19845178723335266
epoch:1/40, step:0/1, loss=0.1952367126941681, avg_loss=0.1952367126941681
epoch:1/40, all_loss=0.1952367126941681
epoch:2/40, step:0/1, loss=0.19206662476062775, avg_loss=0.19206662476062775
epoch:2/40, all_loss=0.19206662476062775
epoch:3/40, step:0/1, loss=0.18893979489803314, avg_loss=0.18893979489803314
epoch:3/40, all_loss=0.18893979489803314
epoch:4/40, step:0/1, loss=0.18585456907749176, avg_loss=0.18585456907749176
epoch:4/40, all_loss=0.18585456907749176
testing...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:05<00:00, 49.61it/s]


mrr: 0.5556533932685852, hit@1: 0.4526, hit@3: 0.6274, hit@10: 0.774  *
epoch:5/40, step:0/1, loss=0.18281076848506927, avg_loss=0.18281076848506927
epoch:5/40, all_loss=0.18281076848506927
epoch:6/40, step:0/1, loss=0.17980754375457764, avg_loss=0.17980754375457764
epoch:6/40, all_loss=0.17980754375457764
epoch:7/40, step:0/1, loss=0.17684368789196014, avg_loss=0.17684368789196014
epoch:7/40, all_loss=0.17684368789196014
epoch:8/40, step:0/1, loss=0.17391838133335114, avg_loss=0.17391838133335114
epoch:8/40, all_loss=0.17391838133335114
epoch:9/40, step:0/1, loss=0.17103147506713867, avg_loss=0.17103147506713867
epoch:9/40, all_loss=0.17103147506713867
testing...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:05<00:00, 49.74it/s]


mrr: 0.5562563538551331, hit@1: 0.453, hit@3: 0.629, hit@10: 0.7732  *
epoch:10/40, step:0/1, loss=0.16818170249462128, avg_loss=0.16818170249462128
epoch:10/40, all_loss=0.16818170249462128
epoch:11/40, step:0/1, loss=0.16536882519721985, avg_loss=0.16536882519721985
epoch:11/40, all_loss=0.16536882519721985
epoch:12/40, step:0/1, loss=0.16259199380874634, avg_loss=0.16259199380874634
epoch:12/40, all_loss=0.16259199380874634
epoch:13/40, step:0/1, loss=0.1598510593175888, avg_loss=0.1598510593175888
epoch:13/40, all_loss=0.1598510593175888
epoch:14/40, step:0/1, loss=0.15714551508426666, avg_loss=0.15714551508426666
epoch:14/40, all_loss=0.15714551508426666
testing...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:05<00:00, 49.43it/s]


mrr: 0.555634081363678, hit@1: 0.4514, hit@3: 0.629, hit@10: 0.774  
epoch:15/40, step:0/1, loss=0.15447497367858887, avg_loss=0.15447497367858887
epoch:15/40, all_loss=0.15447497367858887
epoch:16/40, step:0/1, loss=0.1518394947052002, avg_loss=0.1518394947052002
epoch:16/40, all_loss=0.1518394947052002
epoch:17/40, step:0/1, loss=0.1492394357919693, avg_loss=0.1492394357919693
epoch:17/40, all_loss=0.1492394357919693
epoch:18/40, step:0/1, loss=0.14667434990406036, avg_loss=0.14667434990406036
epoch:18/40, all_loss=0.14667434990406036
epoch:19/40, step:0/1, loss=0.14414377510547638, avg_loss=0.14414377510547638
epoch:19/40, all_loss=0.14414377510547638
testing...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:05<00:00, 49.40it/s]


mrr: 0.5557294487953186, hit@1: 0.4518, hit@3: 0.6276, hit@10: 0.7742  
epoch:20/40, step:0/1, loss=0.14164771139621735, avg_loss=0.14164771139621735
epoch:20/40, all_loss=0.14164771139621735
epoch:21/40, step:0/1, loss=0.13918627798557281, avg_loss=0.13918627798557281
epoch:21/40, all_loss=0.13918627798557281
epoch:22/40, step:0/1, loss=0.13675960898399353, avg_loss=0.13675960898399353
epoch:22/40, all_loss=0.13675960898399353
epoch:23/40, step:0/1, loss=0.1343674510717392, avg_loss=0.1343674510717392
epoch:23/40, all_loss=0.1343674510717392
epoch:24/40, step:0/1, loss=0.13200916349887848, avg_loss=0.13200916349887848
epoch:24/40, all_loss=0.13200916349887848
testing...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:05<00:00, 49.43it/s]


mrr: 0.5543740391731262, hit@1: 0.4498, hit@3: 0.628, hit@10: 0.7732  
epoch:25/40, step:0/1, loss=0.12968449294567108, avg_loss=0.12968449294567108
epoch:25/40, all_loss=0.12968449294567108
epoch:26/40, step:0/1, loss=0.12739354372024536, avg_loss=0.12739354372024536
epoch:26/40, all_loss=0.12739354372024536
epoch:27/40, step:0/1, loss=0.12513600289821625, avg_loss=0.12513600289821625
epoch:27/40, all_loss=0.12513600289821625
epoch:28/40, step:0/1, loss=0.12291167676448822, avg_loss=0.12291167676448822
epoch:28/40, all_loss=0.12291167676448822
epoch:29/40, step:0/1, loss=0.12072039395570755, avg_loss=0.12072039395570755
epoch:29/40, all_loss=0.12072039395570755
testing...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:05<00:00, 49.42it/s]


mrr: 0.553005039691925, hit@1: 0.4472, hit@3: 0.627, hit@10: 0.773  
epoch:30/40, step:0/1, loss=0.11856217682361603, avg_loss=0.11856217682361603
epoch:30/40, all_loss=0.11856217682361603
epoch:31/40, step:0/1, loss=0.1164371594786644, avg_loss=0.1164371594786644
epoch:31/40, all_loss=0.1164371594786644
epoch:32/40, step:0/1, loss=0.11434491723775864, avg_loss=0.11434491723775864
epoch:32/40, all_loss=0.11434491723775864
epoch:33/40, step:0/1, loss=0.11228471249341965, avg_loss=0.11228471249341965
epoch:33/40, all_loss=0.11228471249341965
epoch:34/40, step:0/1, loss=0.11025642603635788, avg_loss=0.11025642603635788
epoch:34/40, all_loss=0.11025642603635788
testing...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:05<00:00, 49.12it/s]


mrr: 0.5521336793899536, hit@1: 0.445, hit@3: 0.6282, hit@10: 0.773  
epoch:35/40, step:0/1, loss=0.10825997591018677, avg_loss=0.10825997591018677
epoch:35/40, all_loss=0.10825997591018677
epoch:36/40, step:0/1, loss=0.10629507899284363, avg_loss=0.10629507899284363
epoch:36/40, all_loss=0.10629507899284363
epoch:37/40, step:0/1, loss=0.10436109453439713, avg_loss=0.10436109453439713
epoch:37/40, all_loss=0.10436109453439713
epoch:38/40, step:0/1, loss=0.10245858132839203, avg_loss=0.10245858132839203
epoch:38/40, all_loss=0.10245858132839203
epoch:39/40, step:0/1, loss=0.10058709979057312, avg_loss=0.10058709979057312
epoch:39/40, all_loss=0.10058709979057312
testing...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:05<00:00, 49.10it/s]


mrr: 0.5508900880813599, hit@1: 0.4428, hit@3: 0.6292, hit@10: 0.774  


In [58]:
predict_all = []
model.load_state_dict(torch.load('transE_best.pth'))
for heads, relations in tqdm.tqdm(test_data_loader):
    predict_id = model.link_predict(heads.cuda(), relations.cuda())
    predict_list = predict_id.cpu().numpy().reshape(1,-1).squeeze(0).tolist()
    predict_ent = map(lambda x: id2ent[x], predict_list)
    predict_all.extend(predict_ent)
print('prediction finished !')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:04<00:00, 50.47it/s]

prediction finished !





In [59]:
with open('submission.tsv', 'w', encoding='utf-8') as f:
    for i in range(len(test)):
        list = [x + '\t' for x in test[i]] + [x + '\n' if i == 9 else x + '\t' for i, x in enumerate(predict_all[i*10:i*10+10])]
        f.writelines(list)
print('file saved !')

file saved !
