In [1]:
# Jupyter Notebook cell: RotatE training and evaluation

import random
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 1. Load triples
def load_triples(path):
    triples = []
    with open(path, 'r') as f:
        for line in f:
            h, r, t = line.strip().split()[:3]
            triples.append((h, r, t))
    return triples

train_triples = load_triples("../graph_data/train.tsv")
dev_triples =  load_triples("../graph_data/dev.tsv")
test_triples =  load_triples("../graph_data/test.tsv")

# 2. Build vocab
entities = sorted({e for (h, _, t) in train_triples + dev_triples + test_triples for e in (h, t)})
relations = sorted({r for (_, r, _) in train_triples + dev_triples + test_triples})
ent2id = {e: i for i, e in enumerate(entities)}
rel2id = {r: i for i, r in enumerate(relations)}

In [3]:
rel2id

{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6}

In [3]:
# 3. Model definition
class RotatEModel(nn.Module):
    def __init__(self, num_ent, num_rel, dim):
        super().__init__()
        self.dim = dim
        # entity embeddings: real and imag parts
        self.ent_re = nn.Embedding(num_ent, dim)
        self.ent_im = nn.Embedding(num_ent, dim)
        # relation embeddings: phase angles
        self.rel_ph = nn.Embedding(num_rel, dim)
        nn.init.xavier_uniform_(self.ent_re.weight)
        nn.init.xavier_uniform_(self.ent_im.weight)
        nn.init.uniform_(self.rel_ph.weight, a=-3.1415, b=3.1415)

    def forward(self, h, r, t):
        re_h = self.ent_re(h)
        im_h = self.ent_im(h)
        re_t = self.ent_re(t)
        im_t = self.ent_im(t)
        ph_r = self.rel_ph(r)
        # compute relation as complex: cos, sin
        re_r = torch.cos(ph_r)
        im_r = torch.sin(ph_r)
        # rotate head
        re_hr = re_h * re_r - im_h * im_r
        im_hr = re_h * im_r + im_h * re_r
        # distance
        diff_re = re_hr - re_t
        diff_im = im_hr - im_t
        # score: negative L1 norm
        score = - (diff_re.abs() + diff_im.abs()).sum(dim=-1)
        return score

# 4. Negative sampling

def negative_sampling(batch, num_entities):
    neg = batch.clone()
    mask = torch.rand(batch.size(0), device=batch.device) < 0.5
    # corrupt head
    num_h = mask.sum().item()
    neg[mask, 0] = torch.randint(0, num_entities, (num_h,), device=batch.device)
    # corrupt tail
    num_t = (~mask).sum().item()
    neg[~mask, 2] = torch.randint(0, num_entities, (num_t,), device=batch.device)
    return neg

# 5. Encoding

def encode(triples):
    return torch.tensor([[ent2id[h], rel2id[r], ent2id[t]] for h, r, t in triples], dtype=torch.long)

train_data = encode(train_triples)
dev_data   = encode(dev_triples)
test_data  = encode(test_triples)

In [4]:
# 8. Training and evaluation functions

def train_epoch(model, loader, optimizer, num_entities):
    model.train()
    total_loss = 0
    total = 0
    for batch in tqdm(loader, desc='Train', leave=False):
        batch = batch[0].to(device)
        pos = batch
        neg = negative_sampling(pos, num_entities)
        pos_score = model(pos[:,0], pos[:,1], pos[:,2])
        neg_score = model(neg[:,0], neg[:,1], neg[:,2])
        loss = - (F.logsigmoid(pos_score).mean() + F.logsigmoid(-neg_score).mean())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * pos.size(0)
        total += pos.size(0)
    return total_loss / total

