<a href="https://colab.research.google.com/github/SanjeevaRDodlapati/test/blob/main/Simple_Schemes_for_Knowledge_Graph_Embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Simple Schemes for Knowledge Graph Embedding**

This colab presents code for training knowledge graph embeddings on the FB15k-237 knowledge base using two KG embedding techniques (rotatE and transE) and evaluates the metrics on MRR and Hits@K. This colab can be run to save entitiy and relation embeddings at every 10 epochs.

Included in the colab is also code to preprocess the data and display the dimensional embedings to illustrate the geometrical properties of the embedding schemas (Visualization Tools and MID to Entity Name Mapping sections). However, this portion of the code was run on a GCP VM and does not run out of the box on colab.

In [None]:
!wget -r -nH --cut-dirs=2 -np https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/entities.dict
!wget -r -nH --cut-dirs=2 -np https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/relations.dict
!wget -r -nH --cut-dirs=2 -np https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/test.txt
!wget -r -nH --cut-dirs=2 -np https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/train.txt
!wget -r -nH --cut-dirs=2 -np https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/valid.txt

--2021-12-10 06:51:26--  https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/entities.dict
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 218551 (213K) [text/plain]
Saving to: ‘master/data/FB15k-237/entities.dict’


2021-12-10 06:51:26 (6.96 MB/s) - ‘master/data/FB15k-237/entities.dict’ saved [218551/218551]

