In [2]:
# whether you are using a GPU to run this Colab
use_gpu = True
# whether you are using a custom GCE env to run the Colab (uses different CUDA)
custom_GCE_env = False

# Installations for pyTorch geometry

In [3]:
if custom_GCE_env:
  !pip install torch-scatter -f https://data.pyg.org/whl/torch-1.9.0+cu102.html
  !pip install torch-sparse -f https://data.pyg.org/whl/torch-1.9.0+cu102.html
else:
  !pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
  !pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
!pip install torch-geometric

Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu111.html
Collecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torch-scatter
  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
  Created wheel for torch-scatter: filename=torch_scatter-2.1.2-cp310-cp310-linux_x86_64.whl size=3662083 sha256=ef4c7e301b8f429fe50f6e671e2db09009f7b992f5b0bfc77c51669dc02ffc92
  Stored in directory: /root/.cache/pip/wheels/92/f1/2b/3b46d54b134259f58c8363568569053248040859b1a145b3ce
Successfully built torch-scatter
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2
Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu111.html
Collecting torch-sparse
  Downloading torch_sparse-0.6.18.tar.gz (209 kB)
[2K     

# Import Libraries

In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import math
from torch_geometric.data import InMemoryDataset, DataLoader
import torch_geometric

# Step 1: Data Preprocessing and Preparation

In [7]:
class FB15kDataset(torch_geometric.data.InMemoryDataset):
    r"""FB15-237 dataset from Freebase.
    Follows similar structure to torch_geometric.datasets.rel_link_pred_dataset

    Args:
      root (string): Root directory where the dataset should be saved.
      transform (callable, optional): A function/transform that takes in an
          :obj:`torch_geometric.data.Data` object and returns a transformed
          version. The data object will be transformed before every access.
          (default: :obj:`None`)
      pre_transform (callable, optional): A function/transform that takes in
          an :obj:`torch_geometric.data.Data` object and returns a
          transformed version. The data object will be transformed before
          being saved to disk. (default: :obj:`None`)
    """
    data_path = 'https://raw.githubusercontent.com/DeepGraphLearning/' \
                'KnowledgeGraphEmbedding/master/data/FB15k-237'

    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['train.txt', 'valid.txt', 'test.txt',
                'entities.dict', 'relations.dict']

    @property
    def processed_file_names(self):
        return ['data.pt']

    @property
    def raw_dir(self):
        return os.path.join(self.root, 'raw')

    def download(self):
        for file_name in self.raw_file_names:
            torch_geometric.data.download_url(f'{self.data_path}/{file_name}',
                                              self.raw_dir)

    def process(self):
        with open(os.path.join(self.raw_dir, 'entities.dict'), 'r') as f:
            lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
            entities_dict = {key: int(value) for value, key in lines}

        with open(os.path.join(self.raw_dir, 'relations.dict'), 'r') as f:
            lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
            relations_dict = {key: int(value) for value, key in lines}

        kwargs = {}
        for split in ['train', 'valid', 'test']:
            with open(os.path.join(self.raw_dir, f'{split}.txt'), 'r') as f:
                lines = [row.split('\t') for row in f.read().split('\n')[:-1]]
                heads = [entities_dict[row[0]] for row in lines]
                relations = [relations_dict[row[1]] for row in lines]
                tails = [entities_dict[row[2]] for row in lines]
                kwargs[f'{split}_edge_index'] = torch.tensor([heads, tails])
                kwargs[f'{split}_edge_type'] = torch.tensor(relations)

        _data = torch_geometric.data.Data(num_entities=len(entities_dict),
                                          num_relations=len(relations_dict),
                                          **kwargs)

        if self.pre_transform is not None:
            _data = self.pre_transform(_data)

        data, slices = self.collate([_data])

        torch.save((data, slices), self.processed_paths[0])

# Load dataset
FB15k_dset = FB15kDataset(root='FB15k')
data = FB15k_dset[0]


Downloading https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/train.txt
Downloading https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/valid.txt
Downloading https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/test.txt
Downloading https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/entities.dict
Downloading https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/relations.dict
Processing...
Done!
  self.data, self.slices = torch.load(self.processed_paths[0])


# Step 2: Define the TransE Model

In [9]:
class TransE(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim, margin, distance_metric='L1', visualize=False):
        super(TransE, self).__init__()
        self.entity_embeddings = nn.Embedding(num_entities, embedding_dim)
        self.relation_embeddings = nn.Embedding(num_relations, embedding_dim)
        self.margin = margin
        self.distance_metric = distance_metric
        self.visualize = visualize

        # Initialize embeddings using TransE paper's method
        uniform_max = 6 / np.sqrt(embedding_dim)
        self.entity_embeddings.weight.data.uniform_(-uniform_max, uniform_max)
        self.relation_embeddings.weight.data.uniform_(-uniform_max, uniform_max)

    def forward(self, edge_index, negative_edge_index, edge_type):
        positive_distance = self.distance(edge_index, edge_type)
        negative_distance = self.distance(negative_edge_index, edge_type)
        return self.loss(positive_distance, negative_distance)

    def predict(self, edge_index, edge_type):
        return self.distance(edge_index, edge_type)

    def distance(self, edge_index, edge_type):
        heads = edge_index[0, :]
        tails = edge_index[1, :]
        return (self.entity_embeddings(heads) + self.relation_embeddings(edge_type) -
                self.entity_embeddings(tails)).norm(p=2., dim=1, keepdim=True)  # L2 norm of h + r - t

    # def loss(self, positive_distance, negative_distance):
    #     y = torch.tensor([-1], dtype=torch.long, device=self.entity_embeddings.weight.device)
    #     criterion = nn.MarginRankingLoss(margin=self.margin)
    #     return criterion(positive_distance, negative_distance, y)

    def loss(self, positive_distance, negative_distance):
      batch_size = positive_distance.size(0)
      y = torch.full((batch_size, 1), -1, dtype=torch.float, device=self.entity_embeddings.weight.device)
      criterion = nn.MarginRankingLoss(margin=self.margin)
      return criterion(positive_distance, negative_distance, y)



# Helper function to create corrupted edges
def create_corrupted_edge_index(edge_index, edge_type, num_entities):
    corrupt_head_or_tail = torch.randint(high=2, size=edge_type.size(),
                                         device=edge_index.device)
    random_entities = torch.randint(high=num_entities,
                                     size=edge_type.size(), device=edge_index.device)
    # corrupt when 1, otherwise regular head
    heads = torch.where(corrupt_head_or_tail == 1, random_entities,
                        edge_index[0, :])
    # corrupt when 0, otherwise regular tail
    tails = torch.where(corrupt_head_or_tail == 0, random_entities,
                        edge_index[1, :])
    return torch.stack([heads, tails], dim=0)

# Step 3: Training the Model

In [11]:
def train(model, data, optimizer, device, epochs=50, batch_size=128, valid_freq=5):
    train_edge_index = data.train_edge_index.to(device)
    train_edge_type = data.train_edge_type.to(device)
    valid_edge_index = data.valid_edge_index.to(device)
    valid_edge_type = data.valid_edge_type.to(device)

    best_valid_score = 0
    valid_scores = None
    test_scores = None

    for epoch in range(epochs):
        model.train()

        # Normalize entity embeddings
        entities_norm = torch.norm(model.entity_embeddings.weight.data, dim=1, keepdim=True)
        model.entity_embeddings.weight.data = model.entity_embeddings.weight.data / entities_norm

        # Shuffle the training data
        num_triples = train_edge_type.size(0)
        shuffled_indices = torch.randperm(num_triples)
        shuffled_edge_index = train_edge_index[:, shuffled_indices]
        shuffled_edge_type = train_edge_type[shuffled_indices]

        negative_edge_index = create_corrupted_edge_index(shuffled_edge_index, shuffled_edge_type, data.num_entities)

        total_loss = 0
        total_size = 0

        for batch_start in range(0, num_triples, batch_size):
            batch_end = min(batch_start + batch_size, num_triples)
            batch_edge_index = shuffled_edge_index[:, batch_start:batch_end]
            batch_negative_edge_index = negative_edge_index[:, batch_start:batch_end]
            batch_edge_type = shuffled_edge_type[batch_start:batch_end]

            loss = model(batch_edge_index, batch_negative_edge_index, batch_edge_type)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * (batch_end - batch_start)
            total_size += batch_end - batch_start

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / total_size:.4f}")

        # Validation at regular intervals
        if (epoch + 1) % valid_freq == 0:
            mrr_score, mr_score, hits_at_10 = evaluate_model(
                model, valid_edge_index, valid_edge_type, data.num_entities, device
            )
            print(f"Validation score: MRR = {mrr_score:.4f}, MR = {mr_score:.4f}, Hits@10 = {hits_at_10:.4f}")

            # Track best validation score
            if mrr_score > best_valid_score:
                best_valid_score = mrr_score
                test_mrr, test_mr, test_hits_at_10 = evaluate_model(
                    model, data.test_edge_index.to(device), data.test_edge_type.to(device), data.num_entities, device
                )
                test_scores = (test_mrr, test_mr, test_hits_at_10)

    print(f"Test scores from the best model (MMR, MR, Hits@10): {test_scores}")