@torch.no_grad()
def evaluate_ranking(triples_tensor, num_entities, num_neg=100, k_list=(1,3,10)):
    model.eval()
    ranks = []
    for idx in tqdm(range(triples_tensor.size(0)), desc='Eval', leave=False):
        h, r, t = triples_tensor[idx]
        h = h.to(device); r = r.to(device); t = t.to(device)
        pos_score = model(h.unsqueeze(0), r.unsqueeze(0), t.unsqueeze(0))
        scores = [pos_score.item()]
        for _ in range(num_neg):
            if random.random() < 0.5:
                h2 = random.randint(0, num_entities-1); t2 = t.item()
            else:
                h2 = h.item(); t2 = random.randint(0, num_entities-1)
            neg_score = model(
                torch.tensor([h2], device=device),
                r.unsqueeze(0),
                torch.tensor([t2], device=device)
            )
            scores.append(neg_score.item())
        sorted_scores = sorted(scores, reverse=True)
        rank = sorted_scores.index(scores[0]) + 1
        ranks.append(rank)
    ranks = torch.tensor(ranks, dtype=torch.float)
    mrr = (1.0/ranks).mean().item()
    hits = {f"Hits@{k}": (ranks <= k).float().mean().item() for k in k_list}
    return mrr, hits



In [5]:
# 6. Dataloaders
batch_size = 512
train_loader = DataLoader(TensorDataset(train_data), batch_size=batch_size, shuffle=True)
dev_loader   = DataLoader(TensorDataset(dev_data),   batch_size=batch_size)

# 7. Instantiate
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = RotatEModel(len(entities), len(relations), dim=100).to(device)
optimizer = AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)



In [6]:
# 9. Run training

best_dev_loss = float('inf')
for epoch in range(1, 201):
    train_loss = train_epoch(model, train_loader, optimizer, len(entities))
    mrr_dev, hits_dev = evaluate_ranking(dev_data, len(entities))
    print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | Dev MRR: {mrr_dev:.4f}, Hits@10: {hits_dev['Hits@10']:.4f}")




                                                     

Epoch 01 | Train Loss: 4.4525 | Dev MRR: 0.0702, Hits@10: 0.1087


                                                     

Epoch 02 | Train Loss: 3.9407 | Dev MRR: 0.0921, Hits@10: 0.1957


                                                     

Epoch 03 | Train Loss: 3.5054 | Dev MRR: 0.1399, Hits@10: 0.2609


                                                     

Epoch 04 | Train Loss: 3.1229 | Dev MRR: 0.1500, Hits@10: 0.2609


                                                     

Epoch 05 | Train Loss: 2.7869 | Dev MRR: 0.1416, Hits@10: 0.2609


                                                     

Epoch 06 | Train Loss: 2.4955 | Dev MRR: 0.1708, Hits@10: 0.2391


                                                     

Epoch 07 | Train Loss: 2.2446 | Dev MRR: 0.1739, Hits@10: 0.2826


                                                     

Epoch 08 | Train Loss: 2.0319 | Dev MRR: 0.1796, Hits@10: 0.3261


                                                     

Epoch 09 | Train Loss: 1.8511 | Dev MRR: 0.1863, Hits@10: 0.3043


                                                     

Epoch 10 | Train Loss: 1.7019 | Dev MRR: 0.1851, Hits@10: 0.3043


                                                     

Epoch 11 | Train Loss: 1.5761 | Dev MRR: 0.1860, Hits@10: 0.3261


                                                     

Epoch 12 | Train Loss: 1.4725 | Dev MRR: 0.2005, Hits@10: 0.3696


                                                     

Epoch 13 | Train Loss: 1.3863 | Dev MRR: 0.2048, Hits@10: 0.3696


                                                     

Epoch 14 | Train Loss: 1.3157 | Dev MRR: 0.2022, Hits@10: 0.3913


                                                     

Epoch 15 | Train Loss: 1.2566 | Dev MRR: 0.1915, Hits@10: 0.4348


                                                     

Epoch 16 | Train Loss: 1.2078 | Dev MRR: 0.2417, Hits@10: 0.4348


                                                     

Epoch 17 | Train Loss: 1.1651 | Dev MRR: 0.2311, Hits@10: 0.4130


                                                     

