In [15]:
#!/usr/bin/env python3
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
import argparse
import random


In [16]:


def load_triples(path):
    triples = []
    with open(path, 'r') as f:
        for line in f:
            h, r, t = line.strip().split()[:3]
            triples.append((h, r, t))
    return triples


class DistMultModel(torch.nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim):
        super().__init__()
        self.entity_emb = torch.nn.Embedding(num_entities, embedding_dim)
        self.rel_emb = torch.nn.Embedding(num_relations, embedding_dim)
        torch.nn.init.xavier_uniform_(self.entity_emb.weight)
        torch.nn.init.xavier_uniform_(self.rel_emb.weight)

    def forward(self, h, r, t):
        # Score: <e_h, R_r, e_t>
        return (self.entity_emb(h) * self.rel_emb(r) * self.entity_emb(t)).sum(dim=-1)


def negative_sampling(batch, num_entities):
    # Corrupt head or tail at random
    neg = batch.clone()
    mask = torch.rand(batch.size(0), device=batch.device) < 0.5
    # corrupt head
    num_h = mask.sum().item()
    neg[mask, 0] = torch.randint(0, num_entities, (num_h,), device=batch.device)
    # corrupt tail
    num_t = (~mask).sum().item()
    neg[~mask, 2] = torch.randint(0, num_entities, (num_t,), device=batch.device)
    return neg


def train_epoch(model, optimizer, dataloader, num_entities, device):
    model.train()
    total_loss = 0
    dataset_size = 0
    for batch_tuple in dataloader:
        batch = batch_tuple[0].to(device)
        dataset_size += batch.size(0)
        pos = batch
        neg = negative_sampling(pos, num_entities)

        pos_score = model(pos[:, 0], pos[:, 1], pos[:, 2])
        neg_score = model(neg[:, 0], neg[:, 1], neg[:, 2])

        loss = - (F.logsigmoid(pos_score).mean() + F.logsigmoid(-neg_score).mean())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.size(0)

    return total_loss / dataset_size


@torch.no_grad()
def evaluate_loss(model, dataloader, device):
    model.eval()
    total_loss = 0

    for batch_tuple in dataloader:
        batch = batch_tuple[0].to(device)
        score = model(batch[:, 0], batch[:, 1], batch[:, 2])
        loss = -F.logsigmoid(score).mean()
        total_loss += loss.item() * batch.size(0)

    return total_loss / len(dataloader.dataset)


@torch.no_grad()
def evaluate_ranking(model, data_tensor, num_entities, device, num_neg=100, k_values=(1,3,10)):
    """
    Approximate ranking evaluation: for each positive triple, sample num_neg negatives,
    compute rank of the positive, then aggregate MRR and Hits@k.
    """
    model.eval()
    ranks = []
    for triple in data_tensor:
        h, r, t = triple
        h, r, t = h.to(device), r.to(device), t.to(device)
        pos_score = model(h.unsqueeze(0), r.unsqueeze(0), t.unsqueeze(0))

        # generate negatives
        neg_h = torch.randint(0, num_entities, (num_neg,), device=device)
        neg_t = torch.randint(0, num_entities, (num_neg,), device=device)
        mask = torch.rand(num_neg, device=device) < 0.5
        heads = torch.where(mask, neg_h, h.expand(num_neg))
        tails = torch.where(~mask, neg_t, t.expand(num_neg))
        rels = r.expand(num_neg)
        neg_score = model(heads, rels, tails)

        scores = torch.cat([pos_score, neg_score])
        # higher score is better
        ranks_i = torch.argsort(scores, descending=True)
        rank = (ranks_i == 0).nonzero(as_tuple=False).item() + 1
        ranks.append(rank)

    ranks = torch.tensor(ranks, dtype=torch.float, device=device)
    mrr = (1.0 / ranks).mean().item()
    hits = {(f"Hits@{k}"): (ranks <= k).float().mean().item() for k in k_values}
    return mrr, hits

In [17]:
# Load triples
train_triples = load_triples("../graph_data/train.tsv")
dev_triples =  load_triples("../graph_data/dev.tsv")
test_triples =  load_triples("../graph_data/test.tsv")

# Build vocab
entities = set()
relations = set()
for h, r, t in train_triples + dev_triples:
    entities.update([h, t]); relations.add(r)