# Metric Functions

In [12]:
def mrr(predictions, gt):
    indices = predictions.argsort()
    return (1.0 / (indices == gt).nonzero()[:, 1].float().add(1.0)).sum().item()

def mr(predictions, gt):
    indices = predictions.argsort()
    return ((indices == gt).nonzero()[:, 1].float().add(1.0)).sum().item()

def hit_at_k(predictions, gt, device, k=10):
    zero_tensor = torch.tensor([0], device=device)
    one_tensor = torch.tensor([1], device=device)
    _, indices = predictions.topk(k=k, largest=False)
    return torch.where(indices == gt, one_tensor, zero_tensor).sum().item()

# Step 4: Prediction and Evaluation

In [13]:
def evaluate_model(model, edge_index, edge_type, num_entities, device, eval_batch_size=64):
    model.eval()
    num_triples = edge_type.size(0)
    mrr_score = 0
    mr_score = 0
    hits_at_10 = 0
    num_predictions = 0

    with torch.no_grad():
        for batch_idx in range(math.ceil(num_triples / eval_batch_size)):
            batch_start = batch_idx * eval_batch_size
            batch_end = min((batch_idx + 1) * eval_batch_size, num_triples)
            batch_edge_index = edge_index[:, batch_start:batch_end]
            batch_edge_type = edge_type[batch_start:batch_end]
            batch_size = batch_edge_type.size(0)

            all_entities = torch.arange(num_entities, device=device).unsqueeze(0).repeat(batch_size, 1)
            head_repeated = batch_edge_index[0, :].reshape(-1, 1).repeat(1, num_entities)
            relation_repeated = batch_edge_type.reshape(-1, 1).repeat(1, num_entities)

            head_squeezed = head_repeated.reshape(-1)
            relation_squeezed = relation_repeated.reshape(-1)
            all_entities_squeezed = all_entities.reshape(-1)

            entity_index_replaced_tail = torch.stack((head_squeezed, all_entities_squeezed))
            predictions = model.predict(entity_index_replaced_tail, relation_squeezed)
            predictions = predictions.reshape(batch_size, -1)
            gt = batch_edge_index[1, :].reshape(-1, 1)

            mrr_score += mrr(predictions, gt)
            mr_score += mr(predictions, gt)
            hits_at_10 += hit_at_k(predictions, gt, device=device, k=10)
            num_predictions += batch_size

    mrr_score = mrr_score / num_predictions
    mr_score = mr_score / num_predictions
    hits_at_10 = hits_at_10 / num_predictions
    return mrr_score, mr_score, hits_at_10