Epoch 18 | Train Loss: 1.1321 | Dev MRR: 0.2292, Hits@10: 0.4565


                                                     

Epoch 19 | Train Loss: 1.1019 | Dev MRR: 0.2394, Hits@10: 0.4565


                                                     

Epoch 20 | Train Loss: 1.0780 | Dev MRR: 0.2865, Hits@10: 0.4565


                                                     

Epoch 21 | Train Loss: 1.0560 | Dev MRR: 0.3020, Hits@10: 0.4783


                                                     

Epoch 22 | Train Loss: 1.0378 | Dev MRR: 0.3231, Hits@10: 0.4783


                                                     

Epoch 23 | Train Loss: 1.0229 | Dev MRR: 0.3285, Hits@10: 0.4783


                                                     

Epoch 24 | Train Loss: 1.0093 | Dev MRR: 0.3434, Hits@10: 0.5000


                                                     

Epoch 25 | Train Loss: 0.9962 | Dev MRR: 0.3016, Hits@10: 0.4783


                                                     

Epoch 26 | Train Loss: 0.9862 | Dev MRR: 0.3176, Hits@10: 0.5000


                                                     

Epoch 27 | Train Loss: 0.9779 | Dev MRR: 0.3444, Hits@10: 0.5000


                                                     

Epoch 28 | Train Loss: 0.9688 | Dev MRR: 0.3016, Hits@10: 0.4783


                                                     

Epoch 29 | Train Loss: 0.9620 | Dev MRR: 0.3292, Hits@10: 0.4565


                                                     

Epoch 30 | Train Loss: 0.9564 | Dev MRR: 0.3573, Hits@10: 0.4783


                                                     

Epoch 31 | Train Loss: 0.9493 | Dev MRR: 0.3691, Hits@10: 0.5000


                                                     

Epoch 32 | Train Loss: 0.9440 | Dev MRR: 0.3863, Hits@10: 0.5000


                                                     

Epoch 33 | Train Loss: 0.9387 | Dev MRR: 0.3814, Hits@10: 0.5000


                                                     

Epoch 34 | Train Loss: 0.9342 | Dev MRR: 0.3901, Hits@10: 0.4565


                                                     

Epoch 35 | Train Loss: 0.9314 | Dev MRR: 0.3836, Hits@10: 0.5217


                                                     

Epoch 36 | Train Loss: 0.9279 | Dev MRR: 0.3674, Hits@10: 0.5000


                                                     

Epoch 37 | Train Loss: 0.9236 | Dev MRR: 0.3957, Hits@10: 0.5217


                                                     

Epoch 38 | Train Loss: 0.9210 | Dev MRR: 0.3930, Hits@10: 0.5435


                                                     

Epoch 39 | Train Loss: 0.9173 | Dev MRR: 0.4425, Hits@10: 0.5217


                                                     

Epoch 40 | Train Loss: 0.9151 | Dev MRR: 0.3935, Hits@10: 0.5435


                                                     

Epoch 41 | Train Loss: 0.9121 | Dev MRR: 0.3695, Hits@10: 0.5217


                                                     

Epoch 42 | Train Loss: 0.9101 | Dev MRR: 0.3843, Hits@10: 0.5652


                                                     

Epoch 43 | Train Loss: 0.9076 | Dev MRR: 0.4353, Hits@10: 0.5217


                                                     

Epoch 44 | Train Loss: 0.9055 | Dev MRR: 0.4088, Hits@10: 0.5217


                                                     

Epoch 45 | Train Loss: 0.9027 | Dev MRR: 0.3894, Hits@10: 0.5435


                                                     

Epoch 46 | Train Loss: 0.9015 | Dev MRR: 0.4164, Hits@10: 0.5435


                                                     

Epoch 47 | Train Loss: 0.8986 | Dev MRR: 0.4002, Hits@10: 0.5870


                                                     

Epoch 48 | Train Loss: 0.8966 | Dev MRR: 0.4199, Hits@10: 0.5217


                                                     

