<a href="https://colab.research.google.com/github/DougBR/KGs/blob/main/Knowledge_Graph_Embeddings_Simplistic_and_Powerful_Representations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CS224W Final Project - Knowledge Graph Embeddings**

*written by Mark Endo*

In this Colab, we will write a pipeline to learn knowledge graph (KG) embeddings for the task of predicting missing triples. Knowledge graphs are an example of heterogenous graphs that capture entities, types, and relationships. We will be working with the [Freebase](https://paperswithcode.com/dataset/fb15k) dataset (FB15k-237) and implementing [TransE](https://proceedings.neurips.cc/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf) in order to embed the graph.

In this notebook, we first learn how to process the FB15k-237 dataset using PyTorch Geometric, and we analyze the graph using NetworkX and an external tool called [Gephi](https://gephi.org/). Then, we dive into building the TransE architecture with PyTorch. Lastly, we train and evaluate our model for the task of predicting missing tails. We measure our method's performance using the Hits @ 10, Mean Rank, and MRR (mean reciprocal rank) metrics.

# Device

To achieve optimal results in a short period of time, it is recommended to run this Colab on a GPU. If you are using a GPU to run this Colab, make sure to set the variable  `use_gpu` to `True`
(In order for this Colab to run in a short amount of time without a GPU, we train for less epochs and evaluate less often)

In [None]:
# whether you are using a GPU to run this Colab
use_gpu = False
# whether you are using a custom GCE env to run the Colab (uses different CUDA)
custom_GCE_env = False

# Installation

In [None]:
if custom_GCE_env:
  !pip install torch-scatter -f https://data.pyg.org/whl/torch-1.9.0+cu102.html
  !pip install torch-sparse -f https://data.pyg.org/whl/torch-1.9.0+cu102.html
else:
  !pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
  !pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
!pip install torch-geometric

Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu111.html
Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu111.html


In [None]:
import os
import torch
import torch_geometric
import numpy as np
import math
torch_geometric.__version__

'2.0.2'

# Processing FB15k-237 Dataset

FB15k-237 captures entities, types, and relations from the Freebase dataset. There are a total of 310,116 triplets, with 14,541 entities and 237 relations. Each triplet is represented as $(h, l, t)$, where $h$ is the head entity, $t$ represents the tail entity, and $l$ represents the relation type. We will be storing this dataset a PyTorch Geometric `InMemoryDataset`. The advantage of using this class is that we can load the data all at once and easily access it later. In order to define the `FB15Dataset` class as inheriting from the `InMemoryDataset`, we have to implement the following methods: `raw_file_names()`, `processed_file_names`, `download()`, and `process()`. More information about these methods can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html).

Importantly, in the `process()` method we will load the dataset into train, valid, and test splits, where each has a Tensor representing its `edge_index` and `edge_type`. Later on, we can access the head entities with `edge_index[0,:]` and the tail entities with `edge_index[1,:]`. 


In [None]:
class FB15kDataset(torch_geometric.data.InMemoryDataset):
  r"""FB15-237 dataset from Freebase.
  Follows similar structure to torch_geometric.datasets.rel_link_pred_dataset
  
  Args: 
    root (string): Root directory where the dataset should be saved.
    transform (callable, optional): A function/transform that takes in an
        :obj:`torch_geometric.data.Data` object and returns a transformed
        version. The data object will be transformed before every access.
        (default: :obj:`None`)
    pre_transform (callable, optional): A function/transform that takes in
        an :obj:`torch_geometric.data.Data` object and returns a
        transformed version. The data object will be transformed before
        being saved to disk. (default: :obj:`None`)
  """
  data_path = 'https://raw.githubusercontent.com/DeepGraphLearning/' \
              'KnowledgeGraphEmbedding/master/data/FB15k-237'
  def __init__(self, root, transform=None, pre_transform=None):
    super().__init__(root, transform, pre_transform)
    self.data, self.slices = torch.load(self.processed_paths[0])

  @property
  def raw_file_names(self):
    return ['train.txt', 'valid.txt', 'test.txt', 
            'entities.dict', 'relations.dict']
  
  @property
  def processed_file_names(self):
    return ['data.pt']

  @property
  def raw_dir(self):
    return os.path.join(self.root, 'raw')
  
  def download(self):
      for file_name in self.raw_file_names:
        torch_geometric.data.download_url(f'{self.data_path}/{file_name}',
                                          self.raw_dir)

  def process(self):
    with open(os.path.join(self.raw_dir, 'entities.dict'), 'r') as f:
      lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
      entities_dict = {key: int(value) for value, key in lines}

    with open(os.path.join(self.raw_dir, 'relations.dict'), 'r') as f:
      lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
      relations_dict = {key: int(value) for value, key in lines}

    kwargs = {}
    for split in ['train', 'valid', 'test']:
      with open(os.path.join(self.raw_dir, f'{split}.txt'), 'r') as f:
        lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
        heads = [entities_dict[row[0]] for row in lines]
        relations = [relations_dict[row[1]] for row in lines]
        tails = [entities_dict[row[2]] for row in lines]
        kwargs[f'{split}_edge_index'] = torch.tensor([heads, tails])
        kwargs[f'{split}_edge_type'] = torch.tensor(relations)

    _data = torch_geometric.data.Data(num_entities=len(entities_dict), 
                                      num_relations=len(relations_dict),
                                      **kwargs)
    
    if self.pre_transform is not None:
      _data = self.pre_transform(_data)

    data, slices = self.collate([_data])

    torch.save((data, slices), self.processed_paths[0])

FB15k_dset = FB15kDataset(root='FB15k')
data = FB15k_dset[0]

# Visualizing Our Data

Now that we have easy access to our data, we will analyze it and export it into a format that can be visualized using tools such as [Gephi](https://gephi.org/). First, lets take a look at the size of our graph and it's different splits.

In [None]:
print(f'The graph has a total of {data.num_entities} entities and {data.num_relations} relations.')
print(f'The train split has {data.train_edge_type.size()[0]} relation triples.')
print(f'The valid split has {data.valid_edge_type.size()[0]} relation triples.')
print(f'The test split has {data.test_edge_type.size()[0]} relation triples.')

The graph has a total of 14541 entities and 237 relations.
The train split has 272115 relation triples.
The valid split has 17535 relation triples.
The test split has 20466 relation triples.


Notice how this dataset has a very large number of entities and relation triples. This means that the graph will be very hard to visualize using standard tools like NetworkX and matplotlib. So, we will export the data so it can be visualized using tools made for large graphs. Even with these tools, the graph can still be quite difficult to manage, so we will sample the graph before exporting it.

In [None]:
import random
import networkx as nx

def visualize_graph(data, out_path, reduction=200):
  nodes = range(data.num_entities)
  edges = []
  edge_relations = []

  num_edges = 0
  for i in range(data.train_edge_type.size()[0]):
    if random.randint(1, reduction) == 1: # only take random subset of edges
      edges.append((int(data.train_edge_index[0,i]),
                    int(data.train_edge_index[1,i])))
      edge_relations.append(int(data.train_edge_type[i]))
      num_edges += 1
  
  G = nx.DiGraph()
  G.add_edges_from(edges)
  for i, (head, tail) in enumerate(edges):
    G[head][tail]["relation"] = edge_relations[i]
  nx.write_gexf(G, path=out_path)

visualize_graph(data, 'FB15k_sampled_dset.gexf')

This is what our resulting graph looks like in Gephi. The most common relationship types are at the top left.

![](https://drive.google.com/uc?export=view&id=12h5fuXznj0llqSDPbvA4jmaRVDO7p5dr) 

# TransE Implementation

Now, we will lay out the architecture for TransE. As stated earlier, the edges in KGs are represented as triples $(h, r, t)$. In TransE, we model both the entities and the relations in an embedding space and try to adjust the embeddings such at $\textbf h + \textbf l \approx \textbf t$ (where bold letters are embeddings). Formally, the loss is:

$\sum_{((h, l, t), (h', l, t')) \in T_{batch}} [\gamma + d(\textbf{h} + \textbf{l}, \textbf t) - d(\textbf{h'} + \textbf l, \textbf{t'})]$

where $(h', l, t')$ represents a corrupted triple by either replacing the head or tail with a random entity.

In total, the TransE algorithm is as follows

![](https://drive.google.com/uc?export=view&id=18XD6duW40-jddEXgZzGHzyvjAstGv34P) 

In terms of model implementation, we will initialize $\textbf l$ and $\textbf e$ according to the pseudocode above. To calculate $d(\textbf{h} + \textbf{l}, \textbf t)$, we take the L2 norm of $\textbf h + \textbf l - \textbf t$.

*Note: we normalize $\textbf e$ every epoch instead of every minibatch because it led to improved performance, so it does not appear in our model code.*

In [None]:
class TransE(torch.nn.Module):
  def __init__(self, num_entities, num_relations, device, embedding_dim=50, 
               margin=1.0, visualize=False):
    super(TransE, self).__init__()
    self.device = device
    self.num_entities = num_entities
    self.num_relations = num_relations
    self.embedding_dim = embedding_dim

    self.entities_emb = self.init_emb(num_entities, embedding_dim, emb_type='entity')
    self.relations_emb = self.init_emb(num_relations, embedding_dim, emb_type='relation')

    self.criterion = torch.nn.MarginRankingLoss(margin=margin, reduction='none')

    self.visualize = visualize
    # used for visualization
    if visualize:
      self.pca = PCA(n_components=2)


  def init_emb(self, size, dim, emb_type='relation'):
    emb = torch.nn.Embedding(num_embeddings=size, 
                             embedding_dim=dim)
    uniform_max = 6 / np.sqrt(dim) # 6 / sqrt(k)
    emb.weight.data.uniform_(-uniform_max, uniform_max)
    if emb_type == 'relation':
      emb_norm = torch.norm(emb.weight.data, dim=1, keepdim=True)
      emb.weight.data = emb.weight.data / emb_norm # l = l / ||l||
    return emb

  def forward(self, edge_index, negative_edge_index, edge_type):
    positive_distance = self.distance(edge_index, edge_type)
    negative_distance = self.distance(negative_edge_index, edge_type)
    return self.loss(positive_distance, negative_distance)
  
  def predict(self, edge_index, edge_type):
    return self.distance(edge_index, edge_type)

  def distance(self, edge_index, edge_type):
    heads = edge_index[0,:]
    tails = edge_index[1,:]
    return (self.entities_emb(heads) + self.relations_emb(edge_type) - \
            self.entities_emb(tails)).norm(p=2., dim=1, keepdim=True) # l2 norm of h + l - t
  
  def loss(self, positive_distance, negative_distance):
    y = torch.tensor([-1], dtype=torch.long, device=self.device)
    return self.criterion(positive_distance, negative_distance, y).sum()

# Training and Testing

Now that we have implemented data processing and the model, it is time to train and test on the task of predicting missing tails.

One key aspect of training our model is creating corrupted triples by either replacing the head or tail with a random entity. We will do this once for every epoch, randomly replacing heads and tails

In [None]:
def create_corrupted_edge_index(edge_index, edge_type, num_entities):
  corrupt_head_or_tail = torch.randint(high=2, size=edge_type.size(),
                                       device=device)
  random_entities = torch.randint(high=num_entities, 
                                  size=edge_type.size(), device=device)
  # corrupt when 1, otherwise regular head
  heads = torch.where(corrupt_head_or_tail == 1, random_entities, 
                      edge_index[0,:])
  # corrupt when 0, otherwise regular tail
  tails = torch.where(corrupt_head_or_tail == 0, random_entities, 
                      edge_index[1,:])
  return torch.stack([heads, tails], dim=0)

Other than corrupting samples, the training process is pretty standard. One thing to keep in mind is training samples are shuffled between epochs.

In [None]:
def train(model, data, optimizer, device, epochs=50, batch_size=128,
          eval_batch_size=256, valid_freq=5):
  train_edge_index = data.train_edge_index.to(device)
  train_edge_type = data.train_edge_type.to(device)

  best_valid_score = 0
  valid_scores = None
  test_scores = None
  for epoch in range(epochs):
    model.train()
    # e = e / ||e||
    entities_norm = torch.norm(model.entities_emb.weight.data, dim=1, keepdim=True)
    model.entities_emb.weight.data = model.entities_emb.weight.data / entities_norm

    # shuffle train set each batch
    num_triples = data.train_edge_type.size()[0]
    shuffled_triples_order = np.arange(num_triples)
    np.random.shuffle(shuffled_triples_order)
    shuffled_edge_index = train_edge_index[:, shuffled_triples_order]
    shuffled_edge_type = train_edge_type[shuffled_triples_order]

    negative_edge_index = create_corrupted_edge_index(shuffled_edge_index,
                                                      shuffled_edge_type,
                                                      data.num_entities)
    
    total_loss = 0
    total_size = 0
    for batch_idx in range(math.ceil(num_triples / batch_size)):
      batch_start = batch_idx * batch_size
      batch_end = (batch_idx + 1) * batch_size
      batch_edge_index = shuffled_edge_index[:,batch_start:batch_end]
      batch_negative_edge_index = negative_edge_index[:,batch_start:batch_end]
      batch_edge_type = shuffled_edge_type[batch_start:batch_end]
      loss = model(batch_edge_index, batch_negative_edge_index, batch_edge_type)
      total_loss += loss.item()
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      total_size += batch_edge_type.size()[0]
      if model.visualize  and epoch == 0 \
      and batch_idx % 100 == 0:
        visualize_emb(model, batch_idx)
    print(f'Epoch {epoch}, train loss equals {total_loss / total_size}')
    if (epoch + 1) % valid_freq == 0:
      mrr_score, mr_score, hits_at_10 = eval(model, data.valid_edge_index.to(device),
                                             data.valid_edge_type.to(device),
                                             data.num_entities, device)
      print(f'Validation score equals {mrr_score}, {mr_score}, {hits_at_10}')
      if mrr_score > best_valid_score:
        valid_scores = (mrr_score, mr_score, hits_at_10)
        test_mmr_score, test_mr_score, test_hits_at_10 = \
                                        eval(model, data.valid_edge_index.to(device),
                                             data.valid_edge_type.to(device),
                                             data.num_entities, device)
        test_scores = (test_mmr_score, test_mr_score, test_hits_at_10)
    # break
  print(f'Test scores from best model (mmr, mr, h@10): {test_scores}')


For metrics, we measure Hits @ 10, Mean Rank, and MRR (mean reciprocal rank).

First, Hits @ 10 = $\frac{|\{r \in P | r \leq 10\}|}{|P|}$, where $|P|$ is the number of rank scores and $r$ is rank.

Second, Mean Rank = $\frac{1}{|P|}\sum_{r \in P}r$

Third, MMR = $\frac{1}{|P|}\sum_{r \in P}\frac{1}{r}$

MRR is the most trusted out of the three metrics. Hits @ 10 suffers from the fact that any rank above 10 is not differentiated, and Mean Rank suffers from the number of entities highly affecting the overall score. There is more information about the metrics [here](https://arxiv.org/pdf/2002.06914.pdf).

For our implementations, we divide by $|P|$ after adding up all the scores in the minibatches.

In [None]:
def mrr(predictions, gt):
  """MRR score adopted from 
  https://github.com/mklimasz/TransE-PyTorch/blob/master/metric.py
  """
  indices = predictions.argsort()
  return (1.0 / (indices == gt).nonzero()[:, 1].float().add(1.0)).sum().item()

def mr(predictions, gt):
  indices = predictions.argsort()
  return ((indices == gt).nonzero()[:, 1].float().add(1.0)).sum().item()

def hit_at_k(predictions, gt, device, k=10):
  """Hit @ k score adopted from 
  https://github.com/mklimasz/TransE-PyTorch/blob/master/metric.py
  """
  zero_tensor = torch.tensor([0], device=device)
  one_tensor = torch.tensor([1], device=device)
  _, indices = predictions.topk(k=k, largest=False)
  return torch.where(indices == gt, one_tensor, zero_tensor).sum().item()

For testing, we remove each tail and replace it with every entity in the graph. We then calculate the distances between $\textbf h + \textbf l$ and every entity, and the entities are sorted in ascending order. 

In [None]:
def eval(model, edge_index, edge_type, num_entities, device, eval_batch_size=64):
  model.eval()
  num_triples = edge_type.size()[0]
  mrr_score = 0
  mr_score = 0
  hits_at_10 = 0
  num_predictions = 0

  for batch_idx in range(math.ceil(num_triples / eval_batch_size)):
    batch_start = batch_idx * eval_batch_size
    batch_end = (batch_idx + 1) * eval_batch_size
    batch_edge_index = edge_index[:,batch_start:batch_end]
    batch_edge_type = edge_type[batch_start:batch_end]
    batch_size = batch_edge_type.size()[0] # can be different on last batch

    all_entities = torch.arange(end=num_entities, 
                                device=device).unsqueeze(0).repeat(batch_size, 1)
    head_repeated = batch_edge_index[0,:].reshape(-1, 1).repeat(1, num_entities)
    relation_repeated = batch_edge_type.reshape(-1, 1).repeat(1, num_entities)

    head_squeezed = head_repeated.reshape(-1)
    relation_squeezed = relation_repeated.reshape(-1)
    all_entities_squeezed = all_entities.reshape(-1)

    entity_index_replaced_tail = torch.stack((head_squeezed,all_entities_squeezed))
    predictions = model.predict(entity_index_replaced_tail, relation_squeezed)
    predictions = predictions.reshape(batch_size, -1)
    gt = batch_edge_index[1,:].reshape(-1, 1)

    mrr_score += mrr(predictions, gt)
    mr_score += mr(predictions, gt)
    hits_at_10 += hit_at_k(predictions, gt, device=device, k=10)
    num_predictions += predictions.size()[0]

  mrr_score = mrr_score / num_predictions
  mr_score = mr_score / num_predictions
  hits_at_10 = hits_at_10 / num_predictions
  return mrr_score, mr_score, hits_at_10

### Visualize embeddings

Since we are using shallow embeddings to represent entities and relations, it is straightforward to visualize the embeddings during training. Here, we use PCA to transform the 50 dimensional vectors to 2d.

In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from matplotlib import cm

def visualize_emb(model, idx):
  if idx == 0:
    out = model.pca.fit_transform(model.relations_emb.weight.data)
  else:
    out = model.pca.transform(model.relations_emb.weight.data)
  cmap = cm.get_cmap('tab20')
  fig, ax = plt.subplots(figsize=(8,8))
  for entity_id in range(np.shape(out)[0]):
    ax.scatter(out[entity_id,0], out[entity_id,1], color=cmap(entity_id % 20))
  plt.savefig('visualize_relations_emb_' + str(idx) + '.png')
  plt.show()

### Start Training!

Now that we have everything set, we can start training.

*Note: evaluation will take quite a while without a GPU (~ 5 minutes)*

In [None]:
lr = 0.01
if use_gpu:
  epochs = 50
  valid_freq = 5
else:
  epochs = 10
  valid_freq = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TransE(data.num_entities, data.num_relations, device, visualize=False).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
train(model, data, optimizer, device, epochs=epochs, valid_freq=valid_freq)

Epoch 0, train loss equals 0.8450540051001942
Epoch 1, train loss equals 0.7025755662292331
Epoch 2, train loss equals 0.607560314087836
Epoch 3, train loss equals 0.5218951126641717
Epoch 4, train loss equals 0.44190907410429464
Epoch 5, train loss equals 0.37600184632441114
Epoch 6, train loss equals 0.32341098243928285
Epoch 7, train loss equals 0.2838035914757863
Epoch 8, train loss equals 0.2542663452536475
Epoch 9, train loss equals 0.2328555125673947
Validation score equals 0.2427326449037176, 607.9158254918734, 0.3835186769318506
Test scores from best model (mmr, mr, h@10): (0.2427326449037176, 607.9158254918734, 0.3835186769318506)


Using a GPU, the test MMR should be around .25, and the Hits @ k should be around .40. Without a GPU, the test MMR should be around .24 and the Hits @ k should be around .38

# Exploring Weaknesses of TransE for FB15k-237

Although TransE is a powerful method for representing knowledge graph embeddings, one relation type that it cannot model is symmetric relations. In this last section, we look at the FB15k-237 dataset and approximate the frequency of symmetric relations.

*Note: this cell takes about 2 minutes to complete*

In [None]:
from tqdm import tqdm
def find_symmetric_relations(edge_index, edge_type, sample=50):
  # (h, l, t) ⇒ (t, l, h)
  num_triplets = edge_type.size()[0]
  shuffled_triplets_order = np.arange(num_triplets)
  np.random.shuffle(shuffled_triplets_order)
  shuffled_triplets_order = shuffled_triplets_order[:sample]

  num_symmetric_relations = 0
  # total_relations = 0
  for i in list(shuffled_triplets_order):
    head = edge_index[0, i]
    tail = edge_index[1, i]
    relation = edge_type[i]
    for j in range(num_triplets):
      if i == j:
        continue
      _head = edge_index[0, j]
      _tail = edge_index[1, j]
      _relation = edge_type[j]
      if (head == _tail and tail == _head and relation == _relation):
        num_symmetric_relations += 1
  print(f'When sampling {sample} triplets, found {num_symmetric_relations} '\
           'symmetric relations present in the graph.')

find_symmetric_relations(data.train_edge_index, data.train_edge_type)

When sampling 50 triplets, found 6 symmetric relations present in the graph.
