In [1]:
# Jupyter Notebook cell: R-GCN training and evaluation

import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import RGCNConv
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:


# 1. Load triples
def load_triples(path):
    triples = []
    with open(path, 'r') as f:
        for line in f:
            h, r, t = line.strip().split()[:3]
            triples.append((h, r, t))
    return triples

train_triples = load_triples("../graph_data/train.tsv")
dev_triples =  load_triples("../graph_data/dev.tsv")
test_triples =  load_triples("../graph_data/test.tsv")

# 2. Build vocab and graph edges
entities = sorted({e for (h, _, t) in train_triples + dev_triples + test_triples for e in (h, t)})
relations = sorted({r for (_, r, _) in train_triples + dev_triples + test_triples})
ent2id = {e: i for i, e in enumerate(entities)}
rel2id = {r: i for i, r in enumerate(relations)}
num_nodes = len(entities)
num_rels = len(relations)

# Build edge index and edge types
edge_index = [[], []]
edge_type = []
for h, r, t in train_triples:
    src = ent2id[h]
    dst = ent2id[t]
    edge_index[0].append(src)
    edge_index[1].append(dst)
    edge_type.append(rel2id[r])
# convert to tensors
edge_index = torch.tensor(edge_index, dtype=torch.long)
edge_type = torch.tensor(edge_type, dtype=torch.long)



data = Data(num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

data = data.to(device)

In [3]:
# 3. Model definition
class RGCN(nn.Module):
    def __init__(self, num_nodes, num_rels, hidden_dim, num_layers=2, dropout=0.3):
        super().__init__()
        self.convs = nn.ModuleList()
        # input to first layer
        self.convs.append(RGCNConv(num_nodes, hidden_dim, num_rels, num_bases=30))
        for _ in range(num_layers-1):
            self.convs.append(RGCNConv(hidden_dim, hidden_dim, num_rels, num_bases=30))
        self.dropout = dropout
        # relation embeddings for scoring
        self.rel_emb = nn.Embedding(num_rels, hidden_dim)
        nn.init.xavier_uniform_(self.rel_emb.weight)

    def forward(self, x, edge_index, edge_type):
        for conv in self.convs:
            x = conv(x, edge_index, edge_type)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return x

# 4. Negative sampling for triples

def negative_sampling(batch, num_entities):
    neg = batch.clone()
    mask = torch.rand(batch.size(0), device=batch.device) < 0.5
    num_h = mask.sum().item()
    neg[mask, 0] = torch.randint(0, num_entities, (num_h,), device=batch.device)
    num_t = (~mask).sum().item()
    neg[~mask, 2] = torch.randint(0, num_entities, (num_t,), device=batch.device)
    return neg

# 5. Encode triples into index tensors

def encode(triples):
    return torch.tensor([[ent2id[h], rel2id[r], ent2id[t]] for h, r, t in triples], dtype=torch.long)

In [None]:
train_data = encode(train_triples).to(device)
dev_data   = encode(dev_triples).to(device)
test_data  = encode(test_triples).to(device)

# 6. Instantiate model and optimizer
model = RGCN(num_nodes, num_rels, hidden_dim=100, num_layers=2).to(device)
optimizer = AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)


In [5]:
# 7. Training and evaluation functions

def score_triples(node_emb, triples):
    hs, rs, ts = triples[:,0], triples[:,1], triples[:,2]
    re = node_emb[hs] * model.rel_emb(rs)
    score = (re * node_emb[ts]).sum(dim=-1)
    return score

def train_epoch():
    model.train()
    total_loss = 0; count=0
    embeddings = torch.eye(num_nodes, device=device)  # one-hot initial
    for batch in tqdm(DataLoader(train_data, batch_size=1024, shuffle=True), desc='Train', leave=False):
        neg = negative_sampling(batch, num_nodes)
        node_emb = model(embeddings, data.edge_index, data.edge_type)
        pos_score = score_triples(node_emb, batch)
        neg_score = score_triples(node_emb, neg)
        loss = - (F.logsigmoid(pos_score).mean() + F.logsigmoid(-neg_score).mean())
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item() * batch.size(0); count += batch.size(0)
    return total_loss/count

