In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm

def load_triples(path):
    triples = []
    with open(path) as f:
        for line in f:
            h, r, t = line.strip().split()[:3]
            triples.append((h, r, t))
    return triples

train_triples = load_triples("../graph_data/train.tsv")
dev_triples =  load_triples("../graph_data/dev.tsv")
test_triples =  load_triples("../graph_data/test.tsv")


entities = sorted({e for (h, _, t) in train_triples + dev_triples + test_triples for e in (h, t)})
relations = sorted({r for (_, r, _) in train_triples + dev_triples + test_triples})
ent2id = {e: i for i, e in enumerate(entities)}
rel2id = {r: i for i, r in enumerate(relations)}


def encode_cls(triples):
    return torch.tensor([[ent2id[h], ent2id[t], rel2id[r]] for h, r, t in triples], dtype=torch.long)

train_data = encode_cls(train_triples)
dev_data   = encode_cls(dev_triples)
test_data  = encode_cls(test_triples)


batch_size = 512
train_loader = DataLoader(TensorDataset(train_data), batch_size=batch_size, shuffle=True)
dev_loader   = DataLoader(TensorDataset(dev_data),   batch_size=batch_size)
test_loader  = DataLoader(TensorDataset(test_data),  batch_size=batch_size)


class DistMultClassifier(nn.Module):
    def __init__(self, num_ent, num_rel, dim):
        super().__init__()
        self.ent = nn.Embedding(num_ent, dim)
        self.rel = nn.Embedding(num_rel, dim)
        nn.init.xavier_uniform_(self.ent.weight)
        nn.init.xavier_uniform_(self.rel.weight)

    def forward(self, h, t):

        h_e = self.ent(h)
        t_e = self.ent(t)

        scores = (h_e.unsqueeze(1) * self.rel.weight.unsqueeze(0) * t_e.unsqueeze(1)).sum(dim=-1)
        return scores



  from .autonotebook import tqdm as notebook_tqdm


In [2]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DistMultClassifier(len(entities), len(relations), dim=100).to(device)
optimizer = AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)

for epoch in range(1, 201):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch[0].to(device)
        h, t, r = batch[:,0], batch[:,1], batch[:,2]
        logits = model(h, t)  # [B, R]
        loss = F.cross_entropy(logits, r)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * h.size(0)
    if epoch % 10 == 0:
        avg_loss = total_loss / len(train_data)
        print(f"Epoch {epoch:02d} | Train Loss: {avg_loss:.4f}")



Epoch 10 | Train Loss: 1.9441
Epoch 20 | Train Loss: 1.9399
Epoch 30 | Train Loss: 1.9304
Epoch 40 | Train Loss: 1.9110
Epoch 50 | Train Loss: 1.8767
Epoch 60 | Train Loss: 1.8231
Epoch 70 | Train Loss: 1.7471
Epoch 80 | Train Loss: 1.6481
Epoch 90 | Train Loss: 1.5272
Epoch 100 | Train Loss: 1.3886
Epoch 110 | Train Loss: 1.2384
Epoch 120 | Train Loss: 1.0847
Epoch 130 | Train Loss: 0.9353
Epoch 140 | Train Loss: 0.7962
Epoch 150 | Train Loss: 0.6712
Epoch 160 | Train Loss: 0.5619
Epoch 170 | Train Loss: 0.4686
Epoch 180 | Train Loss: 0.3900
Epoch 190 | Train Loss: 0.3247
Epoch 200 | Train Loss: 0.2707


In [None]:

def eval_relation_metrics(data_tensor, ks=(1,3,10)):
    model.eval()
    correct = 0
    ranks = []
    with torch.no_grad():
        for h, t, r in data_tensor:
            h_t = torch.tensor([h], device=device)
            t_t = torch.tensor([t], device=device)
            true_r = r.item()
            scores = model(h_t, t_t).squeeze(0)

            pred = scores.argmax().item()
            correct += (pred == true_r)

            _, idxs = scores.sort(descending=True)
            rank = (idxs == true_r).nonzero(as_tuple=False).item() + 1
            ranks.append(rank)
    acc = correct / len(data_tensor)
    ranks = torch.tensor(ranks, dtype=torch.float)
    mrr = (1.0 / ranks).mean().item()
    hits = {f"Hits@{k}": (ranks <= k).float().mean().item() for k in ks}
    return acc, mrr, hits

dev_acc, dev_mrr, dev_hits = eval_relation_metrics(dev_data)
test_acc, test_mrr, test_hits = eval_relation_metrics(test_data)
print(f"Dev ▶ Acc={dev_acc:.4f} MRR={dev_mrr:.4f} Hits@1={dev_hits['Hits@1']:.4f} Hits@3={dev_hits['Hits@3']:.4f}")
print(f"Test ▶ Acc={test_acc:.4f} MRR={test_mrr:.4f} Hits@1={test_hits['Hits@1']:.4f} Hits@3={test_hits['Hits@3']:.4f}")


Dev ▶ Acc=0.5652 MRR=0.6869 Hits@1=0.5652 Hits@3=0.7174
Test ▶ Acc=0.4783 MRR=0.6583 Hits@1=0.4783 Hits@3=0.7609