FINISHED --2021-12-10 06:51:26--
Total wall clock time: 0.5s
Downloaded: 1 files, 213K in 0.03s (6.96 MB/s)
--2021-12-10 06:51:26--  https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/FB15k-237/relations.dict
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (r

**Custom DataLoader**

In [None]:
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

# Allow for reproducible results
np.random.seed(234)
torch.manual_seed(234)

<torch._C.Generator at 0x7f7895f09610>

In [None]:
class DataGenerator(Dataset):

    def __init__(self, triples, num_entities, num_negative_samples, all_triples=None, data_type="train"):
        super(Dataset, self).__init__()
        self.triples = triples
        self.num_entities = num_entities
        self.num_negative_samples = num_negative_samples
        self.all_triples = all_triples
        self.data_type = data_type
        self.len = len(triples)
        self.true_head_relation, self.true_relation_tail = self._get_true_head_tail_lists()

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        positive_sample = self.triples[idx]
        head, relation, tail = positive_sample

        positive_sample = torch.LongTensor(positive_sample)

        if self.data_type == "train":
            true_heads = self.true_relation_tail[(relation, tail)]
            true_tails = self.true_head_relation[(head, relation)]

            # A number of negative sampling methods was tried - including
            # taking a set difference before random sampling, randomly sampling
            # each corrupted head, etc but the method implemented in the paper
            # ended up being the fastest
            corrupted_heads = []
            corrupted = np.random.randint(self.num_entities, size=self.num_negative_samples*2)
            while len(corrupted_heads) < self.num_negative_samples:
                mask = np.in1d(corrupted, true_heads, assume_unique=True, invert=True)
                corrupted = corrupted[mask]
                corrupted_heads.extend(corrupted)
            corrupted_heads = corrupted_heads[:self.num_negative_samples]
            corrupted_heads = torch.LongTensor(corrupted_heads)

            corrupted_tails = []
            corrupted = np.random.randint(self.num_entities, size=self.num_negative_samples*2)
            while len(corrupted_tails) < self.num_negative_samples:
                mask = np.in1d(corrupted, true_tails, assume_unique=True, invert=True)
                corrupted = corrupted[mask]
                corrupted_tails.extend(corrupted)
            corrupted_tails = corrupted_tails[:self.num_negative_samples]
            corrupted_tails = torch.LongTensor(corrupted_tails)

            filter_bias = torch.LongTensor([0] * len(positive_sample))
        else:
            # We cannot empirically say that one head is better than another for a valid
            # (head, relation, tail) triplet. Ex. (Bob, friend, Joe), (Jack, friend, Joe).
            # In this case, we replace the alternate triplet with the current true head
            # and add a filter bias of -1 to push it down in the rankings so ideally prevent it
            # from showing up in our HITS@K and MRR metrics
            corrupted_heads = [(0, test_head) if (test_head, relation, tail) not in self.all_triples
                    else (-1, head) for test_head in range(self.num_entities)]
            corrupted_heads[head] = (0, head)
            corrupted_heads = torch.LongTensor(corrupted_heads)

            corrupted_tails = [(0, test_tail) if (head, relation, test_tail) not in self.all_triples
                    else (-1, tail) for test_tail in range(self.num_entities)]
            corrupted_tails[tail] = (0, tail)
            corrupted_tails = torch.LongTensor(corrupted_tails)

            filter_bias = (corrupted_heads[:, 0], corrupted_tails[:, 0])
            corrupted_heads = corrupted_heads[:, 1]
            corrupted_tails = corrupted_tails[:, 1]

        return positive_sample, corrupted_heads, corrupted_tails, filter_bias

    # We need to be able to get a list of true heads and tails
    # quickly during negative sampling
    def _get_true_head_tail_lists(self):
        true_head_relation = defaultdict(set)
        true_relation_tail = defaultdict(set)
        for triplet in self.triples:
            head, relation, tail = triplet
            true_head_relation[(head, relation)].add(tail)
            true_relation_tail[(relation, tail)].add(head)
        return true_head_relation, true_relation_tail

In [None]:
def get_data_loader(triples, num_entities, num_negative_samples, batch_size,
        all_triples=None, data_type="train", num_workers=2):
    return DataLoader(
        DataGenerator(triples, num_entities, num_negative_samples, all_triples=all_triples, data_type=data_type),
        batch_size = batch_size,
        shuffle = True,
        num_workers = num_workers
    )

**RotatE**





In [None]:
def rotatE(head, relation, tail, embedding_range, sample_type, margin):
    # We used 2 x hidden_dim so that we can split them into real and imaginary
    # components here
    head_real, head_imag = torch.chunk(head, 2, dim=2)
    tail_real, tail_imag = torch.chunk(tail, 2, dim=2)

    # This evenly distributes the relation between [-pi, pi]
    norm_relation = relation / (embedding_range / math.pi)

    relation_real = torch.cos(norm_relation)
    relation_imag = torch.sin(norm_relation)

    real_dist = (head_real * relation_real - head_imag * relation_imag) - tail_real
    imag_dist = (head_real * relation_imag + head_imag * relation_real) - tail_imag

    # Each dimension represents its own rotation in imaginary space.
    # Take the Frobenius norm to compute the score and sum across all
    # dimensions
    total_dist = torch.stack([real_dist, imag_dist], dim=0)
    total_dist = torch.linalg.norm(total_dist, dim=0).sum(dim=2)

    # If something is close enough, we don't want to penalize it
    margin_adjusted_dist = margin - total_dist
    return margin_adjusted_dist

**TransE**





In [None]:
def transE(head, relation, tail, sample_type, margin):
  dist = head + relation - tail
  score = margin - torch.linalg.norm(dist, ord=1, dim=2)
  return score

**Generic Knowledge Graph Embedding Model**

In [None]:
import math
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class KGEmbedding(nn.Module):
  def __init__(self, num_entities, num_relations, hidden_dim, margin, model):
    super(KGEmbedding, self).__init__()

    self.num_entities = num_entities
    self.num_relations = num_relations
    # Entity embed needs to have out dim 2 x hidden_dim for rotatE
    # b/c each dimension needs to have a real and imaginary component
    self.entity_dim = 2 * hidden_dim if model == 'rotatE' else hidden_dim
    self.relation_dim = hidden_dim
    self.margin = margin
    self.model = model
    self.epsilon = 2.0

    # In the paper (default):
    # margin = 12.0
    # hidden_dim = 500
    # epsilon = 2.0
    # embedding_range = (margin + epsilon) / hidden_dim = 0.028
    # This means that all entity embeddings and relation embeddings
    # are initialized with value between -0.028 and 0.028

    # Initialization of the embedding range to be close to 0 helps prevent
    # crazy initializations. We divide by hidden_dim to reduce variance in
    # per-parameter initialization as we increase total number of dimensions
    # co consider
    self.embedding_range = (self.margin + self.epsilon) / hidden_dim

    self.entity_embed = nn.Parameter(torch.zeros(self.num_entities, self.entity_dim,
                                    requires_grad=True))
    nn.init.uniform_(
        self.entity_embed,
        a=-self.embedding_range,
        b=self.embedding_range
    )

    # Relation entity can only affect the phase, not the modulus of the
    # entity embedding. The modulus is fixed to be |r_i| = 1.
    self.relation_embed = nn.Parameter(torch.zeros(self.num_relations, self.relation_dim,
                                      requires_grad=True))
    nn.init.uniform_(
        self.relation_embed,
        a=-self.embedding_range,
        b=self.embedding_range
    )

  def forward(self, sample, sample_type):
    # For each positive example, the paper has 128 negative examples using
    # the same head + relation or relation + tail but with a corrupted entity

    # Sample can be a positive example, a negative tail example, or a
    # negative head example. Each needs to be processed slightly
    # differently before being passed to the respective model function

    if sample_type == 'positive':
      # sample = Tensor([batch_size, 3]) where the
      # 3 represents head, relation, tail
      head = torch.index_select(
          self.entity_embed,
          dim=0,
          index=sample[:,0]
      ).unsqueeze(1)

      relation = torch.index_select(
          self.relation_embed,
          dim=0,
          index=sample[:,1]
      ).unsqueeze(1)

      tail = torch.index_select(
          self.entity_embed,
          dim=0,
          index=sample[:,2]
      ).unsqueeze(1)

      # We call .unsqueeze(1) on this data so that
      # the output will be of dimension [batch_size, 1, 2 * self.hidden_dim]
      # so that the num dimensions match the negative examples that will be of
      # dimension [batch_size, num_neg_samples, 2 * self.hiddenIdim]

    elif sample_type == 'negative-head':
      # positive_tuple is torch.Tensor([batch_size, 3])
      # negative_head_entities is torch.Tensor([batch_size, num_neg_samples]) and
      # needs to be used in conjunction with positive_tuple to derive the 128
      # negative examples for each positive example
      positive_tuple, negative_head_entities = sample
      batch_size, num_neg_samples = negative_head_entities.shape

      head = torch.index_select(
          self.entity_embed,
          dim=0,
          index=negative_head_entities.view(-1)
      ).reshape(batch_size, num_neg_samples, self.entity_dim)

      relation = torch.index_select(
          self.relation_embed,
          dim=0,
          index=positive_tuple[:,1]
      ).unsqueeze(1)

      tail = torch.index_select(
          self.entity_embed,
          dim=0,
          index=positive_tuple[:,2]
      ).unsqueeze(1)


    elif sample_type == 'negative-tail':
      # same as "negative-head" except this time
      # the tail needs to be adjusted
      positive_tuple, negative_tail_entities = sample
      batch_size, num_neg_samples = negative_tail_entities.shape

      head = torch.index_select(
          self.entity_embed,
          dim=0,
          index=positive_tuple[:,0]
      ).unsqueeze(1)

      relation = torch.index_select(
          self.relation_embed,
          dim=0,
          index=positive_tuple[:,1]
      ).unsqueeze(1)

      tail = torch.index_select(
          self.entity_embed,
          dim=0,
          index=negative_tail_entities.view(-1)
      ).reshape(batch_size, num_neg_samples, self.entity_dim)

    if self.model == 'rotatE':
      return rotatE(head, relation, tail, self.embedding_range, sample_type, self.margin)
    elif self.model == 'transE':
      return transE(head, relation, tail, sample_type, self.margin)

**Run Model**

In [None]:
import os
from tqdm import trange

DATA_DIR = "master/data/FB15k-237/"
MODEL_DIR = "models/"

In [None]:
# Read the entities and relations dictionary files
def load_dict(file_path):
    loaded_dict = dict()
    with open(file_path, 'r') as f:
        for line in f:
            uid, val = line.strip().split('\t')
            loaded_dict[val] = int(uid)
    return loaded_dict

# Read the KG triples
def load_triples(file_path, entity2id, relation2id):
    triples = list()
    with open(file_path, 'r') as f:
        for line in f:
            head, relation, tail = line.strip().split('\t')
            triples.append((entity2id[head], relation2id[relation], entity2id[tail]))
    return triples

In [None]:
def save_model(model, optimizer, scheduler, epoch):
    # Check if MODEL_DIR exists and create if it doesn't
    if not os.path.exists(MODEL_DIR):
        os.makedirs(MODEL_DIR)

    torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        },
        os.path.join(MODEL_DIR, f'checkpoint_{epoch}')
    )

    entity_embedding = model.entity_embed.detach().cpu().numpy()
    np.save(os.path.join(MODEL_DIR, f'entity_embedding_{epoch}'), entity_embedding)

    relation_embedding = model.relation_embed.detach().cpu().numpy()
    np.save(os.path.join(MODEL_DIR, f'relation_embedding_{epoch}'), relation_embedding)