Epoch 49 | Train Loss: 0.8956 | Dev MRR: 0.4282, Hits@10: 0.5652


                                                     

Epoch 50 | Train Loss: 0.8932 | Dev MRR: 0.4301, Hits@10: 0.5652


                                                     

Epoch 51 | Train Loss: 0.8933 | Dev MRR: 0.4044, Hits@10: 0.5435


                                                     

Epoch 52 | Train Loss: 0.8900 | Dev MRR: 0.4120, Hits@10: 0.5652


                                                     

Epoch 53 | Train Loss: 0.8884 | Dev MRR: 0.4184, Hits@10: 0.5435


                                                     

Epoch 54 | Train Loss: 0.8864 | Dev MRR: 0.4185, Hits@10: 0.5870


                                                     

Epoch 55 | Train Loss: 0.8861 | Dev MRR: 0.4523, Hits@10: 0.5652


                                                     

Epoch 56 | Train Loss: 0.8849 | Dev MRR: 0.4241, Hits@10: 0.5870


                                                     

Epoch 57 | Train Loss: 0.8829 | Dev MRR: 0.4366, Hits@10: 0.5435


                                                     

Epoch 58 | Train Loss: 0.8813 | Dev MRR: 0.4676, Hits@10: 0.5435


                                                     

Epoch 59 | Train Loss: 0.8804 | Dev MRR: 0.4565, Hits@10: 0.5870


                                                     

Epoch 60 | Train Loss: 0.8779 | Dev MRR: 0.4298, Hits@10: 0.5652


                                                     

Epoch 61 | Train Loss: 0.8776 | Dev MRR: 0.4509, Hits@10: 0.5652


                                                     

Epoch 62 | Train Loss: 0.8766 | Dev MRR: 0.4625, Hits@10: 0.5652


                                                     

Epoch 63 | Train Loss: 0.8756 | Dev MRR: 0.4622, Hits@10: 0.5870


                                                     

Epoch 64 | Train Loss: 0.8758 | Dev MRR: 0.4668, Hits@10: 0.5870


                                                     

Epoch 65 | Train Loss: 0.8723 | Dev MRR: 0.4654, Hits@10: 0.5652


                                                     

Epoch 66 | Train Loss: 0.8734 | Dev MRR: 0.4775, Hits@10: 0.5652


                                                     

Epoch 67 | Train Loss: 0.8715 | Dev MRR: 0.4676, Hits@10: 0.5652


                                                     

Epoch 68 | Train Loss: 0.8691 | Dev MRR: 0.5024, Hits@10: 0.5652


                                                     

Epoch 69 | Train Loss: 0.8688 | Dev MRR: 0.4610, Hits@10: 0.5652


                                                     

Epoch 70 | Train Loss: 0.8671 | Dev MRR: 0.4770, Hits@10: 0.5870


                                                     

Epoch 71 | Train Loss: 0.8663 | Dev MRR: 0.4817, Hits@10: 0.5870


                                                     

Epoch 72 | Train Loss: 0.8662 | Dev MRR: 0.5019, Hits@10: 0.5435


                                                     

Epoch 73 | Train Loss: 0.8648 | Dev MRR: 0.4701, Hits@10: 0.5652


                                                     

Epoch 74 | Train Loss: 0.8647 | Dev MRR: 0.4698, Hits@10: 0.5652


                                                     

Epoch 75 | Train Loss: 0.8637 | Dev MRR: 0.4946, Hits@10: 0.5870


                                                     

Epoch 76 | Train Loss: 0.8624 | Dev MRR: 0.4681, Hits@10: 0.5435


                                                     

Epoch 77 | Train Loss: 0.8612 | Dev MRR: 0.4808, Hits@10: 0.5217


                                                     

Epoch 78 | Train Loss: 0.8605 | Dev MRR: 0.4845, Hits@10: 0.5652


                                                     

Epoch 79 | Train Loss: 0.8593 | Dev MRR: 0.4985, Hits@10: 0.5652


                                                     