# 9. Evaluation
torch.no_grad()
def evaluate(triples_tensor, num_neg=100, k_list=(1,3,10)):
    model.eval()
    ranks = []
    embeddings = torch.eye(num_nodes, device=device)
    node_emb = model(embeddings, data.edge_index, data.edge_type)
    for idx in tqdm(range(triples_tensor.size(0)), desc='Eval', leave=False):
        triple = triples_tensor[idx].unsqueeze(0)
        h, r, t = triple[0]
        pos_score = score_triples(node_emb, triple)[0]
        scores = [pos_score.item()]
        for _ in range(num_neg):
            if random.random() < 0.5:
                h2 = random.randint(0, num_nodes-1); t2 = t.item()
            else:
                h2 = h.item(); t2 = random.randint(0, num_nodes-1)
            neg_triple = torch.tensor([[h2, r, t2]], device=device)
            neg_score = score_triples(node_emb, neg_triple)[0]
            scores.append(neg_score.item())
        sorted_scores = sorted(scores, reverse=True)
        rank = sorted_scores.index(scores[0]) + 1
        ranks.append(rank)
    ranks = torch.tensor(ranks, dtype=torch.float)
    mrr = (1.0 / ranks).mean().item()
    hits = {f"Hits@{k}": (ranks <= k).float().mean().item() for k in k_list}
    return mrr, hits



In [None]:
dev_best = 0
for epoch in range(1, 201):
    loss = train_epoch()
    mrr_dev, hits_dev = evaluate(dev_data)
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Dev MRR: {mrr_dev:.4f}, Hits@10: {hits_dev['Hits@10']:.4f}")
    dev_best = max(dev_best, mrr_dev)

                                                     

Epoch 01 | Loss: 1.3859 | Dev MRR: 0.0656, Hits@10: 0.1739


                                                     

Epoch 02 | Loss: 1.3830 | Dev MRR: 0.0725, Hits@10: 0.1522


                                                     

Epoch 03 | Loss: 1.3738 | Dev MRR: 0.0832, Hits@10: 0.2174


                                                     

Epoch 04 | Loss: 1.3611 | Dev MRR: 0.1158, Hits@10: 0.3043


                                                     

Epoch 05 | Loss: 1.3494 | Dev MRR: 0.0979, Hits@10: 0.2391


                                                     

Epoch 06 | Loss: 1.3317 | Dev MRR: 0.1048, Hits@10: 0.3043


                                                     

Epoch 07 | Loss: 1.3229 | Dev MRR: 0.1156, Hits@10: 0.2826


                                                     

Epoch 08 | Loss: 1.2989 | Dev MRR: 0.1169, Hits@10: 0.2609


                                                     

Epoch 09 | Loss: 1.2672 | Dev MRR: 0.0787, Hits@10: 0.2609


                                                     

Epoch 10 | Loss: 1.2443 | Dev MRR: 0.0819, Hits@10: 0.3261


                                                     

Epoch 11 | Loss: 1.1892 | Dev MRR: 0.0961, Hits@10: 0.4130


                                                     

Epoch 12 | Loss: 1.1811 | Dev MRR: 0.0846, Hits@10: 0.3043


                                                     

Epoch 13 | Loss: 1.1285 | Dev MRR: 0.0886, Hits@10: 0.3043


                                                     

Epoch 14 | Loss: 1.1004 | Dev MRR: 0.1183, Hits@10: 0.3478


                                                     

Epoch 15 | Loss: 1.0430 | Dev MRR: 0.1120, Hits@10: 0.3696


                                                     

Epoch 16 | Loss: 1.0024 | Dev MRR: 0.1008, Hits@10: 0.2826


                                                     

Epoch 17 | Loss: 0.9685 | Dev MRR: 0.0982, Hits@10: 0.3261


                                                     

Epoch 18 | Loss: 0.9308 | Dev MRR: 0.1060, Hits@10: 0.2609


                                                     

Epoch 19 | Loss: 0.9294 | Dev MRR: 0.1049, Hits@10: 0.2826


                                                     

Epoch 20 | Loss: 0.8823 | Dev MRR: 0.0851, Hits@10: 0.2826


                                                     

Epoch 21 | Loss: 0.8588 | Dev MRR: 0.1070, Hits@10: 0.2826


                                                     

Epoch 22 | Loss: 0.8590 | Dev MRR: 0.1045, Hits@10: 0.2609


                                                     

Epoch 23 | Loss: 0.8257 | Dev MRR: 0.0942, Hits@10: 0.3261


                                                     

Epoch 24 | Loss: 0.8067 | Dev MRR: 0.1034, Hits@10: 0.3261


                                                     