In [None]:
def eval_model(model, data, num_entities, _num_negative_samples, _batch_size, all_data, data_type):
    model.eval()

    dataloader = get_data_loader(data, num_entities, _num_negative_samples, _batch_size, all_triples=all_data, data_type=data_type)

    final_metrics = defaultdict(float)
    dataset_metrics = []

    with torch.no_grad():
        for batch, (positive_sample, corrupted_heads, corrupted_tails, filter_bias) in enumerate(dataloader):
            head_bias, tail_bias = filter_bias
            if torch.cuda.is_available():
                positive_sample = positive_sample.cuda()
                corrupted_heads = corrupted_heads.cuda()
                corrupted_tails = corrupted_tails.cuda()
                head_bias = head_bias.cuda()
                tail_bias = tail_bias.cuda()

            # When we run eval, this list of "corrupted" values does contain one
            # (or more*) true positive values. We have to include the true triplet to
            # be able to run eval metrics like MRR and Hits@K
            # *If there is more than one true positive value, the others are weighted
            # negatively such that they appear lower in the ranking and will be unlikely
            # to skew the metrics. This is seen in the head_bias and tail_bias below
            corrupted_head_dist = model((positive_sample, corrupted_heads), 'negative-head') + head_bias
            corrupted_tail_dist = model((positive_sample, corrupted_tails), 'negative-tail') + tail_bias

            # We sort by descending, b/c margin in RotatE sets large distances to be negative
            head_arg_order = torch.argsort(corrupted_head_dist, dim=1, descending=True)
            tail_arg_order = torch.argsort(corrupted_tail_dist, dim=1, descending=True)

            true_head = positive_sample[:, 0]
            true_tail = positive_sample[:, 2]

            for ind in range(len(true_head)):
                # Pytorch way to evaluate rank of item in list
                true_head_rank = (head_arg_order[ind, :] == true_head[ind]).nonzero()
                true_head_rank = true_head_rank.item() + 1

                true_tail_rank = (tail_arg_order[ind, :] == true_tail[ind]).nonzero()
                true_tail_rank = true_tail_rank.item() + 1

                dataset_metrics.append({
                    'MRR': 1.0 / true_head_rank,
                    'MR': float(true_head_rank),
                    'HITS@1': 1.0 if true_head_rank <= 1 else 0.0,
                    'HITS@3': 1.0 if true_head_rank <= 3 else 0.0,
                    'HITS@10': 1.0 if true_head_rank <= 10 else 0.0
                })
                dataset_metrics.append({
                    'MRR': 1.0 / true_tail_rank,
                    'MR': float(true_tail_rank),
                    'HITS@1': 1.0 if true_tail_rank <= 1 else 0.0,
                    'HITS@3': 1.0 if true_tail_rank <= 3 else 0.0,
                    'HITS@10': 1.0 if true_tail_rank <= 10 else 0.0
                })
        for metric in dataset_metrics:
            for key, val in metric.items():
                final_metrics[key] += val
        for key, val in final_metrics.items():
            final_metrics[key] = val / len(dataset_metrics)

        print(f"{data_type} metrics: {final_metrics}")

In [None]:
# For training our actual model we ran this on a GCP VM.
# The same code will run on Colab but will likely take much longer.