Epoch 80 | Train Loss: 0.8588 | Dev MRR: 0.5033, Hits@10: 0.5435


                                                     

Epoch 81 | Train Loss: 0.8588 | Dev MRR: 0.4935, Hits@10: 0.5217


                                                     

Epoch 82 | Train Loss: 0.8574 | Dev MRR: 0.5122, Hits@10: 0.5870


                                                     

Epoch 83 | Train Loss: 0.8576 | Dev MRR: 0.5022, Hits@10: 0.5435


                                                     

Epoch 84 | Train Loss: 0.8549 | Dev MRR: 0.4943, Hits@10: 0.5870


                                                     

Epoch 85 | Train Loss: 0.8554 | Dev MRR: 0.5145, Hits@10: 0.5652


                                                     

Epoch 86 | Train Loss: 0.8552 | Dev MRR: 0.4947, Hits@10: 0.5652


                                                     

Epoch 87 | Train Loss: 0.8552 | Dev MRR: 0.5086, Hits@10: 0.5652


                                                     

Epoch 88 | Train Loss: 0.8547 | Dev MRR: 0.5267, Hits@10: 0.5870


                                                     

Epoch 89 | Train Loss: 0.8516 | Dev MRR: 0.4917, Hits@10: 0.5652


                                                     

Epoch 90 | Train Loss: 0.8511 | Dev MRR: 0.4991, Hits@10: 0.5435


                                                     

Epoch 91 | Train Loss: 0.8514 | Dev MRR: 0.5037, Hits@10: 0.6087


                                                     

Epoch 92 | Train Loss: 0.8496 | Dev MRR: 0.4933, Hits@10: 0.5652


                                                     

Epoch 93 | Train Loss: 0.8482 | Dev MRR: 0.4936, Hits@10: 0.5870


                                                     

Epoch 94 | Train Loss: 0.8498 | Dev MRR: 0.5038, Hits@10: 0.5652


                                                     

Epoch 95 | Train Loss: 0.8488 | Dev MRR: 0.5307, Hits@10: 0.5652


                                                     

Epoch 96 | Train Loss: 0.8487 | Dev MRR: 0.5197, Hits@10: 0.5435


                                                     

Epoch 97 | Train Loss: 0.8476 | Dev MRR: 0.5281, Hits@10: 0.5652


                                                     

Epoch 98 | Train Loss: 0.8468 | Dev MRR: 0.4887, Hits@10: 0.5435


                                                     

Epoch 99 | Train Loss: 0.8467 | Dev MRR: 0.5099, Hits@10: 0.5652


                                                     

Epoch 100 | Train Loss: 0.8466 | Dev MRR: 0.4920, Hits@10: 0.5652


                                                     

Epoch 101 | Train Loss: 0.8445 | Dev MRR: 0.5227, Hits@10: 0.5652


                                                     

Epoch 102 | Train Loss: 0.8454 | Dev MRR: 0.5446, Hits@10: 0.5870


                                                     

Epoch 103 | Train Loss: 0.8442 | Dev MRR: 0.5101, Hits@10: 0.6087


                                                     

Epoch 104 | Train Loss: 0.8425 | Dev MRR: 0.5043, Hits@10: 0.6087


                                                     

Epoch 105 | Train Loss: 0.8427 | Dev MRR: 0.5190, Hits@10: 0.6304


                                                     

Epoch 106 | Train Loss: 0.8424 | Dev MRR: 0.4991, Hits@10: 0.5870


                                                     

Epoch 107 | Train Loss: 0.8421 | Dev MRR: 0.5367, Hits@10: 0.6087


                                                     

Epoch 108 | Train Loss: 0.8417 | Dev MRR: 0.5178, Hits@10: 0.6087


                                                     

Epoch 109 | Train Loss: 0.8416 | Dev MRR: 0.4922, Hits@10: 0.6087


                                                     

Epoch 110 | Train Loss: 0.8403 | Dev MRR: 0.5012, Hits@10: 0.5652


                                                     

