# Knowledge Graph Embedding

In this tutorial, I would like to briefly explain Knowledge Graph and implement [TransE](https://papers.nips.cc/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html) (Bordes et al. 2013 )
from scratch.

# Brief Background of Learning Embeddings


+ [Bengio et al. (2003)](https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf) suggest that **learning a distributed representation for words** can be an effective means to tackle the curse of dimensionality at learning the join probability of sequence of words in a language. 

+ In this context, the impact of **The curse of dimsionsality** can be shown as follows. To compute the joint probability of **n** words in a language having the vocabulary **V**



In [1]:
from itertools import product
for n,size_of_vocabulary in [(3,1),(3,10), (3,100)]:
    print(f'To compute the joint probability of {n} items of a vocabulary of size {size_of_vocabulary}: {len({i for i in product(range(size_of_vocabulary), repeat=n)})}')

To compute the joint probability of 3 items of a vocabulary of size 1: 1
To compute the joint probability of 3 items of a vocabulary of size 10: 1000
To compute the joint probability of 3 items of a vocabulary of size 100: 1000000


***The number of parameters for learning the join prob. dist. does not increase linearly but exponentialy.***

# Translating Embeddings for Modeling Multi-relational Data
[Bordes et al. 2013](https://papers.nips.cc/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html) propose to embed entities and relations of multi-relational data in low-dimensional vector space. In other words, **the goal is to model large knowledge graphs by learning a distributed representations for entities and relations based on translation operation**.


# Workflow
1. Parse input knowledge graph via KG class
2. Generate training dataset via DatasetTriple class
3. Train TransE
4. Report the filtered link prediction results of TransE.

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
warnings.simplefilter("ignore", UserWarning)

In [3]:
class KG:
    def __init__(self, data_dir=None):
        
        # 1. Parse the benchmark dataset
        s = '------------------- Description of Dataset' + data_dir + '----------------------------'
        print(f'\n{s}')
        self.train = self.load_data(data_dir + 'train.txt', add_reciprical=False)
        self.valid = self.load_data(data_dir + 'valid.txt', add_reciprical=False)
        self.test = self.load_data(data_dir + 'test.txt', add_reciprical=False)
        
        self.all_triples = self.train + self.valid + self.test
        self.entities = self.get_entities(self.all_triples)
        self.relations = self.get_relations(self.all_triples)

        # 2. Index entities and relations
        self.entity_idxs = {self.entities[i]: i for i in range(len(self.entities))}
        self.relation_idxs = {self.relations[i]: i for i in range(len(self.relations))}

        print(f'Number of triples: {len(self.all_triples)}')
        print(f'Number of entities: {len(self.entities)}')
        print(f'Number of relations: {len(self.relations)}')
        print(f'Number of triples on train set: {len(self.train)}')
        print(f'Number of triples on valid set: {len(self.valid)}')
        print(f'Number of triples on test set: {len(self.test)}')
        s = len(s) * '-'
        print(f'{s}\n')

        # 3. Index train, validation and test sets 
        self.train_idx = [(self.entity_idxs[s], self.relation_idxs[p], self.entity_idxs[o]) for s, p, o in
                          self.train]
        self.valid_idx = [(self.entity_idxs[s], self.relation_idxs[p], self.entity_idxs[o]) for s, p, o in
                          self.valid]
        self.test_idx = [(self.entity_idxs[s], self.relation_idxs[p], self.entity_idxs[o]) for s, p, o in
                         self.test]

        # 4. Create mappings for the filtered link prediction
        self.sp_vocab = dict()
        self.po_vocab = dict()
        self.so_vocab = dict()

        for i in self.all_triples:
            s, p, o = i[0], i[1], i[2]
            s_idx, p_idx, o_idx = self.entity_idxs[s], self.relation_idxs[p], self.entity_idxs[o]
            self.sp_vocab.setdefault((s_idx, p_idx), []).append(o_idx)
            self.so_vocab.setdefault((s_idx, o_idx), []).append(p_idx)
            self.po_vocab.setdefault((p_idx, o_idx), []).append(s_idx)


    @staticmethod
    def load_data(data_dir, add_reciprical=True):
        with open(data_dir, "r") as f:
            data = f.read().strip().split("\n")
            data = [i.split() for i in data]
            if add_reciprical:
                data += [[i[2], i[1] + "_reverse", i[0]] for i in data]
        return data

    @staticmethod
    def get_relations(data):
        relations = sorted(list(set([d[1] for d in data])))
        return relations

    @staticmethod
    def get_entities(data):
        entities = sorted(list(set([d[0] for d in data] + [d[2] for d in data])))
        return entities

    @property
    def num_entities(self):
        return len(self.entities)
    @property
    def num_relations(self):
        return len(self.relations)

In [4]:
kg=KG(data_dir='KGs/UMLS/')


------------------- Description of DatasetKGs/UMLS/----------------------------
Number of triples: 6529
Number of entities: 135
Number of relations: 46
Number of triples on train set: 5216
Number of triples on valid set: 652
Number of triples on test set: 661
-------------------------------------------------------------------------------



# Creating Dataset from Knowledge Graph

In [5]:
class DatasetTriple(torch.utils.data.Dataset):
    def __init__(self, data, num_entities=None, nneg=1,**kwargs):
        data = torch.Tensor(data).long()
        self.head_idx = data[:, 0]
        self.rel_idx = data[:, 1]
        self.tail_idx = data[:, 2]
        self.num_entities = num_entities
        self.nneg = nneg
        assert self.head_idx.shape == self.rel_idx.shape == self.tail_idx.shape

        self.length = len(self.head_idx)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        h = self.head_idx[idx]
        r = self.rel_idx[idx]
        t = self.tail_idx[idx]
        return h, r, t

    def collate_fn(self, batch):
        """ Generate Negative Triples"""
        batch = torch.LongTensor(batch)
        h, r, t = batch[:, 0], batch[:, 1], batch[:, 2]
        size_of_batch, _ = batch.shape
        assert size_of_batch > 0
        label = torch.ones((size_of_batch,))
        # Generate Negative Triples
        corr = torch.randint(0, self.num_entities, (size_of_batch * self.nneg, 1))
        
        if torch.rand(1).item()>.5:
            # 2.1 Head Corrupt:
            h_corr = corr[:, 0]
            r_corr = r.repeat(self.nneg, )
            t_corr = t.repeat(self.nneg, )
            label_corr = -torch.ones(len(t_corr), )
        else:
            # 2.2. Tail Corrupt
            h_corr = h.repeat(self.nneg, )
            r_corr = r.repeat(self.nneg, )
            t_corr = corr[:, 0]
            label_corr = -torch.ones(len(t_corr), )

        # 3. Stack True and Corrupted Triples
        h = torch.cat((h, h_corr), 0)
        r = torch.cat((r, r_corr), 0)
        t = torch.cat((t, t_corr), 0)
        label = torch.cat((label, label_corr), 0)
        return h, r, t, label

# Define TransE

In [6]:
class TransE(torch.nn.Module):
    def __init__(self, embedding_dim,num_entities,num_relations,**kwargs):
        super(TransE, self).__init__()
        self.name = 'TransE'
        self.embedding_dim = embedding_dim
        self.num_entities = num_entities
        self.num_relations = num_relations

        self.emb_ent = nn.Embedding(self.num_entities, self.embedding_dim)
        self.emb_rel = nn.Embedding(self.num_relations, self.embedding_dim)
        
        low,high=-6/torch.sqrt(torch.Tensor([self.embedding_dim])).item(),6/torch.sqrt(torch.Tensor([self.embedding_dim])).item()
        self.emb_ent.weight.data.uniform_(low, high)
        self.emb_rel.weight.data.uniform_(low, high)
        
        
    def forward(self, e1_idx, rel_idx, e2_idx ):
        # (1) Embeddings of head, relation and tail
        emb_head, emb_rel, emb_tail = self.emb_ent(e1_idx),self.emb_rel(rel_idx), self.emb_ent(e2_idx)
        # (2) Normalize head and tail entities
        emb_head = F.normalize(emb_head, p=2,dim=1)
        emb_tail = F.normalize(emb_tail, p=2,dim=1)
        # (3) Compute Distance
        distance = torch.norm((emb_head + emb_rel) - emb_tail, p=2,dim=1)
        return distance

In [7]:
# hyperparameters
hparam={'embedding_dim':25,
       'num_entities':kg.num_entities,
       'num_relations':kg.num_relations,
       'gamma':1.0, # margin for loss
       'lr':.01,# learning rate for optimizer
       'batch_size':256,
       'num_epochs':100
      }

In [8]:
dataset = DatasetTriple(data=kg.train_idx, **hparam)
dataloader = torch.utils.data.DataLoader(dataset,batch_size=hparam['batch_size'], num_workers=4, shuffle=True,drop_last=True, collate_fn=dataset.collate_fn)


model = TransE(**hparam)
gamma = nn.Parameter( torch.Tensor([ hparam['gamma'] ]),requires_grad=False)
optimizer = torch.optim.Adam(model.parameters(), lr=hparam['lr'])

In [9]:
for e in range(1,hparam['num_epochs']):
    epoch_loss=.0
    for h, r, t, labels in dataloader:
        optimizer.zero_grad()
        
        # Compute Distance based on translation,i.e. h + r \approx t provided that h,r,t \in G.
        distance=model.forward(h,r,t)    
        
        pos_distance=distance[labels == 1]
        neg_distance=distance[labels == -1]

        loss= (F.relu(gamma + pos_distance - neg_distance)).sum()
        
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    if e%10==0:
        print(f'{e}.th epoch sum of loss: {epoch_loss}')

10.th epoch sum of loss: 3351.2478790283203
20.th epoch sum of loss: 2201.5757446289062
30.th epoch sum of loss: 1927.5072174072266
40.th epoch sum of loss: 1828.472312927246
50.th epoch sum of loss: 1835.4241943359375
60.th epoch sum of loss: 1812.6691436767578
70.th epoch sum of loss: 1829.0188751220703
80.th epoch sum of loss: 1798.6147384643555
90.th epoch sum of loss: 1772.907299041748


# The filtered Link Prediction Evaluation

In [10]:
hits = dict()
reciprocal_ranks = []

for test_triple_idx in kg.test_idx:
    s_idx,p_idx,o_idx= test_triple_idx

    all_entities = torch.arange(0, dataset.num_entities).long()
    all_entities = all_entities.reshape(len(all_entities), )

    # 2. Compute tail distances \forall x \in Entities: TransE(s,p,x)
    predictions_tails = model.forward(e1_idx=torch.tensor(s_idx).repeat(dataset.num_entities, ),
                                              rel_idx=torch.tensor(p_idx).repeat(dataset.num_entities, ),
                                              e2_idx=all_entities)
    
    # 3. Compute head distances \forall x \in Entities: TransE(x,p,o)
    predictions_heads = model.forward(e1_idx=all_entities,
                                              rel_idx=torch.tensor(p_idx).repeat(dataset.num_entities, ),
                                              e2_idx=torch.tensor(o_idx).repeat(dataset.num_entities, ))
    
    # 3. Computed filtered ranks for missing head and tail entities
    # 3.1. Filtered ranks for tail entities
    filt_tails = kg.sp_vocab[(s_idx, p_idx)]
    target_value = predictions_tails[o_idx].item()
    predictions_tails[filt_tails] = +np.Inf
    predictions_tails[o_idx] = target_value    
    _, sort_idxs = torch.sort(predictions_tails, descending=False)
    sort_idxs = sort_idxs.cpu().numpy()
    filt_tail_entity_rank = np.where(sort_idxs == o_idx)[0][0]


    
    # 3.2. Filtered ranks for head entities
    filt_heads = kg.po_vocab[(p_idx, o_idx)]
    target_value = predictions_heads[s_idx].item()
    predictions_heads[filt_heads] = +np.Inf
    predictions_heads[s_idx] = target_value
    _, sort_idxs = torch.sort(predictions_heads, descending=False)
    sort_idxs = sort_idxs.cpu().numpy()
    filt_head_entity_rank = np.where(sort_idxs == s_idx)[0][0]

    # 4. Add 1 to ranks as numpy array first item has the index of 0.
    filt_head_entity_rank += 1
    filt_tail_entity_rank += 1
    
    # 5. Store reciprocal ranks.
    reciprocal_ranks.append(1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank))

    # 4. Compute Hit@N
    for hits_level in range(1, 11):
        I = 1 if filt_head_entity_rank <= hits_level else 0
        I += 1 if filt_tail_entity_rank <= hits_level else 0
        hits.setdefault(hits_level, []).append(I)
        
    mean_reciprocal_rank = sum(reciprocal_ranks) / (float(len(kg.test_idx) * 2))
    hit_1 = sum(hits[1]) / (float(len(kg.test_idx) * 2))
    hit_3 = sum(hits[3]) / (float(len(kg.test_idx) * 2))
    hit_10 = sum(hits[10]) / (float(len(kg.test_idx) * 2))
    
    results = {'H@1': hit_1, 'H@3': hit_3, 'H@10': hit_10,'MRR': mean_reciprocal_rank}

In [11]:
results

{'H@1': 0.4296520423600605,
 'H@3': 0.6565809379727685,
 'H@10': 0.8381240544629349,
 'MRR': 0.5741972594591919}