def main():
    # List of hyper parametrs that we manually tuned using
    # the suggestions of the rotatE paper as a baseline. Training
    # and modeling was one on a GCP VM rather than in Colab
    _num_negative_samples = 128
    _batch_size = 1024
    _test_batch_size = 16
    _lr = 0.0001
    _hidden_dim = 1000
    _margin = 12
    _num_epochs = 400
    _weight_decay = 5e-5
    _cuda = torch.cuda.is_available()

    # FB2k-237 data is presented in terms of "mids" (Ex. /m/23sdf) that
    # represent an entity. To train embeddings for these entities, we need to
    # assign each one an index to easily model in a neural network. These
    # maps are used to relation entities and relations to indices
    entity2id = load_dict(os.path.join(DATA_DIR, 'entities.dict'))
    relation2id = load_dict(os.path.join(DATA_DIR, 'relations.dict'))
    num_entities = len(entity2id)
    num_relations = len(relation2id)

    train_data = load_triples(os.path.join(DATA_DIR, 'train.txt'), entity2id, relation2id)
    eval_data = load_triples(os.path.join(DATA_DIR, 'valid.txt'), entity2id, relation2id)
    test_data = load_triples(os.path.join(DATA_DIR, 'test.txt'), entity2id, relation2id)

    # Here we compile a list of all triples for the sake of negative sampling later on.
    # We don't want to accidentally consider a true triple as a negative sample
    all_data = [train_data, eval_data, test_data]

    # Get the train dataloader
    dataloader = get_data_loader(train_data, num_entities, _num_negative_samples, _batch_size, data_type="train")

    # Here we initialize the model to use the rotatE embedding loss paradigm. When
    # we trained our transE model, we flipped the final parameter to "transE"
    model = KGEmbedding(num_entities, num_relations, _hidden_dim, _margin, "rotatE")
    if _cuda:
        model = model.cuda()

    # We use Adam as our optimizer, which provides an adaptive learning rate
    # per parameter, but based on empirical testing, using a learning rate
    # scheduler provided slightly better performance
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=_lr, weight_decay=_weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)

    print("Start Training...")
    for epoch in trange(_num_epochs, desc="Train", unit="Epoch"):
        total_loss = 0
        model.train()
        optimizer.zero_grad()

        for batch, (positive_sample, corrupted_heads, corrupted_tails, _) in enumerate(dataloader):
            if _cuda:
                positive_sample = positive_sample.cuda()
                corrupted_heads = corrupted_heads.cuda()
                corrupted_tails = corrupted_tails.cuda()

            positive_sample_dist = model(positive_sample, 'positive')
            positive_score = F.logsigmoid(positive_sample_dist)
            positive_sample_loss = -positive_score.mean()

            # Here we take the negative of the distance values because
            # we want the distances between (head, relation) and (tail)
            # entities for corrupted samples to be far apart. In our model,
            # we use "margin - distance", so a large distance would give
            # a negative value that we then flip to positive. After plugging
            # into the logsigmoid, that would yield a value close to 0
            corrupted_head_dist = model((positive_sample, corrupted_heads), 'negative-head')
            corrupted_head_score = F.logsigmoid(-corrupted_head_dist)
            corrupted_head_loss = -corrupted_head_score.mean()

            corrupted_tail_dist = model((positive_sample, corrupted_tails), 'negative-tail')
            corrupted_tail_score = F.logsigmoid(-corrupted_tail_dist)
            corrupted_tail_loss = -corrupted_tail_score.mean()

            # In the paper, each corrupted_head and corrupted_tail is treated as a separate example.
            # Here, we combine them into one, so we need to weight the positive_sample loss accordingly
            # by adding it twice
            loss = (positive_sample_loss + corrupted_head_loss + positive_sample_loss + corrupted_tail_loss) / 4
            total_loss += loss.item()

            print(f"\nbatch: {batch}, loss: {loss}, pos_loss: {positive_sample_loss}, neg_head_loss: {corrupted_head_loss}, neg_tail_loss: {corrupted_tail_loss}")

            loss.backward()
            optimizer.step()

        # Technically having the scheduler step on evaluation loss would be
        # preferred. However, results did not differ much. Because running
        # eval takes significantly longer (due to increased number of negative
        # examples), we opted to step on total train loss. When selecting the
        # appropriate model to use for the final test metrics,
        # we referenced the eval metrics (MRR, Hits@K)
        scheduler.step(total_loss)
        print(f"\nepoch: {epoch}, avg loss: {total_loss / len(dataloader)}")

        # Evaluate the model on the valid set and save model state for testing or re-loading
        if epoch != 0 and epoch % 10 == 0:
            eval_model(model, eval_data, num_entities, _num_negative_samples, _test_batch_size, all_data, "eval")
            save_model(model, optimizer, scheduler, epoch)

    # Evaludate the model on the test set
    eval_model(model, test_data, num_entities, _num_negative_samples, _test_batch_size, all_data, "test")

In [None]:
main()

Start Training...


Train:   0%|          | 0/400 [00:00<?, ?Epoch/s]


batch: 0, loss: 1.3762550354003906, pos_loss: 2.6802444458007812, neg_head_loss: 0.07202023267745972, neg_tail_loss: 0.07251051068305969

batch: 1, loss: 1.3995506763458252, pos_loss: 2.732279062271118, neg_head_loss: 0.06735088676214218, neg_tail_loss: 0.06629366427659988

batch: 2, loss: 1.4148257970809937, pos_loss: 2.7663259506225586, neg_head_loss: 0.06434808671474457, neg_tail_loss: 0.06230316311120987

batch: 3, loss: 1.4102600812911987, pos_loss: 2.7591075897216797, neg_head_loss: 0.06315280497074127, neg_tail_loss: 0.05967229604721069

batch: 4, loss: 1.425262451171875, pos_loss: 2.790980339050293, neg_head_loss: 0.061774589121341705, neg_tail_loss: 0.057315073907375336

batch: 5, loss: 1.407443642616272, pos_loss: 2.755051851272583, neg_head_loss: 0.06286194920539856, neg_tail_loss: 0.05680900067090988

batch: 6, loss: 1.397940993309021, pos_loss: 2.734828472137451, neg_head_loss: 0.06422890722751617, neg_tail_loss: 0.05787811428308487

batch: 7, loss: 1.3723748922348022, po

Train:   0%|          | 1/400 [03:00<20:03:01, 180.90s/Epoch]


batch: 265, loss: 1.9720350503921509, pos_loss: 3.8811581134796143, neg_head_loss: 0.06703301519155502, neg_tail_loss: 0.058790769428014755

epoch: 0, avg loss: 1.3251219145337443

batch: 0, loss: 1.8433564901351929, pos_loss: 3.6225290298461914, neg_head_loss: 0.0677560567855835, neg_tail_loss: 0.06061171367764473

batch: 1, loss: 1.9328656196594238, pos_loss: 3.8047828674316406, neg_head_loss: 0.06400027871131897, neg_tail_loss: 0.057896606624126434

batch: 2, loss: 1.872615098953247, pos_loss: 3.680802822113037, neg_head_loss: 0.06592807173728943, neg_tail_loss: 0.0629267543554306

batch: 3, loss: 1.8249462842941284, pos_loss: 3.5867395401000977, neg_head_loss: 0.06918621063232422, neg_tail_loss: 0.05711967498064041

batch: 4, loss: 1.797040581703186, pos_loss: 3.528383731842041, neg_head_loss: 0.06898609548807144, neg_tail_loss: 0.062408994883298874

batch: 5, loss: 1.8384649753570557, pos_loss: 3.614591598510742, neg_head_loss: 0.06735537201166153, neg_tail_loss: 0.05732154101133

Train:   0%|          | 2/400 [06:01<19:57:58, 180.60s/Epoch]


batch: 265, loss: 1.7934229373931885, pos_loss: 3.4911134243011475, neg_head_loss: 0.10327009111642838, neg_tail_loss: 0.08819469064474106

