In [6]:
import torch
from torch import optim, nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from load_fb15k import TrainSet, TestSet
import math

In [7]:

class TranE(nn.Module):
    def __init__(self, entity_num, relation_num, device, dim=50, d_norm=2, gamma=1):
        """
        :param entity_num: number of entities
        :param relation_num: number of relations
        :param dim: embedding dim
        :param device:
        :param d_norm: measure d(h+l, t), either L1-norm or L2-norm
        :param gamma: margin hyperparameter
        """
        super(TranE, self).__init__()
        self.dim = dim
        self.d_norm = d_norm
        self.device = device
        self.gamma = torch.FloatTensor([gamma]).to(self.device)
        self.entity_num = entity_num
        self.relation_num = relation_num
        self.entity_embedding = nn.Embedding.from_pretrained(
            torch.empty(entity_num, self.dim).uniform_(-6 / math.sqrt(self.dim), 6 / math.sqrt(self.dim)), freeze=False)
        self.relation_embedding = nn.Embedding.from_pretrained(
            torch.empty(relation_num, self.dim).uniform_(-6 / math.sqrt(self.dim), 6 / math.sqrt(self.dim)),
            freeze=False)
        # l <= l / ||l||
        relation_norm = torch.norm(self.relation_embedding.weight.data, dim=1, keepdim=True)
        self.relation_embedding.weight.data = self.relation_embedding.weight.data / relation_norm

    def forward(self, pos_head, pos_relation, pos_tail, neg_head, neg_relation, neg_tail):
        """
        :param pos_head: [batch_size]
        :param pos_relation: [batch_size]
        :param pos_tail: [batch_size]
        :param neg_head: [batch_size]
        :param neg_relation: [batch_size]
        :param neg_tail: [batch_size]
        :return: triples loss
        """
        pos_dis = self.entity_embedding(pos_head) + self.relation_embedding(pos_relation) - self.entity_embedding(
            pos_tail)
        neg_dis = self.entity_embedding(neg_head) + self.relation_embedding(neg_relation) - self.entity_embedding(
            neg_tail)
        # return pos_head_and_relation, pos_tail, neg_head_and_relation, neg_tail
        return self.calculate_loss(pos_dis, neg_dis).requires_grad_()

    def calculate_loss(self, pos_dis, neg_dis):
        """
        :param pos_dis: [batch_size, embed_dim]
        :param neg_dis: [batch_size, embed_dim]
        :return: triples loss: [batch_size]
        """
        distance_diff = self.gamma + torch.norm(pos_dis, p=self.d_norm, dim=1) - torch.norm(neg_dis, p=self.d_norm,
                                                                                            dim=1)
        return torch.sum(F.relu(distance_diff))

    def tail_predict(self, head, relation, tail, k=10):
        """
        to do tail prediction hits@k
        :param head: [batch_size]
        :param relation: [batch_size]
        :param tail: [batch_size]
        :param k: hits@k
        :return:
        """
        # head: [batch_size]
        # h_and_r: [batch_size, embed_size] => [batch_size, 1, embed_size] => [batch_size, N, embed_size]
        h_and_r = self.entity_embedding(head) + self.relation_embedding(relation)
        h_and_r = torch.unsqueeze(h_and_r, dim=1)
        h_and_r = h_and_r.expand(h_and_r.shape[0], self.entity_num, self.dim)
        # embed_tail: [batch_size, N, embed_size]
        embed_tail = self.entity_embedding.weight.data.expand(h_and_r.shape[0], self.entity_num, self.dim)
        # indices: [batch_size, k]
        values, indices = torch.topk(torch.norm(h_and_r - embed_tail, dim=2), k, dim=1, largest=False)
        # tail: [batch_size] => [batch_size, 1]
        tail = tail.view(-1, 1)
        return torch.sum(torch.eq(indices, tail)).item()

In [8]:
device = torch.device('cuda')
embed_dim = 50
num_epochs = 50
train_batch_size = 32
test_batch_size = 256
lr = 1e-2
momentum = 0
gamma = 1
d_norm = 2
top_k = 10

In [9]:
train_dataset = TrainSet()
test_dataset = TestSet()
test_dataset.convert_word_to_index(train_dataset.entity_to_index, train_dataset.relation_to_index,
                                       test_dataset.raw_data)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True)
transe = TranE(train_dataset.entity_num, train_dataset.relation_num, device, dim=embed_dim, d_norm=d_norm,
                gamma=gamma).to(device)
optimizer = optim.SGD(transe.parameters(), lr=lr, momentum=momentum)

Train set: 14951 entities, 1345 relations, 483142 triplets.
Test set: 59071 triplets


In [10]:
for epoch in range(num_epochs):
    # e <= e / ||e||
    entity_norm = torch.norm(transe.entity_embedding.weight.data, dim=1, keepdim=True)
    transe.entity_embedding.weight.data = transe.entity_embedding.weight.data / entity_norm
    total_loss = 0
    for batch_idx, (pos, neg) in enumerate(train_loader):
        pos, neg = pos.to(device), neg.to(device)
        # pos: [batch_size, 3] => [3, batch_size]
        pos = torch.transpose(pos, 0, 1)
        # pos_head, pos_relation, pos_tail: [batch_size]
        pos_head, pos_relation, pos_tail = pos[0], pos[1], pos[2]
        neg = torch.transpose(neg, 0, 1)
        # neg_head, neg_relation, neg_tail: [batch_size]
        neg_head, neg_relation, neg_tail = neg[0], neg[1], neg[2]
        loss = transe(pos_head, pos_relation, pos_tail, neg_head, neg_relation, neg_tail)
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"epoch {epoch+1}, loss = {total_loss/train_dataset.__len__()}")
    corrct_test = 0
    for batch_idx, data in enumerate(test_loader):
        data = data.to(device)
        # data: [batch_size, 3] => [3, batch_size]
        data = torch.transpose(data, 0, 1)
        corrct_test += transe.tail_predict(data[0], data[1], data[2], k=top_k)
    print(f"===>epoch {epoch+1}, test accuracy {corrct_test/test_dataset.__len__()}")

epoch 1, loss = 0.8006324084371367
===>epoch 1, test accuracy 0.21287941629564422
epoch 2, loss = 0.6110726407517862
===>epoch 2, test accuracy 0.24372365458516024
epoch 3, loss = 0.465637634008911
===>epoch 3, test accuracy 0.2644783396251968
epoch 4, loss = 0.35256436182743967
===>epoch 4, test accuracy 0.28269370757224355
epoch 5, loss = 0.27879381454206437
===>epoch 5, test accuracy 0.29620287450694927
epoch 6, loss = 0.2320634035086554
===>epoch 6, test accuracy 0.3076298014254033
epoch 7, loss = 0.20112843379949455
===>epoch 7, test accuracy 0.31382573513229844
epoch 8, loss = 0.1795059778154809
===>epoch 8, test accuracy 0.3217653332430465
epoch 9, loss = 0.16349206028165192
===>epoch 9, test accuracy 0.3250156591220734
epoch 10, loss = 0.15096781277536048
===>epoch 10, test accuracy 0.3303990113592118
epoch 11, loss = 0.14093025786640273
===>epoch 11, test accuracy 0.33386941138629783
epoch 12, loss = 0.13270950060343936
===>epoch 12, test accuracy 0.33688273433664573
epoch 13,