Epoch 25 | Loss: 0.7985 | Dev MRR: 0.1262, Hits@10: 0.3696


                                                     

Epoch 26 | Loss: 0.7758 | Dev MRR: 0.1097, Hits@10: 0.2826


                                                     

Epoch 27 | Loss: 0.7586 | Dev MRR: 0.1371, Hits@10: 0.3478


                                                     

Epoch 28 | Loss: 0.7490 | Dev MRR: 0.1120, Hits@10: 0.2826


                                                     

Epoch 29 | Loss: 0.7282 | Dev MRR: 0.1245, Hits@10: 0.2609


                                                     

Epoch 30 | Loss: 0.7095 | Dev MRR: 0.1532, Hits@10: 0.2826


                                                     

Epoch 31 | Loss: 0.7091 | Dev MRR: 0.1464, Hits@10: 0.2826


                                                     

Epoch 32 | Loss: 0.7023 | Dev MRR: 0.1271, Hits@10: 0.2609


                                                     

Epoch 33 | Loss: 0.6708 | Dev MRR: 0.1132, Hits@10: 0.3043


                                                     

Epoch 34 | Loss: 0.6619 | Dev MRR: 0.1442, Hits@10: 0.2826


                                                     

Epoch 35 | Loss: 0.6770 | Dev MRR: 0.1088, Hits@10: 0.2609


                                                     

Epoch 36 | Loss: 0.6529 | Dev MRR: 0.0969, Hits@10: 0.2391


                                                     

Epoch 37 | Loss: 0.6548 | Dev MRR: 0.1317, Hits@10: 0.2391


                                                     

Epoch 38 | Loss: 0.6183 | Dev MRR: 0.1163, Hits@10: 0.2609


                                                     

Epoch 39 | Loss: 0.6370 | Dev MRR: 0.1162, Hits@10: 0.2826


                                                     

Epoch 40 | Loss: 0.6086 | Dev MRR: 0.1324, Hits@10: 0.3043


                                                     

Epoch 41 | Loss: 0.5798 | Dev MRR: 0.1248, Hits@10: 0.3478


                                                     

Epoch 42 | Loss: 0.6037 | Dev MRR: 0.1228, Hits@10: 0.2826


                                                     

Epoch 43 | Loss: 0.5992 | Dev MRR: 0.1077, Hits@10: 0.2826


                                                     

Epoch 44 | Loss: 0.5644 | Dev MRR: 0.1341, Hits@10: 0.2609


                                                     

Epoch 45 | Loss: 0.5656 | Dev MRR: 0.1315, Hits@10: 0.3043


                                                     

Epoch 46 | Loss: 0.5554 | Dev MRR: 0.1052, Hits@10: 0.2174


                                                     

Epoch 47 | Loss: 0.5560 | Dev MRR: 0.1311, Hits@10: 0.2826


                                                     

Epoch 48 | Loss: 0.5430 | Dev MRR: 0.1252, Hits@10: 0.2391


                                                     

Epoch 49 | Loss: 0.5448 | Dev MRR: 0.1200, Hits@10: 0.1957


                                                     

Epoch 50 | Loss: 0.5368 | Dev MRR: 0.1128, Hits@10: 0.2609


                                                     

Epoch 51 | Loss: 0.5059 | Dev MRR: 0.1254, Hits@10: 0.2609


                                                     

Epoch 52 | Loss: 0.5469 | Dev MRR: 0.1158, Hits@10: 0.2391


                                                     

Epoch 53 | Loss: 0.5446 | Dev MRR: 0.1194, Hits@10: 0.1957


                                                     

Epoch 54 | Loss: 0.5024 | Dev MRR: 0.1111, Hits@10: 0.1739


                                                     

Epoch 55 | Loss: 0.5111 | Dev MRR: 0.0990, Hits@10: 0.1739


                                                     

Epoch 56 | Loss: 0.5222 | Dev MRR: 0.1098, Hits@10: 0.1739


                                                     

Epoch 57 | Loss: 0.5007 | Dev MRR: 0.1085, Hits@10: 0.2174


                                                     

Epoch 58 | Loss: 0.4828 | Dev MRR: 0.1181, Hits@10: 0.2174


                                                     

Epoch 59 | Loss: 0.4968 | Dev MRR: 0.1226, Hits@10: 0.2609


                                                     

Epoch 60 | Loss: 0.4568 | Dev MRR: 0.1041, Hits@10: 0.2391


                                                     