epoch: 1, avg loss: 1.516995091411404

batch: 0, loss: 1.7898266315460205, pos_loss: 3.48795223236084, neg_head_loss: 0.10003311932086945, neg_tail_loss: 0.08336932212114334

batch: 1, loss: 1.7931678295135498, pos_loss: 3.497302532196045, neg_head_loss: 0.09667809307575226, neg_tail_loss: 0.08138832449913025

batch: 2, loss: 1.939405083656311, pos_loss: 3.799081325531006, neg_head_loss: 0.09108307212591171, neg_tail_loss: 0.06837446242570877

batch: 3, loss: 1.8978313207626343, pos_loss: 3.7156386375427246, neg_head_loss: 0.08673729747533798, neg_tail_loss: 0.07331018894910812

batch: 4, loss: 1.9472389221191406, pos_loss: 3.818413019180298, neg_head_loss: 0.08426141738891602, neg_tail_loss: 0.06786809861660004

batch: 5, loss: 1.8592925071716309, pos_loss: 3.6411566734313965, neg_head_loss: 0.08207081258296967, neg_tail_loss: 0.07278572767972946

Train:   1%|          | 3/400 [09:01<19:53:47, 180.42s/Epoch]


batch: 265, loss: 1.4976028203964233, pos_loss: 2.670337677001953, neg_head_loss: 0.37640810012817383, neg_tail_loss: 0.27332764863967896

epoch: 2, avg loss: 1.6490236670899212

batch: 0, loss: 1.5630441904067993, pos_loss: 2.8250811100006104, neg_head_loss: 0.3623201549053192, neg_tail_loss: 0.2396939992904663

batch: 1, loss: 1.5527360439300537, pos_loss: 2.8148279190063477, neg_head_loss: 0.34639328718185425, neg_tail_loss: 0.2348947674036026

batch: 2, loss: 1.6037110090255737, pos_loss: 2.92179536819458, neg_head_loss: 0.33339476585388184, neg_tail_loss: 0.23785819113254547

batch: 3, loss: 1.5023399591445923, pos_loss: 2.7130250930786133, neg_head_loss: 0.3500797152519226, neg_tail_loss: 0.23323023319244385

batch: 4, loss: 1.5391252040863037, pos_loss: 2.808725118637085, neg_head_loss: 0.3293285071849823, neg_tail_loss: 0.20972217619419098

batch: 5, loss: 1.5913677215576172, pos_loss: 2.918118476867676, neg_head_loss: 0.3142847418785095, neg_tail_loss: 0.21494901180267334

ba

Train:   1%|          | 4/400 [12:01<19:50:22, 180.36s/Epoch]


batch: 265, loss: 1.4715255498886108, pos_loss: 2.4828202724456787, neg_head_loss: 0.5397549867630005, neg_tail_loss: 0.38070619106292725

epoch: 3, avg loss: 1.3182198411988137

batch: 0, loss: 1.3798258304595947, pos_loss: 2.314262866973877, neg_head_loss: 0.5264474749565125, neg_tail_loss: 0.36433008313179016

batch: 1, loss: 1.545060634613037, pos_loss: 2.6579761505126953, neg_head_loss: 0.5023665428161621, neg_tail_loss: 0.36192357540130615

batch: 2, loss: 1.391930341720581, pos_loss: 2.3592309951782227, neg_head_loss: 0.5022941827774048, neg_tail_loss: 0.34696507453918457

batch: 3, loss: 1.454723834991455, pos_loss: 2.4841620922088623, neg_head_loss: 0.5011959671974182, neg_tail_loss: 0.3493751585483551

batch: 4, loss: 1.4725855588912964, pos_loss: 2.53201961517334, neg_head_loss: 0.5004717111587524, neg_tail_loss: 0.32583147287368774

batch: 5, loss: 1.51736319065094, pos_loss: 2.6400065422058105, neg_head_loss: 0.46559274196624756, neg_tail_loss: 0.323846697807312

batch: 6

Train:   1%|▏         | 5/400 [15:01<19:46:58, 180.30s/Epoch]


batch: 265, loss: 1.5990413427352905, pos_loss: 2.897150993347168, neg_head_loss: 0.3828582763671875, neg_tail_loss: 0.21900495886802673

epoch: 4, avg loss: 1.207111119551766

batch: 0, loss: 1.462993860244751, pos_loss: 2.6235687732696533, neg_head_loss: 0.3761356472969055, neg_tail_loss: 0.22870275378227234

batch: 1, loss: 1.5609999895095825, pos_loss: 2.8168091773986816, neg_head_loss: 0.3981936573982239, neg_tail_loss: 0.2121882140636444

batch: 2, loss: 1.5243823528289795, pos_loss: 2.7655186653137207, neg_head_loss: 0.36483055353164673, neg_tail_loss: 0.20166151225566864

batch: 3, loss: 1.5471564531326294, pos_loss: 2.8056416511535645, neg_head_loss: 0.36862507462501526, neg_tail_loss: 0.20871713757514954

batch: 4, loss: 1.5292465686798096, pos_loss: 2.7644095420837402, neg_head_loss: 0.38928860425949097, neg_tail_loss: 0.19887831807136536

batch: 5, loss: 1.6550400257110596, pos_loss: 3.0413737297058105, neg_head_loss: 0.3481248915195465, neg_tail_loss: 0.18928775191307068


Train:   2%|▏         | 6/400 [18:02<19:44:04, 180.32s/Epoch]


batch: 265, loss: 1.5530931949615479, pos_loss: 2.8173110485076904, neg_head_loss: 0.38131484389305115, neg_tail_loss: 0.19643595814704895

epoch: 5, avg loss: 1.279832056590489

batch: 0, loss: 1.6659986972808838, pos_loss: 3.0578813552856445, neg_head_loss: 0.33890217542648315, neg_tail_loss: 0.20932941138744354

batch: 1, loss: 1.627278208732605, pos_loss: 2.9993231296539307, neg_head_loss: 0.32131677865982056, neg_tail_loss: 0.18915008008480072