Epoch 111 | Train Loss: 0.8394 | Dev MRR: 0.5138, Hits@10: 0.6304


                                                     

Epoch 112 | Train Loss: 0.8382 | Dev MRR: 0.4864, Hits@10: 0.6304


                                                     

Epoch 113 | Train Loss: 0.8381 | Dev MRR: 0.5245, Hits@10: 0.6304


                                                     

Epoch 114 | Train Loss: 0.8370 | Dev MRR: 0.5199, Hits@10: 0.6087


                                                     

Epoch 115 | Train Loss: 0.8388 | Dev MRR: 0.5362, Hits@10: 0.6522


                                                     

Epoch 116 | Train Loss: 0.8362 | Dev MRR: 0.5324, Hits@10: 0.6304


                                                     

Epoch 117 | Train Loss: 0.8365 | Dev MRR: 0.5138, Hits@10: 0.6087


                                                     

Epoch 118 | Train Loss: 0.8362 | Dev MRR: 0.5228, Hits@10: 0.6087


                                                     

Epoch 119 | Train Loss: 0.8344 | Dev MRR: 0.4942, Hits@10: 0.5870


                                                     

Epoch 120 | Train Loss: 0.8344 | Dev MRR: 0.5254, Hits@10: 0.6304


                                                     

Epoch 121 | Train Loss: 0.8342 | Dev MRR: 0.5244, Hits@10: 0.6304


                                                     

Epoch 122 | Train Loss: 0.8349 | Dev MRR: 0.5319, Hits@10: 0.6522


                                                     

Epoch 123 | Train Loss: 0.8335 | Dev MRR: 0.5090, Hits@10: 0.6522


                                                     

Epoch 124 | Train Loss: 0.8337 | Dev MRR: 0.5216, Hits@10: 0.6087


                                                     

Epoch 125 | Train Loss: 0.8337 | Dev MRR: 0.5254, Hits@10: 0.6304


                                                     

Epoch 126 | Train Loss: 0.8321 | Dev MRR: 0.5030, Hits@10: 0.6522


                                                     

Epoch 127 | Train Loss: 0.8334 | Dev MRR: 0.5272, Hits@10: 0.6304


                                                     

Epoch 128 | Train Loss: 0.8316 | Dev MRR: 0.5312, Hits@10: 0.6304


                                                     

Epoch 129 | Train Loss: 0.8323 | Dev MRR: 0.5271, Hits@10: 0.6304


                                                     

Epoch 130 | Train Loss: 0.8312 | Dev MRR: 0.5217, Hits@10: 0.6087


                                                     

Epoch 131 | Train Loss: 0.8321 | Dev MRR: 0.5294, Hits@10: 0.6087


                                                     

Epoch 132 | Train Loss: 0.8303 | Dev MRR: 0.5063, Hits@10: 0.6087


                                                     

Epoch 133 | Train Loss: 0.8301 | Dev MRR: 0.5232, Hits@10: 0.6087


                                                     

Epoch 134 | Train Loss: 0.8292 | Dev MRR: 0.5264, Hits@10: 0.6304


                                                     

Epoch 135 | Train Loss: 0.8292 | Dev MRR: 0.5393, Hits@10: 0.6087


                                                     

Epoch 136 | Train Loss: 0.8287 | Dev MRR: 0.5357, Hits@10: 0.6304


                                                     

Epoch 137 | Train Loss: 0.8289 | Dev MRR: 0.5272, Hits@10: 0.6522


                                                     

Epoch 138 | Train Loss: 0.8287 | Dev MRR: 0.5321, Hits@10: 0.6304


                                                     

Epoch 139 | Train Loss: 0.8282 | Dev MRR: 0.5366, Hits@10: 0.6304


                                                     

Epoch 140 | Train Loss: 0.8266 | Dev MRR: 0.5166, Hits@10: 0.6304


                                                     

Epoch 141 | Train Loss: 0.8277 | Dev MRR: 0.5270, Hits@10: 0.6522


                                                     

