In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
import math

In [None]:
import numpy as np
import pandas as pd
from collections import Counter
import random

class TrainSet(Dataset):
    def __init__(self):
        super(TrainSet, self).__init__()
        # self.raw_data, self.entity_dic, self.relation_dic = self.load_texd()
        self.raw_data, self.entity_to_index, self.relation_to_index = self.load_text()
        self.entity_num, self.relation_num = len(self.entity_to_index), len(self.relation_to_index)
        self.triple_num = self.raw_data.shape[0]
        print(f'Train set: {self.entity_num} entities, {self.relation_num} relations, {self.triple_num} triplets.')
        self.pos_data = self.convert_word_to_index(self.raw_data)
        self.related_dic = self.get_related_entity()
        # print(self.related_dic[0], self.related_dic[479])
        self.neg_data = self.generate_neg()

    def __len__(self):
        return self.triple_num

    def __getitem__(self, item):
        return [self.pos_data[item], self.neg_data[item]]

    def load_text(self):
        raw_data = pd.read_csv('/kg/transe/fb15k/freebase_mtr100_mte100-train.txt', sep='\t', header=None,
                               names=['head', 'relation', 'tail'],
                               keep_default_na=False, encoding='utf-8')
        raw_data = raw_data.applymap(lambda x: x.strip())
        head_count = Counter(raw_data['head'])
        tail_count = Counter(raw_data['tail'])
        relation_count = Counter(raw_data['relation'])
        entity_list = list((head_count + tail_count).keys())
        relation_list = list(relation_count.keys())
        entity_dic = dict([(word, idx) for idx, word in enumerate(entity_list)])
        relation_dic = dict([(word, idx) for idx, word in enumerate(relation_list)])
        return raw_data.values, entity_dic, relation_dic

    def convert_word_to_index(self, data):
        index_list = np.array([
            [self.entity_to_index[triple[0]], self.relation_to_index[triple[1]], self.entity_to_index[triple[2]]] for
            triple in data])
        return index_list

    def generate_neg(self):
        """
        generate negative sampling
        :return: same shape as positive sampling
        """
        neg_candidates, i = [], 0
        neg_data = []
        population = list(range(self.entity_num))
        for idx, triple in enumerate(self.pos_data):
            while True:
                if i == len(neg_candidates):
                    i = 0
                    neg_candidates = random.choices(population=population, k=int(1e4))
                neg, i = neg_candidates[i], i + 1
                if random.randint(0, 1) == 0:
                    # replace head
                    if neg not in self.related_dic[triple[2]]:
                        neg_data.append([neg, triple[1], triple[2]])
                        break
                else:
                    # replace tail
                    if neg not in self.related_dic[triple[0]]:
                        neg_data.append([triple[0], triple[1], neg])
                        break

        return np.array(neg_data)

    def get_related_entity(self):
        """
        get related entities
        :return: {entity_id: {related_entity_id_1, related_entity_id_2...}}
        """
        related_dic = dict()
        for triple in self.pos_data:
            if related_dic.get(triple[0]) is None:
                related_dic[triple[0]] = {triple[2]}
            else:
                related_dic[triple[0]].add(triple[2])
            if related_dic.get(triple[2]) is None:
                related_dic[triple[2]] = {triple[0]}
            else:
                related_dic[triple[2]].add(triple[0])
        return related_dic


class TestSet(Dataset):
    def __init__(self):
        super(TestSet, self).__init__()
        self.raw_data = self.load_text()
        self.data = self.raw_data
        print(f"Test set: {self.raw_data.shape[0]} triplets")

    def __getitem__(self, item):
        return self.data[item]

    def __len__(self):
        return self.data.shape[0]

    def load_text(self):
        raw_data = pd.read_csv('/kg/transe/fb15k/freebase_mtr100_mte100-test.txt', sep='\t', header=None,
                               names=['head', 'relation', 'tail'],
                               keep_default_na=False, encoding='utf-8')
        raw_data = raw_data.applymap(lambda x: x.strip())
        return raw_data.values

    def convert_word_to_index(self, entity_to_index, relation_to_index, data):
        index_list = np.array(
            [[entity_to_index[triple[0]], relation_to_index[triple[1]], entity_to_index[triple[2]]] for triple in data])
        self.data = index_list