batch: 2, loss: 1.7359710931777954, pos_loss: 3.2235267162323, neg_head_loss: 0.3210904598236084, neg_tail_loss: 0.17574076354503632

batch: 3, loss: 1.5657994747161865, pos_loss: 2.8854622840881348, neg_head_loss: 0.3150567412376404, neg_tail_loss: 0.17721673846244812

batch: 4, loss: 1.6738020181655884, pos_loss: 3.1103365421295166, neg_head_loss: 0.3038709759712219, neg_tail_loss: 0.17066442966461182

batch: 5, loss: 1.649315357208252, pos_loss: 3.0620338916778564, neg_head_loss: 0.3015190362930298, neg_tail_loss: 0.17167451977729797

b

Train:   2%|▏         | 7/400 [21:02<19:40:46, 180.27s/Epoch]


batch: 265, loss: 1.6658461093902588, pos_loss: 3.076779842376709, neg_head_loss: 0.3276479244232178, neg_tail_loss: 0.18217676877975464

epoch: 6, avg loss: 1.3839388686911505

batch: 0, loss: 1.6172975301742554, pos_loss: 2.9768028259277344, neg_head_loss: 0.31754112243652344, neg_tail_loss: 0.19804321229457855

batch: 1, loss: 1.6512062549591064, pos_loss: 3.063272476196289, neg_head_loss: 0.30594468116760254, neg_tail_loss: 0.17233552038669586

batch: 2, loss: 1.713794231414795, pos_loss: 3.2046172618865967, neg_head_loss: 0.2912014126777649, neg_tail_loss: 0.15474146604537964

batch: 3, loss: 1.731032371520996, pos_loss: 3.2344000339508057, neg_head_loss: 0.30134275555610657, neg_tail_loss: 0.15398696064949036

batch: 4, loss: 1.719094157218933, pos_loss: 3.2264468669891357, neg_head_loss: 0.270616739988327, neg_tail_loss: 0.1528657078742981

batch: 5, loss: 1.755645513534546, pos_loss: 3.298846960067749, neg_head_loss: 0.2784377932548523, neg_tail_loss: 0.14645010232925415

batc

Train:   2%|▏         | 8/400 [24:02<19:37:54, 180.29s/Epoch]


batch: 265, loss: 1.685313105583191, pos_loss: 3.112835645675659, neg_head_loss: 0.32312002778053284, neg_tail_loss: 0.19246120750904083

epoch: 7, avg loss: 1.4627579287030643

batch: 0, loss: 1.670797348022461, pos_loss: 3.087320566177368, neg_head_loss: 0.31889140605926514, neg_tail_loss: 0.18965697288513184

batch: 1, loss: 1.6117509603500366, pos_loss: 2.973437786102295, neg_head_loss: 0.30767834186553955, neg_tail_loss: 0.19245001673698425

batch: 2, loss: 1.616439938545227, pos_loss: 2.9886631965637207, neg_head_loss: 0.2939911484718323, neg_tail_loss: 0.19444212317466736

batch: 3, loss: 1.5725473165512085, pos_loss: 2.895265579223633, neg_head_loss: 0.31257739663124084, neg_tail_loss: 0.1870810091495514

batch: 4, loss: 1.6740469932556152, pos_loss: 3.104257822036743, neg_head_loss: 0.30039066076278687, neg_tail_loss: 0.18728160858154297

batch: 5, loss: 1.6638867855072021, pos_loss: 3.082747459411621, neg_head_loss: 0.30698078870773315, neg_tail_loss: 0.18307137489318848

ba

Train:   2%|▏         | 9/400 [27:03<19:34:49, 180.28s/Epoch]


batch: 265, loss: 1.0019936561584473, pos_loss: 1.0487233400344849, neg_head_loss: 1.0939749479293823, neg_tail_loss: 0.8165529370307922

epoch: 8, avg loss: 1.289266117981502

batch: 0, loss: 0.8732174634933472, pos_loss: 0.7767853736877441, neg_head_loss: 1.109297752380371, neg_tail_loss: 0.8300014734268188

batch: 1, loss: 0.862839937210083, pos_loss: 0.749805748462677, neg_head_loss: 1.1124935150146484, neg_tail_loss: 0.8392549157142639

batch: 2, loss: 0.913147509098053, pos_loss: 0.8339416980743408, neg_head_loss: 1.1337529420852661, neg_tail_loss: 0.8509538769721985

batch: 3, loss: 0.9093776941299438, pos_loss: 0.8647695779800415, neg_head_loss: 1.0883963108062744, neg_tail_loss: 0.8195753693580627

batch: 4, loss: 0.8962777853012085, pos_loss: 0.8223769664764404, neg_head_loss: 1.0932222604751587, neg_tail_loss: 0.847135066986084

batch: 5, loss: 0.8692982792854309, pos_loss: 0.7574095726013184, neg_head_loss: 1.105346918106079, neg_tail_loss: 0.8570270538330078

batch: 6, lo

Train:   2%|▎         | 10/400 [30:04<19:33:49, 180.59s/Epoch]


batch: 265, loss: 0.7405669093132019, pos_loss: 1.2923609018325806, neg_head_loss: 0.232016459107399, neg_tail_loss: 0.14552949368953705

epoch: 9, avg loss: 0.7652176107679095

batch: 0, loss: 0.6588925123214722, pos_loss: 1.1263175010681152, neg_head_loss: 0.2386338710784912, neg_tail_loss: 0.14430108666419983

batch: 1, loss: 0.6193917393684387, pos_loss: 1.0474127531051636, neg_head_loss: 0.23309122025966644, neg_tail_loss: 0.14965036511421204

batch: 2, loss: 0.6435023546218872, pos_loss: 1.101501703262329, neg_head_loss: 0.23335938155651093, neg_tail_loss: 0.13764657080173492

batch: 3, loss: 0.6394268870353699, pos_loss: 1.088416337966919, neg_head_loss: 0.23183351755142212, neg_tail_loss: 0.14904144406318665