Epoch 61 | Loss: 0.4614 | Dev MRR: 0.1536, Hits@10: 0.2174


                                                     

Epoch 62 | Loss: 0.4524 | Dev MRR: 0.1213, Hits@10: 0.2826


                                                     

Epoch 63 | Loss: 0.4696 | Dev MRR: 0.1149, Hits@10: 0.2609


                                                     

Epoch 64 | Loss: 0.4347 | Dev MRR: 0.1209, Hits@10: 0.2391


                                                     

Epoch 65 | Loss: 0.4381 | Dev MRR: 0.1019, Hits@10: 0.2609


                                                     

Epoch 66 | Loss: 0.4391 | Dev MRR: 0.1363, Hits@10: 0.2826


                                                     

Epoch 67 | Loss: 0.4307 | Dev MRR: 0.1093, Hits@10: 0.2826


                                                     

Epoch 68 | Loss: 0.4344 | Dev MRR: 0.1151, Hits@10: 0.2609


                                                     

Epoch 69 | Loss: 0.4196 | Dev MRR: 0.1102, Hits@10: 0.2609


                                                     

Epoch 70 | Loss: 0.4170 | Dev MRR: 0.0959, Hits@10: 0.2609


                                                     

Epoch 71 | Loss: 0.4226 | Dev MRR: 0.1375, Hits@10: 0.3043


                                                     

Epoch 72 | Loss: 0.4211 | Dev MRR: 0.1302, Hits@10: 0.2826


                                                     

Epoch 73 | Loss: 0.3912 | Dev MRR: 0.1265, Hits@10: 0.2609


                                                     

Epoch 74 | Loss: 0.4072 | Dev MRR: 0.1009, Hits@10: 0.1957


                                                     

Epoch 75 | Loss: 0.4175 | Dev MRR: 0.1242, Hits@10: 0.2826


                                                     

Epoch 76 | Loss: 0.3896 | Dev MRR: 0.1299, Hits@10: 0.2826


                                                     

Epoch 77 | Loss: 0.4042 | Dev MRR: 0.1222, Hits@10: 0.2609


                                                     

Epoch 78 | Loss: 0.3944 | Dev MRR: 0.1356, Hits@10: 0.2609


                                                     

Epoch 79 | Loss: 0.3929 | Dev MRR: 0.1246, Hits@10: 0.2391


                                                     

Epoch 80 | Loss: 0.3743 | Dev MRR: 0.1453, Hits@10: 0.2609


                                                     

Epoch 81 | Loss: 0.3543 | Dev MRR: 0.1130, Hits@10: 0.2174


                                                     

Epoch 82 | Loss: 0.3732 | Dev MRR: 0.1499, Hits@10: 0.2391


                                                     

Epoch 83 | Loss: 0.3615 | Dev MRR: 0.1170, Hits@10: 0.2391


                                                     

Epoch 84 | Loss: 0.3529 | Dev MRR: 0.1417, Hits@10: 0.2391


                                                     

Epoch 85 | Loss: 0.3527 | Dev MRR: 0.1212, Hits@10: 0.2609


                                                     

Epoch 86 | Loss: 0.3559 | Dev MRR: 0.1449, Hits@10: 0.2391


                                                     

Epoch 87 | Loss: 0.3471 | Dev MRR: 0.1047, Hits@10: 0.2826


                                                     

Epoch 88 | Loss: 0.3378 | Dev MRR: 0.1252, Hits@10: 0.2826


                                                     

Epoch 89 | Loss: 0.3456 | Dev MRR: 0.1301, Hits@10: 0.2609


                                                     

Epoch 90 | Loss: 0.3458 | Dev MRR: 0.1276, Hits@10: 0.2609


                                                     

Epoch 91 | Loss: 0.3447 | Dev MRR: 0.1282, Hits@10: 0.3043


                                                     

Epoch 92 | Loss: 0.3464 | Dev MRR: 0.1071, Hits@10: 0.2826


                                                     

Epoch 93 | Loss: 0.3347 | Dev MRR: 0.1224, Hits@10: 0.2174


                                                     

Epoch 94 | Loss: 0.3347 | Dev MRR: 0.1376, Hits@10: 0.2609


                                                     

Epoch 95 | Loss: 0.3202 | Dev MRR: 0.1293, Hits@10: 0.2826


                                                     

Epoch 96 | Loss: 0.3321 | Dev MRR: 0.1318, Hits@10: 0.2609


                                                     

