In [None]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [None]:

def mobius_add(x, y, eps=1e-5):
    xy = (x * y).sum(dim=-1, keepdim=True)
    x2 = (x * x).sum(dim=-1, keepdim=True)
    y2 = (y * y).sum(dim=-1, keepdim=True)
    num = (1 + 2*xy + y2) * x + (1 - x2) * y
    denom = 1 + 2*xy + x2*y2
    return num / denom.clamp_min(eps)

def hyp_distance(x, y, eps=1e-5):
    x2 = (x * x).sum(dim=-1)
    y2 = (y * y).sum(dim=-1)
    diff2 = ((x - y) ** 2).sum(dim=-1)
    denom = (1 - x2) * (1 - y2)
    z = 1 + 2*diff2 / denom.clamp_min(eps)
    return torch.acosh(z.clamp_min(1 + eps))


class MuRPClassifier(nn.Module):
    def __init__(self, num_ent, num_rel, dim):
        super().__init__()
        self.ent = nn.Embedding(num_ent, dim)
        self.rel = nn.Embedding(num_rel, dim)
        nn.init.uniform_(self.ent.weight, -0.001, 0.001)
        nn.init.uniform_(self.rel.weight, -0.001, 0.001)

    def forward(self, h, t):
        h_e = self.ent(h)
        t_e = self.ent(t)
        ph = self.rel.weight.tanh()       

        h_r = mobius_add(h_e.unsqueeze(1), ph.unsqueeze(0)) 

        dist = hyp_distance(h_r, t_e.unsqueeze(1))         
        return -dist  


def load_triples(path):
    triples = []
    with open(path) as f:
        for line in f:
            h, r, t = line.strip().split()[:3]
            triples.append((h, r, t))
    return triples

train_triples = load_triples("../graph_data/train.tsv")
dev_triples =  load_triples("../graph_data/dev.tsv")
test_triples =  load_triples("../graph_data/test.tsv")

entities = sorted({e for h,_,t in train_triples+dev_triples+test_triples for e in (h,t)})
relations = sorted({r for _,r,_ in train_triples+dev_triples+test_triples})
ent2id = {e:i for i,e in enumerate(entities)}
rel2id = {r:i for i,r in enumerate(relations)}

def encode_cls(triples):
    return torch.tensor([[ent2id[h], ent2id[t], rel2id[r]] for h,r,t in triples], dtype=torch.long)

train_data = encode_cls(train_triples)
dev_data   = encode_cls(dev_triples)
test_data  = encode_cls(test_triples)


batch_size = 512
train_loader = DataLoader(TensorDataset(train_data), batch_size=batch_size, shuffle=True)
dev_loader   = DataLoader(TensorDataset(dev_data),   batch_size=batch_size)
test_loader  = DataLoader(TensorDataset(test_data),  batch_size=batch_size)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MuRPClassifier(len(entities), len(relations), dim=100).to(device)
optimizer = AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)

for epoch in range(1, 201):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch[0].to(device)
        h, t, r = batch[:,0], batch[:,1], batch[:,2]
        scores = model(h, t)       
        loss = F.cross_entropy(scores, r)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item() * h.size(0)
    if epoch % 10 == 0:
        print(f"Epoch {epoch:02d} | Train Loss: {total_loss/len(train_data):.4f}")




Epoch 10 | Train Loss: 1.3982
Epoch 20 | Train Loss: 0.8642
Epoch 30 | Train Loss: 0.5762
Epoch 40 | Train Loss: 0.4007
Epoch 50 | Train Loss: 0.2907
Epoch 60 | Train Loss: 0.2201
Epoch 70 | Train Loss: 0.1718
Epoch 80 | Train Loss: 0.1394
Epoch 90 | Train Loss: 0.1148
Epoch 100 | Train Loss: 0.0994
Epoch 110 | Train Loss: 0.0881
Epoch 120 | Train Loss: 0.0772
Epoch 130 | Train Loss: 0.0704
Epoch 140 | Train Loss: 0.0629
Epoch 150 | Train Loss: 0.0572
Epoch 160 | Train Loss: 0.0558
Epoch 170 | Train Loss: 0.0495
Epoch 180 | Train Loss: 0.0456
Epoch 190 | Train Loss: 0.0419
Epoch 200 | Train Loss: 0.0423


In [None]:

def eval_relation_metrics(data_tensor, ks=(1,3,10)):
    model.eval()
    correct = 0
    ranks = []
    with torch.no_grad():
        for h, t, r in data_tensor:
            h_t = torch.tensor([h], device=device)
            t_t = torch.tensor([t], device=device)
            true_r = r.item()
            scores = model(h_t, t_t).squeeze(0)
            pred = scores.argmax().item()
            correct += (pred == true_r)

            _, idxs = scores.sort(descending=True)
            rank = (idxs == true_r).nonzero(as_tuple=False).item() + 1
            ranks.append(rank)
    acc = correct / len(data_tensor)
    ranks = torch.tensor(ranks, dtype=torch.float)
    mrr = (1.0/ranks).mean().item()
    hits = {f"Hits@{k}": (ranks<=k).float().mean().item() for k in ks}
    return acc, mrr, hits


dev_acc, dev_mrr, dev_hits = eval_relation_metrics(dev_data)
test_acc, test_mrr, test_hits = eval_relation_metrics(test_data)
print(f"Dev ▶ Acc={dev_acc:.4f} MRR={dev_mrr:.4f} Hits@1={dev_hits['Hits@1']:.4f} Hits@3={dev_hits['Hits@3']:.4f}")
print(f"Test ▶ Acc={test_acc:.4f} MRR={test_mrr:.4f} Hits@1={test_hits['Hits@1']:.4f} Hits@3={test_hits['Hits@3']:.4f}")

Dev ▶ Acc=0.3696 MRR=0.5736 Hits@1=0.3696 Hits@3=0.7391
Test ▶ Acc=0.3261 MRR=0.5557 Hits@1=0.3261 Hits@3=0.7826