In [None]:
class TransE(nn.Module):
    def __init__(self, ent_num, rel_num, device, dim=50, d_norn=2, margin=1):
        '''
        :param ent_num: entity_num
        :param rel_num: relation_num
        :param device: cuda_device
        :param dim: dim = 50
        :param d_norn: d_norm = 2
        :param margin: margin hyperparameter
        '''

        super().__init__()
        self.ent_num = ent_num
        self.rel_num = rel_num

        self.device = device
        self.dim = dim
        self.d_norn = d_norn
        self.margin = torch.tensor([margin]).to(self.device)

        self.ent_emb =  nn.Embedding.from_pretrained(
            torch.empty(ent_num, self.dim).uniform_(-6 / math.sqrt(self.dim), 6 / math.sqrt(self.dim)),freeze=False
        )

        self.rel_emb = nn.Embedding.from_pretrained(
            torch.empty(rel_num, self.dim).uniform_(-6 / math.sqrt(self.dim), 6 / math.sqrt(self.dim))
            ,freeze=False
        )

        rel_norm = torch.norm(self.rel_emb.weight.data, dim=1, keepdim=True)
        self.rel_emb.weight.data = self.rel_emb.weight.data / rel_norm

    def forward(self, pos_head, pos_relation, pos_tail, neg_head, neg_relation, neg_tail):
        '''
        :param pos_head: [batch_size]
        :param pos_relation: [batch_size]
        :param pos_tail: [batch_size]
        :param neg_head: [batch_size]
        :param neg_relation: [batch_size]
        :param neg_tail: [batch_size]
        :return: triple_loss
        '''

        pos_dis = self.ent_emb(pos_head) + self.rel_emb(pos_relation) - self.ent_emb(pos_tail)

        neg_dis = self.ent_emb(neg_head) + self.rel_emb(neg_relation) - self.ent_emb(neg_tail)

        # return pos_head_and_relation, pos_tail, neg_head_and_relation, neg_tail

        return self.calculate_loss(pos_dis, neg_dis).requires_grad_()

    def calculate_loss(self, pos_dis, neg_dis):
        '''
        :param pos_dis: [batch_size, embed_dim]
        :param neg_dis: [batch_size, embed_dim]
        :return: triples loss: [batch_size]
        '''

        distance_diff = self.margin + torch.norm(pos_dis, p=self.d_norn, dim=1) - \
                        torch.norm(neg_dis, p=self.d_norn, dim=1)

        return torch.sum(F.relu(distance_diff))

    def tail_predict(self, head, relation, tail, k=10):
        '''
        to do tail prediction hits@k
        :param head: [batch_size]
        :param relation: [batch_size]
        :param tail: [batch_size]
        :param k: hits@k
        :return:
        '''

        # head: [batch_size]
        # h_and_r: [batch_size, embed_size] => [batch_size, 1, embed_size] => [batch_size, N, embed_size]

        h_and_r = self.ent_emb(head) + self.rel_emb(relation)
        h_and_r = torch.unsqueeze(h_and_r, dim=1)
        h_and_r = h_and_r.expand(h_and_r.shape[0], self.ent_num, self.dim)

        # embed_tail: [batch_size, N, embed_size]
        embed_tail = self.ent_emb.weight.data.expand(h_and_r.shape[0], self.ent_num, self.dim)

        # indices: [batch_size, k]
        values, indices = torch.topk(torch.norm(h_and_r - embed_tail, dim=2), k, dim=1, largest=False)
        # tail: [batch_size] => [batch_size, 1]
        tail = tail.view(-1, 1)
        return torch.sum(torch.eq(indices, tail)).item()

In [None]:
import torch
from torch import nn, optim

device = torch.device('cuda')
embed_dim = 50
num_epochs = 50
train_batch_size = 32
test_batch_size = 256
lr = 1e-2
momentum = 0
gamma = 1
d_norm = 2
top_k = 10


def main():
    train_dataset = TrainSet()
    test_dataset = TestSet()
    test_dataset.convert_word_to_index(train_dataset.entity_to_index, train_dataset.relation_to_index,
                                       test_dataset.raw_data)
    train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True)
    transe = TransE(train_dataset.entity_num, train_dataset.relation_num, device, dim=embed_dim, d_norn=d_norm, margin=gamma).to(device)
    optimizer = optim.SGD(transe.parameters(), lr=lr, momentum=momentum)
    for epoch in range(num_epochs):
        # e <= e / ||e||
        entity_norm = torch.norm(transe.ent_emb.weight.data, dim=1, keepdim=True)
        transe.ent_emb.weight.data = transe.ent_emb.weight.data / entity_norm
        total_loss = 0
        for batch_idx, (pos, neg) in enumerate(train_loader):
            pos, neg = pos.to(device), neg.to(device)
            # pos: [batch_size, 3] => [3, batch_size]
            pos = torch.transpose(pos, 0, 1)
            # pos_head, pos_relation, pos_tail: [batch_size]
            pos_head, pos_relation, pos_tail = pos[0], pos[1], pos[2]
            neg = torch.transpose(neg, 0, 1)
            # neg_head, neg_relation, neg_tail: [batch_size]
            neg_head, neg_relation, neg_tail = neg[0], neg[1], neg[2]
            loss = transe(pos_head, pos_relation, pos_tail, neg_head, neg_relation, neg_tail)
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"epoch {epoch+1}, loss = {total_loss/train_dataset.__len__()}")
        corrct_test = 0
        for batch_idx, data in enumerate(test_loader):
            data = data.to(device)
            # data: [batch_size, 3] => [3, batch_size]
            data = torch.transpose(data, 0, 1)
            corrct_test += transe.tail_predict(data[0], data[1], data[2], k=top_k)
        print(f"===>epoch {epoch+1}, test accuracy {corrct_test/test_dataset.__len__()}")


if __name__ == '__main__':
    main()

Test set: 59071 triplets


KeyboardInterrupt: 