Epoch 97 | Loss: 0.3089 | Dev MRR: 0.1428, Hits@10: 0.3043


                                                     

Epoch 98 | Loss: 0.3293 | Dev MRR: 0.1429, Hits@10: 0.2826


                                                     

Epoch 99 | Loss: 0.3193 | Dev MRR: 0.1486, Hits@10: 0.3478


                                                     

Epoch 100 | Loss: 0.3126 | Dev MRR: 0.1241, Hits@10: 0.3043


                                                     

Epoch 101 | Loss: 0.3223 | Dev MRR: 0.1376, Hits@10: 0.3261


                                                     

Epoch 102 | Loss: 0.3011 | Dev MRR: 0.1329, Hits@10: 0.3478


                                                     

Epoch 103 | Loss: 0.3073 | Dev MRR: 0.1416, Hits@10: 0.3261


                                                     

Epoch 104 | Loss: 0.3009 | Dev MRR: 0.1142, Hits@10: 0.3043


                                                     

Epoch 105 | Loss: 0.2786 | Dev MRR: 0.1643, Hits@10: 0.2826


                                                     

Epoch 106 | Loss: 0.2827 | Dev MRR: 0.1664, Hits@10: 0.2826


                                                     

Epoch 107 | Loss: 0.2996 | Dev MRR: 0.1602, Hits@10: 0.3043


                                                     

Epoch 108 | Loss: 0.3085 | Dev MRR: 0.1475, Hits@10: 0.3261


                                                     

Epoch 109 | Loss: 0.2824 | Dev MRR: 0.1660, Hits@10: 0.3261


                                                     

Epoch 110 | Loss: 0.2697 | Dev MRR: 0.1668, Hits@10: 0.3478


                                                     

Epoch 111 | Loss: 0.2879 | Dev MRR: 0.1828, Hits@10: 0.3261


                                                     

Epoch 112 | Loss: 0.2915 | Dev MRR: 0.1358, Hits@10: 0.3261


                                                     

Epoch 113 | Loss: 0.2842 | Dev MRR: 0.1424, Hits@10: 0.2826


                                                     

Epoch 114 | Loss: 0.2643 | Dev MRR: 0.1564, Hits@10: 0.3261


                                                     

Epoch 115 | Loss: 0.2474 | Dev MRR: 0.1645, Hits@10: 0.3043


                                                     

Epoch 116 | Loss: 0.2898 | Dev MRR: 0.1507, Hits@10: 0.3478


                                                     

Epoch 117 | Loss: 0.2674 | Dev MRR: 0.1747, Hits@10: 0.3043


                                                     

Epoch 118 | Loss: 0.2735 | Dev MRR: 0.1533, Hits@10: 0.3696


                                                     

Epoch 119 | Loss: 0.2672 | Dev MRR: 0.1423, Hits@10: 0.3478


                                                     

Epoch 120 | Loss: 0.2825 | Dev MRR: 0.1558, Hits@10: 0.3261


                                                     

Epoch 121 | Loss: 0.2657 | Dev MRR: 0.1554, Hits@10: 0.3043


                                                     

Epoch 122 | Loss: 0.2741 | Dev MRR: 0.1540, Hits@10: 0.3261


                                                     

Epoch 123 | Loss: 0.2868 | Dev MRR: 0.1715, Hits@10: 0.2826


                                                     

Epoch 124 | Loss: 0.2769 | Dev MRR: 0.1273, Hits@10: 0.3261


                                                     

Epoch 125 | Loss: 0.2432 | Dev MRR: 0.1760, Hits@10: 0.3478


                                                     

Epoch 126 | Loss: 0.2888 | Dev MRR: 0.1619, Hits@10: 0.3043


                                                     

Epoch 127 | Loss: 0.2645 | Dev MRR: 0.1785, Hits@10: 0.3478


                                                     

Epoch 128 | Loss: 0.2562 | Dev MRR: 0.1676, Hits@10: 0.3478


                                                     

Epoch 129 | Loss: 0.2475 | Dev MRR: 0.1621, Hits@10: 0.3261


                                                     

Epoch 130 | Loss: 0.2901 | Dev MRR: 0.1522, Hits@10: 0.3261


                                                     

Epoch 131 | Loss: 0.2467 | Dev MRR: 0.1846, Hits@10: 0.3478


                                                     