# Run evaluation on the dataset

learning rate λ for the stochastic gradient descent among {0.001,0.01,0.1}, the margin γ among {1,2,10} and the latent dimension
kamong {20,50} on the validation set of each data set. The dissimilarity measure dwas set either
to the L1 or L2 distance according to validation performance as well. Optimal configurations were:
k = 20, λ = 0.01, γ = 2, and d= L1 on Wordnet; k = 50, λ = 0.01, γ = 1, and d= L1 on
FB15k; k= 50, λ= 0.01, γ = 1, and d= L2 on FB1M

In [21]:
# Start Training
lr = 0.1
use_gpu = torch.cuda.is_available()
if use_gpu:
    epochs = 50
    valid_freq = 5
else:
    epochs = 10
    valid_freq = 10

device = torch.device('cuda' if use_gpu else 'cpu')
model = TransE(data.num_entities, data.num_relations, embedding_dim=400, margin=2).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

train(model, data, optimizer, device, epochs=epochs, valid_freq=valid_freq)

Epoch 1/50, Loss: 1.9712
Epoch 2/50, Loss: 1.9277
Epoch 3/50, Loss: 1.9035
Epoch 4/50, Loss: 1.8838
Epoch 5/50, Loss: 1.8672
Validation score: MRR = 0.2022, MR = 3396.3455, Hits@10 = 0.2967
Epoch 6/50, Loss: 1.8512
Epoch 7/50, Loss: 1.8364
Epoch 8/50, Loss: 1.8245
Epoch 9/50, Loss: 1.8116
Epoch 10/50, Loss: 1.8001
Validation score: MRR = 0.2205, MR = 2809.7321, Hits@10 = 0.3300
Epoch 11/50, Loss: 1.7886
Epoch 12/50, Loss: 1.7787
Epoch 13/50, Loss: 1.7700
Epoch 14/50, Loss: 1.7593
Epoch 15/50, Loss: 1.7509
Validation score: MRR = 0.2292, MR = 2423.0732, Hits@10 = 0.3502
Epoch 16/50, Loss: 1.7423
Epoch 17/50, Loss: 1.7345
Epoch 18/50, Loss: 1.7263
Epoch 19/50, Loss: 1.7196
Epoch 20/50, Loss: 1.7130
Validation score: MRR = 0.2344, MR = 2113.4531, Hits@10 = 0.3603
Epoch 21/50, Loss: 1.7047
Epoch 22/50, Loss: 1.6982
Epoch 23/50, Loss: 1.6895
Epoch 24/50, Loss: 1.6846
Epoch 25/50, Loss: 1.6775
Validation score: MRR = 0.2379, MR = 1855.6990, Hits@10 = 0.3675
Epoch 26/50, Loss: 1.6719
Epoch 27