ent2id = {e: i for i, e in enumerate(sorted(entities))}
rel2id = {r: i for i, r in enumerate(sorted(relations))}



In [18]:
# Encode triples
def encode(triples):
    return torch.tensor([[ent2id[h], rel2id[r], ent2id[t]] for h, r, t in triples], dtype=torch.long)

train_data = encode(train_triples)

dev_data = encode(dev_triples)

test_data = encode(test_triples)

In [19]:
embedding_dim = 100 
batch_size = 512
lr = 5e-4 
weight_decay = 1e-4 
epochs = 500
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data loaders
train_loader = DataLoader(TensorDataset(train_data), batch_size=512, shuffle=True)
dev_loader = DataLoader(TensorDataset(dev_data), batch_size=512)

# Model & optimizer
num_entities = len(ent2id)
num_relations = len(rel2id)
model = DistMultModel(num_entities, num_relations,  embedding_dim).to( device)
optimizer = AdamW(model.parameters(), lr= lr, weight_decay= weight_decay)

In [20]:
# Training loop
best_dev_loss = float('inf')
for epoch in range(1,  epochs + 1):
    train_loss = train_epoch(model, optimizer, train_loader, num_entities, device)
    dev_loss = evaluate_loss(model, dev_loader, device)
    print(f"Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | Dev Loss: {dev_loss:.4f}")

    # Early stopping on dev loss
    if dev_loss < best_dev_loss:
        best_dev_loss = dev_loss
        torch.save(model.state_dict(), 'best_distmult.pt')

Epoch 001 | Train Loss: 1.3863 | Dev Loss: 0.6932
Epoch 002 | Train Loss: 1.3861 | Dev Loss: 0.6932
Epoch 003 | Train Loss: 1.3859 | Dev Loss: 0.6932
Epoch 004 | Train Loss: 1.3857 | Dev Loss: 0.6932
Epoch 005 | Train Loss: 1.3855 | Dev Loss: 0.6932
Epoch 006 | Train Loss: 1.3853 | Dev Loss: 0.6932
Epoch 007 | Train Loss: 1.3850 | Dev Loss: 0.6931
Epoch 008 | Train Loss: 1.3847 | Dev Loss: 0.6931
Epoch 009 | Train Loss: 1.3843 | Dev Loss: 0.6931
Epoch 010 | Train Loss: 1.3839 | Dev Loss: 0.6931
Epoch 011 | Train Loss: 1.3835 | Dev Loss: 0.6930
Epoch 012 | Train Loss: 1.3829 | Dev Loss: 0.6930
Epoch 013 | Train Loss: 1.3823 | Dev Loss: 0.6929
Epoch 014 | Train Loss: 1.3815 | Dev Loss: 0.6929
Epoch 015 | Train Loss: 1.3807 | Dev Loss: 0.6928
Epoch 016 | Train Loss: 1.3797 | Dev Loss: 0.6927
Epoch 017 | Train Loss: 1.3786 | Dev Loss: 0.6925
Epoch 018 | Train Loss: 1.3774 | Dev Loss: 0.6924
Epoch 019 | Train Loss: 1.3759 | Dev Loss: 0.6922
Epoch 020 | Train Loss: 1.3743 | Dev Loss: 0.6919


In [23]:
# Final evaluation on dev set ranking
model.load_state_dict(torch.load('best_distmult.pt'))
mrr, hits = evaluate_ranking(model, test_data, num_entities, device,  k_values=(1,3,5))
print(f"Dev Ranking → MRR: {mrr:.4f}, " + ", ".join([f"{k}: {v:.4f}" for k,v in hits.items()]))

Dev Ranking → MRR: 0.3187, Hits@1: 0.2826, Hits@3: 0.3043, Hits@5: 0.3478


  model.load_state_dict(torch.load('best_distmult.pt'))


In [32]:
  # Final evaluation on dev set ranking
model.load_state_dict(torch.load('best_distmult.pt'))
mrr, hits = evaluate_ranking(model, test_data, num_entities, device, num_neg=100, k_values=(1,3,10))
print(f"Dev Ranking → MRR: {mrr:.4f}, " + ", ".join([f"{k}: {v:.4f}" for k,v in hits.items()]))

Dev Ranking → MRR: 0.3047, Hits@1: 0.2609, Hits@3: 0.3043, Hits@10: 0.3696


  model.load_state_dict(torch.load('best_distmult.pt'))
