In [None]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm




def load_triples(path):
    triples = []
    with open(path, 'r') as f:
        for line in f:
            h, r, t = line.strip().split()[:3]
            triples.append((h, r, t))
    return triples

train_triples = load_triples("../graph_data/train.tsv")
dev_triples =  load_triples("../graph_data/dev.tsv")
test_triples =  load_triples("../graph_data/test.tsv")


entities = sorted({e for (h, _, t) in train_triples + dev_triples + test_triples for e in (h, t)})
relations = sorted({r for (_, r, _) in train_triples + dev_triples + test_triples})
ent2id = {e: i for i, e in enumerate(entities)}
rel2id = {r: i for i, r in enumerate(relations)}



def encode_cls(triples):
    data = []
    for h, r, t in triples:
        data.append((ent2id[h], ent2id[t], rel2id[r]))
    return torch.tensor(data, dtype=torch.long)

train_data = encode_cls(train_triples)
dev_data   = encode_cls(dev_triples)
test_data  = encode_cls(test_triples)


batch_size = 512
dev_loader   = DataLoader(TensorDataset(dev_data),   batch_size=batch_size)

train_loader = DataLoader(TensorDataset(train_data), batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(TensorDataset(test_data),  batch_size=batch_size)

In [None]:
# 5) RotatE-based classifier model
dim = 512

class RotatEClassifier(nn.Module):
    def __init__(self, num_ent, num_rel, dim):
        super().__init__()

        self.ent_re = nn.Embedding(num_ent, dim)
        self.ent_im = nn.Embedding(num_ent, dim)

        self.rel_ph = nn.Embedding(num_rel, dim)

        nn.init.xavier_uniform_(self.ent_re.weight)
        nn.init.xavier_uniform_(self.ent_im.weight)
        nn.init.uniform_(self.rel_ph.weight, -3.1415, 3.1415)

    def forward(self, h, t):

        re_h, im_h = self.ent_re(h), self.ent_im(h)
        re_t, im_t = self.ent_re(t), self.ent_im(t)

        ph = self.rel_ph.weight         
        re_r = torch.cos(ph)              
        im_r = torch.sin(ph)

        re_hr = re_h.unsqueeze(1) * re_r.unsqueeze(0) - im_h.unsqueeze(1) * im_r.unsqueeze(0)
        im_hr = re_h.unsqueeze(1) * im_r.unsqueeze(0) + im_h.unsqueeze(1) * re_r.unsqueeze(0)

        diff_re = re_hr - re_t.unsqueeze(1)
        diff_im = im_hr - im_t.unsqueeze(1)

        score = - (diff_re.abs() + diff_im.abs()).sum(dim=-1)
        return score 


num_ent = len(entities)
num_rel = len(relations)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = RotatEClassifier(num_ent, num_rel, dim).to(device)
optimizer = AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)

dev_tuple = dev_data.to(device)
test_tuple = test_data.to(device)

In [8]:
for epoch in range(1, 201):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch[0].to(device)
        h, t, r = batch[:,0], batch[:,1], batch[:,2]
        scores = model(h, t)
        loss = F.cross_entropy(scores, r)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * h.size(0)
    if epoch % 10 == 0:
        train_loss = total_loss / len(train_data)
        model.eval()
        correct = 0
        with torch.no_grad():
            for batch in dev_loader:
                h,t,r = batch[0][:,0].to(device), batch[0][:,1].to(device), batch[0][:,2].to(device)
                preds = model(h, t).argmax(dim=-1)
                correct += (preds == r).sum().item()
        dev_acc = correct / len(dev_data)
        print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | Dev Acc: {dev_acc:.4f}")





Epoch 10 | Train Loss: 0.0274 | Dev Acc: 0.1522
Epoch 20 | Train Loss: 0.0120 | Dev Acc: 0.1522
Epoch 30 | Train Loss: 0.0076 | Dev Acc: 0.1304
Epoch 40 | Train Loss: 0.0058 | Dev Acc: 0.1087
Epoch 50 | Train Loss: 0.0047 | Dev Acc: 0.1522
Epoch 60 | Train Loss: 0.0040 | Dev Acc: 0.1522
Epoch 70 | Train Loss: 0.0037 | Dev Acc: 0.1522
Epoch 80 | Train Loss: 0.0035 | Dev Acc: 0.1522
Epoch 90 | Train Loss: 0.0032 | Dev Acc: 0.1522
Epoch 100 | Train Loss: 0.0031 | Dev Acc: 0.1522
Epoch 110 | Train Loss: 0.0031 | Dev Acc: 0.1739
Epoch 120 | Train Loss: 0.0030 | Dev Acc: 0.1957
Epoch 130 | Train Loss: 0.0026 | Dev Acc: 0.1957
Epoch 140 | Train Loss: 0.0026 | Dev Acc: 0.1957
Epoch 150 | Train Loss: 0.0027 | Dev Acc: 0.1957
Epoch 160 | Train Loss: 0.0026 | Dev Acc: 0.1957
Epoch 170 | Train Loss: 0.0027 | Dev Acc: 0.2174
Epoch 180 | Train Loss: 0.0027 | Dev Acc: 0.2174
Epoch 190 | Train Loss: 0.0026 | Dev Acc: 0.2174
Epoch 200 | Train Loss: 0.0024 | Dev Acc: 0.2174


In [None]:

def evaluate_relation_ranking(data_tensor, ks=(1,3)):
    model.eval()
    ranks = []
    with torch.no_grad():
        for h,t,r in data_tensor:
            h = h.unsqueeze(0).to(device); t = t.unsqueeze(0).to(device); r = r.item()
            scores = model(h, t).squeeze(0)
            sorted_scores, idxs = scores.sort(descending=True)
            rank = (idxs == r).nonzero(as_tuple=False).item() + 1
            ranks.append(rank)
    ranks = torch.tensor(ranks, dtype=torch.float)
    mrr = (1.0 / ranks).mean().item()
    hits = {f"Hits@{k}": (ranks <= k).float().mean().item() for k in ks}
    return mrr, hits


correct = 0
with torch.no_grad():
    for batch in test_loader:
        h,t,r = batch[0][:,0].to(device), batch[0][:,1].to(device), batch[0][:,2].to(device)
        preds = model(h, t).argmax(dim=-1)
        correct += (preds == r).sum().item()
acc_test = correct / len(test_data)
mrr_test, hits_test = evaluate_relation_ranking(test_tuple)
print(f"Test Acc: {acc_test:.4f} | Test MRR: {mrr_test:.4f} | Hits@1: {hits_test['Hits@1']:.4f} | Hits@3: {hits_test['Hits@3']:.4f}")

Test Acc: 0.1957 | Test MRR: 0.4598 | Hits@1: 0.1957 | Hits@3: 0.6739