batch: 4, loss: 0.6003158092498779, pos_loss: 1.0138585567474365, neg_head_loss: 0.22683262825012207, neg_tail_loss: 0.1467134654521942

batch: 5, loss: 0.6278892755508423, pos_loss: 1.0741803646087646, neg_head_loss: 0.2237151861190796, neg_tail_loss: 0.13948102295398712

Train:   3%|▎         | 11/400 [41:22<35:58:49, 332.98s/Epoch]


batch: 0, loss: 0.5263964533805847, pos_loss: 0.3116559684276581, neg_head_loss: 0.8515654802322388, neg_tail_loss: 0.6307083368301392

batch: 1, loss: 0.5095170736312866, pos_loss: 0.2975550889968872, neg_head_loss: 0.8219766616821289, neg_tail_loss: 0.6209814548492432

batch: 2, loss: 0.4972337484359741, pos_loss: 0.25200530886650085, neg_head_loss: 0.8351205587387085, neg_tail_loss: 0.6498036980628967

batch: 3, loss: 0.5003272294998169, pos_loss: 0.25321730971336365, neg_head_loss: 0.8613237142562866, neg_tail_loss: 0.6335504651069641

batch: 4, loss: 0.5234430432319641, pos_loss: 0.31946349143981934, neg_head_loss: 0.823891282081604, neg_tail_loss: 0.6309539079666138

batch: 5, loss: 0.5069924592971802, pos_loss: 0.2716831564903259, neg_head_loss: 0.8455794453620911, neg_tail_loss: 0.639024019241333

batch: 6, loss: 0.49239474534988403, pos_loss: 0.2542763948440552, neg_head_loss: 0.8304470777511597, neg_tail_loss: 0.6305791139602661

batch: 7, loss: 0.5129432678222656, pos_loss:

Train:   3%|▎         | 12/400 [44:22<30:51:30, 286.32s/Epoch]


batch: 265, loss: 0.45641446113586426, pos_loss: 0.80561363697052, neg_head_loss: 0.12916827201843262, neg_tail_loss: 0.08526234328746796

epoch: 11, avg loss: 0.4269408089549918

batch: 0, loss: 0.4226773977279663, pos_loss: 0.7399373054504395, neg_head_loss: 0.12900108098983765, neg_tail_loss: 0.0818338394165039

batch: 1, loss: 0.40889355540275574, pos_loss: 0.7149906158447266, neg_head_loss: 0.12255547940731049, neg_tail_loss: 0.08303748816251755

batch: 2, loss: 0.4121030867099762, pos_loss: 0.7227481007575989, neg_head_loss: 0.11862142384052277, neg_tail_loss: 0.08429465442895889

batch: 3, loss: 0.40595555305480957, pos_loss: 0.7084940671920776, neg_head_loss: 0.12708576023578644, neg_tail_loss: 0.07974836975336075

batch: 4, loss: 0.41318362951278687, pos_loss: 0.7232720255851746, neg_head_loss: 0.1262444257736206, neg_tail_loss: 0.07994601130485535

batch: 5, loss: 0.4146536588668823, pos_loss: 0.7287713289260864, neg_head_loss: 0.12149234116077423, neg_tail_loss: 0.079579599

**Visualization Tools**

In [None]:
# Note: This was only tested on GCP VM, it will not
# run on Google Colab without additional tinkering
# In addition, auxiliary files are needed that preload
# FB15k-237 "mid" values to their corresponding true
# names and index values

import math
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from matplotlib import cm
from matplotlib.patches import Circle
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

np.random.seed(9654924)

def load_pickle(file_name):
    with open(file_name, 'rb') as f:
        p = pickle.load(f)
    return p

def load_numpy(file_name):
    with open(file_name, 'rb') as f:
        n = np.load(f)
    return n

def plot_couple_dimensions(path, model, embedding_path, examples_to_use=3, dims=(2,2)):
    """ Given a number of triples, plot their positions on a graph
    """
    mid_to_index_mapping = load_pickle(os.path.join(path, "mid_to_index_mapping.pickle"))
    mid_to_name_mapping = load_pickle(os.path.join(path, "mid_to_name_mapping.pickle"))
    relation_to_index_mapping = load_pickle(os.path.join(path, "relation_to_index_mapping.pickle"))
    used_tuples_list = load_pickle(os.path.join(path, "used_tuples.pickle"))

    dims_to_plot = dims[0] * dims[1]
    entity_embedding = load_numpy(embedding_path)
    relation_embedding = load_numpy(embedding_path.replace("entity_embedding", "relation_embedding"))
    num_entities, num_dim = entity_embedding.shape

    # Randomly select
    idxs = np.random.choice(range(len(used_tuples_list)), examples_to_use, replace=False)
    selected_entities = np.array(used_tuples_list)[idxs]

    # All selected head tails have the same relation between them
    relevant_relation = int(relation_to_index_mapping[selected_entities[0][1]])

    heads = [head for head, _, _ in selected_entities]
    tails = [tail for _, _, tail in selected_entities]
    selected_entities = heads + tails

    if model == 'rotatE':
        # Split num_dim in half after accounting for imaginary dimension
        num_dim /= 2
        real, imag = np.split(entity_embedding, 2, axis=1)
        variances = (np.var(real, axis=0) + np.var(imag, axis=0)) / 2
        X, Y = real, imag

    elif model == 'transE':
        # Best to plot in at least 2D, so we select the dimensions with
        # highest variance and assign 2 dimensions per test dimension
        variances = np.var(entity_embedding, axis=0)
        double_dims = np.argsort(variances)[-(dims_to_plot*2):]

        X = entity_embedding
        Y = entity_embedding
        for ind in range(dims_to_plot):
            Y[:, double_dims[ind]] = entity_embedding[:, double_dims[-(ind+1)]]
        r = relation_embedding

    else:
        raise(f"Invalid model type for plotting: {model}")

    assert len(variances) == num_dim
    selected_dims = np.argsort(variances)[-(dims_to_plot):]

    row, col = 0, 0
    num_row, num_col = dims
    plt.figure(figsize=(32,20))
    fig, axis = plt.subplots(num_row, num_col)
    for ind, dim in enumerate(selected_dims):
        axis[row][col].axhline(0, color='black')
        axis[row][col].axvline(0, color='black')
        for mid in selected_entities:
            index = int(mid_to_index_mapping[mid])
            x = X[index][dim]
            y = Y[index][dim]
            axis[row][col].scatter(x, y, marker="o", s=50)
            axis[row][col].annotate(mid_to_name_mapping[mid], (x, y), fontsize=6)

        # Add concentric circles to illustrate rotational nature
        if model == 'rotatE':
            x_low, x_high = axis[row][col].get_xlim()
            y_low, y_high = axis[row][col].get_ylim()
            high = max(abs(x_high), abs(y_high), abs(x_low), abs(y_low))

            # Determine the interval at which to draw concentric circles
            interval = 1
            while high < 1:
                high *= 10
                interval /= 10

            multiplier = 1
            while multiplier * interval < high:
                axis[row][col].add_patch(Circle((0, 0), multiplier * interval, color='r', fill=False, linestyle='dotted'))
                multiplier += 1
            axis[row][col].set_title(f'rotatE Embedding Dim {dim}', fontsize=10)

        # Draw relations as vectors
        elif model == 'transE':
            dim_2 = double_dims[-(ind+1)]
            for head in heads:
                index = int(mid_to_index_mapping[head])
                start_x = X[index][dim]
                start_y = Y[index][dim]
                dist_x = r[relevant_relation][dim]
                dist_y = r[relevant_relation][dim_2]
                axis[row][col].arrow(start_x, start_y, dist_x, dist_y)
            axis[row][col].set_title(f'transE Embedding Dims {dim} and {dim_2}', fontsize=10)

        col += 1
        if col == num_col:
            col = 0
            row += 1

    plt.tight_layout()
    plt.savefig(f'{model}_embeddings')