Epoch 142 | Train Loss: 0.8265 | Dev MRR: 0.5385, Hits@10: 0.6304


                                                     

Epoch 143 | Train Loss: 0.8263 | Dev MRR: 0.5380, Hits@10: 0.6304


                                                     

Epoch 144 | Train Loss: 0.8257 | Dev MRR: 0.5348, Hits@10: 0.6522


                                                     

Epoch 145 | Train Loss: 0.8263 | Dev MRR: 0.5144, Hits@10: 0.6304


                                                     

Epoch 146 | Train Loss: 0.8247 | Dev MRR: 0.5463, Hits@10: 0.6304


                                                     

Epoch 147 | Train Loss: 0.8253 | Dev MRR: 0.5560, Hits@10: 0.6304


                                                     

Epoch 148 | Train Loss: 0.8253 | Dev MRR: 0.5277, Hits@10: 0.6087


                                                     

Epoch 149 | Train Loss: 0.8262 | Dev MRR: 0.5350, Hits@10: 0.6304


                                                     

Epoch 150 | Train Loss: 0.8249 | Dev MRR: 0.5323, Hits@10: 0.6522


                                                     

Epoch 151 | Train Loss: 0.8237 | Dev MRR: 0.5448, Hits@10: 0.6087


                                                     

Epoch 152 | Train Loss: 0.8235 | Dev MRR: 0.5354, Hits@10: 0.6304


                                                     

Epoch 153 | Train Loss: 0.8236 | Dev MRR: 0.5515, Hits@10: 0.6522


                                                     

Epoch 154 | Train Loss: 0.8234 | Dev MRR: 0.5562, Hits@10: 0.6304


                                                     

Epoch 155 | Train Loss: 0.8234 | Dev MRR: 0.5576, Hits@10: 0.6304


                                                     

Epoch 156 | Train Loss: 0.8234 | Dev MRR: 0.5290, Hits@10: 0.6304


                                                     

Epoch 157 | Train Loss: 0.8226 | Dev MRR: 0.5358, Hits@10: 0.6087


                                                     

Epoch 158 | Train Loss: 0.8219 | Dev MRR: 0.5448, Hits@10: 0.6087


                                                     

Epoch 159 | Train Loss: 0.8218 | Dev MRR: 0.5172, Hits@10: 0.6087


                                                     

Epoch 160 | Train Loss: 0.8211 | Dev MRR: 0.5428, Hits@10: 0.6522


                                                     

Epoch 161 | Train Loss: 0.8214 | Dev MRR: 0.5372, Hits@10: 0.5870


                                                     

Epoch 162 | Train Loss: 0.8209 | Dev MRR: 0.5462, Hits@10: 0.6087


                                                     

Epoch 163 | Train Loss: 0.8204 | Dev MRR: 0.5247, Hits@10: 0.6304


                                                     

Epoch 164 | Train Loss: 0.8206 | Dev MRR: 0.5213, Hits@10: 0.6087


                                                     

Epoch 165 | Train Loss: 0.8201 | Dev MRR: 0.5347, Hits@10: 0.6087


                                                     

Epoch 166 | Train Loss: 0.8195 | Dev MRR: 0.5372, Hits@10: 0.6522


                                                     

Epoch 167 | Train Loss: 0.8197 | Dev MRR: 0.5287, Hits@10: 0.6304


                                                     

Epoch 168 | Train Loss: 0.8190 | Dev MRR: 0.5395, Hits@10: 0.6739


                                                     

Epoch 169 | Train Loss: 0.8183 | Dev MRR: 0.5193, Hits@10: 0.6522


                                                     

Epoch 170 | Train Loss: 0.8191 | Dev MRR: 0.5410, Hits@10: 0.6304


                                                     

Epoch 171 | Train Loss: 0.8196 | Dev MRR: 0.5353, Hits@10: 0.6522


                                                     

Epoch 172 | Train Loss: 0.8183 | Dev MRR: 0.5248, Hits@10: 0.6304


                                                     