Epoch 132 | Loss: 0.2444 | Dev MRR: 0.1645, Hits@10: 0.3696


                                                     

Epoch 133 | Loss: 0.2399 | Dev MRR: 0.1784, Hits@10: 0.3261


                                                     

Epoch 134 | Loss: 0.2264 | Dev MRR: 0.1662, Hits@10: 0.3261


                                                     

Epoch 135 | Loss: 0.2544 | Dev MRR: 0.1582, Hits@10: 0.3478


                                                     

Epoch 136 | Loss: 0.2284 | Dev MRR: 0.1444, Hits@10: 0.3478


                                                     

Epoch 137 | Loss: 0.2289 | Dev MRR: 0.1407, Hits@10: 0.3261


                                                     

Epoch 138 | Loss: 0.2397 | Dev MRR: 0.1463, Hits@10: 0.3043


                                                     

Epoch 139 | Loss: 0.2327 | Dev MRR: 0.1628, Hits@10: 0.3261


                                                     

Epoch 140 | Loss: 0.2303 | Dev MRR: 0.1732, Hits@10: 0.3478


                                                     

Epoch 141 | Loss: 0.2622 | Dev MRR: 0.1903, Hits@10: 0.3478


                                                     

Epoch 142 | Loss: 0.2417 | Dev MRR: 0.1503, Hits@10: 0.3478


                                                     

Epoch 143 | Loss: 0.2447 | Dev MRR: 0.1668, Hits@10: 0.3913


                                                     

Epoch 144 | Loss: 0.2284 | Dev MRR: 0.1887, Hits@10: 0.3478


                                                     

Epoch 145 | Loss: 0.2299 | Dev MRR: 0.1707, Hits@10: 0.3261


                                                     

Epoch 146 | Loss: 0.2254 | Dev MRR: 0.1612, Hits@10: 0.3261


                                                     

Epoch 147 | Loss: 0.2202 | Dev MRR: 0.1787, Hits@10: 0.3478


                                                     

Epoch 148 | Loss: 0.2321 | Dev MRR: 0.1904, Hits@10: 0.3261


                                                     

Epoch 149 | Loss: 0.2183 | Dev MRR: 0.1934, Hits@10: 0.3478


                                                     

Epoch 150 | Loss: 0.2319 | Dev MRR: 0.1563, Hits@10: 0.3478


                                                     

Epoch 151 | Loss: 0.2461 | Dev MRR: 0.1797, Hits@10: 0.3261


                                                     

Epoch 152 | Loss: 0.2181 | Dev MRR: 0.1799, Hits@10: 0.3478


                                                     

Epoch 153 | Loss: 0.2253 | Dev MRR: 0.2145, Hits@10: 0.3261


                                                     

Epoch 154 | Loss: 0.2327 | Dev MRR: 0.1649, Hits@10: 0.3261


                                                     

Epoch 155 | Loss: 0.2438 | Dev MRR: 0.1757, Hits@10: 0.3478


                                                     

Epoch 156 | Loss: 0.2347 | Dev MRR: 0.2221, Hits@10: 0.3478


                                                     

Epoch 157 | Loss: 0.2317 | Dev MRR: 0.1903, Hits@10: 0.3478


                                                     

Epoch 158 | Loss: 0.2049 | Dev MRR: 0.2193, Hits@10: 0.3696


                                                     

Epoch 159 | Loss: 0.2514 | Dev MRR: 0.2192, Hits@10: 0.3696


                                                     

Epoch 160 | Loss: 0.2345 | Dev MRR: 0.2152, Hits@10: 0.3696


                                                     

Epoch 161 | Loss: 0.2202 | Dev MRR: 0.1634, Hits@10: 0.3043


                                                     

Epoch 162 | Loss: 0.2169 | Dev MRR: 0.2168, Hits@10: 0.3478


                                                     

Epoch 163 | Loss: 0.2212 | Dev MRR: 0.1697, Hits@10: 0.3478


                                                     

Epoch 164 | Loss: 0.2147 | Dev MRR: 0.1725, Hits@10: 0.3261


                                                     

Epoch 165 | Loss: 0.2085 | Dev MRR: 0.1833, Hits@10: 0.3043


                                                     

Epoch 166 | Loss: 0.2063 | Dev MRR: 0.1746, Hits@10: 0.3261


                                                     

Epoch 167 | Loss: 0.2405 | Dev MRR: 0.1758, Hits@10: 0.3478


                                                     