#plot_couple_dimensions("sports_data", "rotatE", "rotatE_lr_decay_embeddings/entity_embedding_200.npy")

**MID to Entity Name Mapping**

In [None]:
# Similar to above, this was run on a GCP VM, so will not run
# directly on Colab without making adjustments

import json
import os
import pickle
import re
import requests
import sys
import tqdm
import random

# Given a relation, find all corresponding MIDs and their indexes
relation = "/award/award_nominee/award_nominations./award/award_nomination/nominated_for"
path = "data/FB15k-237"
decode_file = 'fb2w.nt'

entities_file = os.path.join(path, 'entities.dict')
relations_file = os.path.join(path, 'relations.dict')

mid_to_index_mapping = dict()
mid_to_name_mapping = dict()
mid_to_url_mapping = dict()

def process_grep_out(out):
    tuples = out.split('\n')
    return [tuple(tup.split('\t')) for tup in tuples]

def add_mappings(mid):
    got_mid = True
    if mid not in mid_to_index_mapping:
        got_mid = False
        stream = os.popen(f'grep {mid} {entities_file}')
        out = stream.read()
        entity_and_ids = process_grep_out(out)
        for entity_and_id in entity_and_ids:
            if len(entity_and_id) == 2:
                eid, entity = entity_and_id
                if entity == mid:
                    mid_to_index_mapping[mid] = eid
                    got_mid = True

    if mid not in mid_to_url_mapping:
        got_mid = False
        converted_mid = mid[1:2] + '.' + mid[3:]
        stream = os.popen(f'grep {converted_mid} {decode_file}')
        out = stream.read()
        mid_to_urls = process_grep_out(out)
        for mid_to_url in mid_to_urls:
            if len(mid_to_url) == 3:
                fb, w3, wiki = mid_to_url
                found_mid = fb[1:-1].split('/')[4]
                if found_mid == converted_mid:
                    mid_to_url_mapping[mid] = wiki[1:-3]
                    got_mid = True
        # If didn't find mid, then remove it from the
        # mid_to_index_mapping map b/c it can't be used
        if not got_mid:
            mid_to_index_mapping.pop(mid)
    return got_mid

def generate_relation_map():
    """ Only needs to be generated once
    """
    with open(relations_file, 'r') as f:
        all_relations = f.read()

    relation_to_index_mapping = dict()
    list_relations = all_relations.split('\n')
    for relation in list_relations:
        out = relation.split('\t')
        if len(out) == 2:
            rid, curr_relation = out
            relation_to_index_mapping[curr_relation] = rid

    with open("relation_to_index_mapping.pickle", 'wb') as f:
        pickle.dump(relation_to_index_mapping, f)

test_file = os.path.join(path, 'test.txt')
stream = os.popen(f'grep {relation} {test_file}')
out = stream.read()
tuples = process_grep_out(out)

if os.path.exists('mid_to_index_mapping.pickle') and os.path.exists('mid_to_url_mapping.pickle'):
    with open('mid_to_index_mapping.pickle', 'rb') as f:
        mid_to_index_mapping = pickle.load(f)
    with open('mid_to_url_mapping.pickle', 'rb') as f:
        mid_to_url_mapping = pickle.load(f)
else:
    all_tuples = []
    for _, tup in enumerate(tqdm.tqdm(tuples, desc="Index and URL For Tuple")):
        if len(tup) == 3:
            head, relation, tail = tup
            if add_mappings(head) and add_mappings(tail):
                all_tuples.append(tup)

    with open('mid_to_index_mapping.pickle', 'wb') as f:
        pickle.dump(mid_to_index_mapping, f)
    with open('mid_to_url_mapping.pickle', 'wb') as f:
        pickle.dump(mid_to_url_mapping, f)
    with open('used_tuples.pickle', 'wb') as f:
        pickle.dump(all_tuples, f)

for _, (mid, url) in enumerate(tqdm.tqdm(mid_to_url_mapping.items(), desc="Name For Tuple")):
    request = requests.get(url)
    x = json.loads(request.text)['entities']
    keys = list(x.keys())
    name = x[keys[0]]['labels']['en']['value']
    mid_to_name_mapping[mid] = name

with open('mid_to_name_mapping.pickle', 'wb') as f:
    pickle.dump(mid_to_name_mapping, f)

generate_relation_map()