Epoch 173 | Train Loss: 0.8180 | Dev MRR: 0.5342, Hits@10: 0.6304


                                                     

Epoch 174 | Train Loss: 0.8181 | Dev MRR: 0.5461, Hits@10: 0.6304


                                                     

Epoch 175 | Train Loss: 0.8175 | Dev MRR: 0.5104, Hits@10: 0.6087


                                                     

Epoch 176 | Train Loss: 0.8172 | Dev MRR: 0.5200, Hits@10: 0.6304


                                                     

Epoch 177 | Train Loss: 0.8172 | Dev MRR: 0.5507, Hits@10: 0.6087


                                                     

Epoch 178 | Train Loss: 0.8169 | Dev MRR: 0.5627, Hits@10: 0.6304


                                                     

Epoch 179 | Train Loss: 0.8164 | Dev MRR: 0.5359, Hits@10: 0.6087


                                                     

Epoch 180 | Train Loss: 0.8159 | Dev MRR: 0.5406, Hits@10: 0.6739


                                                     

Epoch 181 | Train Loss: 0.8165 | Dev MRR: 0.5492, Hits@10: 0.6304


                                                     

Epoch 182 | Train Loss: 0.8165 | Dev MRR: 0.5525, Hits@10: 0.6087


                                                     

Epoch 183 | Train Loss: 0.8155 | Dev MRR: 0.5548, Hits@10: 0.6304


                                                     

Epoch 184 | Train Loss: 0.8155 | Dev MRR: 0.5334, Hits@10: 0.5870


                                                     

Epoch 185 | Train Loss: 0.8159 | Dev MRR: 0.5363, Hits@10: 0.6304


                                                     

Epoch 186 | Train Loss: 0.8146 | Dev MRR: 0.5399, Hits@10: 0.6087


                                                     

Epoch 187 | Train Loss: 0.8153 | Dev MRR: 0.5550, Hits@10: 0.6304


                                                     

Epoch 188 | Train Loss: 0.8138 | Dev MRR: 0.5233, Hits@10: 0.6087


                                                     

Epoch 189 | Train Loss: 0.8140 | Dev MRR: 0.5608, Hits@10: 0.6087


                                                     

Epoch 190 | Train Loss: 0.8140 | Dev MRR: 0.5433, Hits@10: 0.6304


                                                     

Epoch 191 | Train Loss: 0.8138 | Dev MRR: 0.5474, Hits@10: 0.6304


                                                     

Epoch 192 | Train Loss: 0.8135 | Dev MRR: 0.5370, Hits@10: 0.6087


                                                     

Epoch 193 | Train Loss: 0.8138 | Dev MRR: 0.5626, Hits@10: 0.6304


                                                     

Epoch 194 | Train Loss: 0.8129 | Dev MRR: 0.5297, Hits@10: 0.6304


                                                     

Epoch 195 | Train Loss: 0.8138 | Dev MRR: 0.5390, Hits@10: 0.6522


                                                     

Epoch 196 | Train Loss: 0.8130 | Dev MRR: 0.5119, Hits@10: 0.6304


                                                     

Epoch 197 | Train Loss: 0.8133 | Dev MRR: 0.5330, Hits@10: 0.6304


                                                     

Epoch 198 | Train Loss: 0.8127 | Dev MRR: 0.5223, Hits@10: 0.6522


                                                     

Epoch 199 | Train Loss: 0.8132 | Dev MRR: 0.5433, Hits@10: 0.6304


                                                     

Epoch 200 | Train Loss: 0.8128 | Dev MRR: 0.5468, Hits@10: 0.6304




In [8]:
# 10. Final test evaluation
mrr_test, hits_test = evaluate_ranking(test_data, len(entities))
print(f"Test MRR: {mrr_test:.4f}   " + ", ".join([f"{k}: {v:.4f}" for k,v in hits_test.items()]))


                                                     

Test MRR: 0.5107   Hits@1: 0.4565, Hits@3: 0.5435, Hits@10: 0.5652