Epoch 168 | Loss: 0.2023 | Dev MRR: 0.2001, Hits@10: 0.3478


                                                     

Epoch 169 | Loss: 0.1961 | Dev MRR: 0.1835, Hits@10: 0.3261


                                                     

Epoch 170 | Loss: 0.2184 | Dev MRR: 0.1632, Hits@10: 0.3478


                                                     

Epoch 171 | Loss: 0.2116 | Dev MRR: 0.1833, Hits@10: 0.3478


                                                     

Epoch 172 | Loss: 0.2050 | Dev MRR: 0.1876, Hits@10: 0.3478


                                                     

Epoch 173 | Loss: 0.2098 | Dev MRR: 0.1701, Hits@10: 0.3478


                                                     

Epoch 174 | Loss: 0.2046 | Dev MRR: 0.1528, Hits@10: 0.3261


                                                     

Epoch 175 | Loss: 0.2078 | Dev MRR: 0.1808, Hits@10: 0.3043


                                                     

Epoch 176 | Loss: 0.2234 | Dev MRR: 0.1583, Hits@10: 0.3261


                                                     

Epoch 177 | Loss: 0.2173 | Dev MRR: 0.1454, Hits@10: 0.3261


                                                     

Epoch 178 | Loss: 0.2017 | Dev MRR: 0.1785, Hits@10: 0.3261


                                                     

Epoch 179 | Loss: 0.2055 | Dev MRR: 0.1861, Hits@10: 0.3478


                                                     

Epoch 180 | Loss: 0.2033 | Dev MRR: 0.2071, Hits@10: 0.3043


                                                     

Epoch 181 | Loss: 0.2104 | Dev MRR: 0.1733, Hits@10: 0.3478


                                                     

Epoch 182 | Loss: 0.2046 | Dev MRR: 0.1950, Hits@10: 0.3043


                                                     

Epoch 183 | Loss: 0.2027 | Dev MRR: 0.1842, Hits@10: 0.3696


                                                     

Epoch 184 | Loss: 0.1921 | Dev MRR: 0.1857, Hits@10: 0.3478


                                                     

Epoch 185 | Loss: 0.1976 | Dev MRR: 0.1878, Hits@10: 0.3478


                                                     

Epoch 186 | Loss: 0.2107 | Dev MRR: 0.1758, Hits@10: 0.3478


                                                     

Epoch 187 | Loss: 0.1867 | Dev MRR: 0.1772, Hits@10: 0.3261


                                                     

Epoch 188 | Loss: 0.1788 | Dev MRR: 0.1751, Hits@10: 0.3478


                                                     

Epoch 189 | Loss: 0.2064 | Dev MRR: 0.1789, Hits@10: 0.3478


                                                     

Epoch 190 | Loss: 0.2038 | Dev MRR: 0.1796, Hits@10: 0.3261


                                                     

Epoch 191 | Loss: 0.1732 | Dev MRR: 0.2004, Hits@10: 0.3478


                                                     

Epoch 192 | Loss: 0.1891 | Dev MRR: 0.1704, Hits@10: 0.3478


                                                     

Epoch 193 | Loss: 0.1827 | Dev MRR: 0.1714, Hits@10: 0.3478


                                                     

Epoch 194 | Loss: 0.2001 | Dev MRR: 0.1927, Hits@10: 0.3478


                                                     

Epoch 195 | Loss: 0.2017 | Dev MRR: 0.2217, Hits@10: 0.3261


                                                     

Epoch 196 | Loss: 0.1977 | Dev MRR: 0.1800, Hits@10: 0.3478


                                                     

Epoch 197 | Loss: 0.2016 | Dev MRR: 0.2002, Hits@10: 0.3261


                                                     

Epoch 198 | Loss: 0.1887 | Dev MRR: 0.1854, Hits@10: 0.3261


                                                     

Epoch 199 | Loss: 0.1776 | Dev MRR: 0.1955, Hits@10: 0.3478


                                                     

Epoch 200 | Loss: 0.1911 | Dev MRR: 0.1794, Hits@10: 0.3261




In [14]:
# 9. Final test evaluation
mrr_test, hits_test = evaluate(test_data)
print(f"Test MRR: {mrr_test:.4f}   " + ", ".join([f"{k}: {v:.4f}" for k,v in hits_test.items()]))

                                                     

Test MRR: 0.1331   Hits@1: 0.0435, Hits@3: 0.1522, Hits@10: 0.3261


