In [None]:
import json, torch, random
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from tqdm.auto import tqdm

def load_triples(path):
    triples = []
    with open(path, 'r') as f:
        for line in f:
            h, r, t = line.strip().split()[:3]
            triples.append((h, r, t))
    return triples

with open('../graph_data/entities.json', 'r') as f:
    ent2text = json.load(f)


train_triples = load_triples("../graph_data/train.tsv")
dev_triples =  load_triples("../graph_data/dev.tsv")
test_triples =  load_triples("../graph_data/test.tsv")

# Build relation vocabulary
relations = sorted({r for _, r, _ in train_triples + dev_triples + test_triples})
rel2id = {r: i for i, r in enumerate(relations)}


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
class KGRelDataset(Dataset):
    def __init__(self, triples):
        self.examples = []
        for h, r, t in triples:
            head_txt = ent2text[h]['canonical']
            tail_txt = ent2text[t]['canonical']
            enc = tokenizer(
                head_txt,
                tail_txt,
                padding='max_length', truncation=True, max_length=64,
                return_tensors='pt'
            )
            enc = {k: v.squeeze(0) for k, v in enc.items()}
            enc['labels'] = torch.tensor(rel2id[r], dtype=torch.long)
            self.examples.append(enc)

    def __len__(self):
        return len(self.examples)
    def __getitem__(self, idx):
        return self.examples[idx]


def load_triples(path):
    with open(path) as f:
        return [tuple(line.strip().split()[:3]) for line in f]


train_triples = load_triples("../graph_data/train.tsv")
dev_triples =  load_triples("../graph_data/dev.tsv")
test_triples =  load_triples("../graph_data/test.tsv")

batch_size = 16
train_ds = KGRelDataset(train_triples)
dev_ds   = KGRelDataset(dev_triples)
test_ds  = KGRelDataset(test_triples)

def collate_fn(batch):
    out = {}
    for k in batch[0].keys():
        out[k] = torch.stack([ex[k] for ex in batch])
    return out

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
dev_loader   = DataLoader(dev_ds,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, collate_fn=collate_fn)



  from .autonotebook import tqdm as notebook_tqdm


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', num_labels=7
).to(device)
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4)

for epoch in range(1, 6):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch} Training", leave=False):
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * labels.size(0)
    # if epoch % 10 == 0:
    print(f"Epoch {epoch:02d} | Train Loss: {total_loss/len(train_ds):.4f}")



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                                   

Epoch 01 | Train Loss: 1.3830


                                                                   

Epoch 02 | Train Loss: 1.0273


                                                                   

Epoch 03 | Train Loss: 0.7544


                                                                   

Epoch 04 | Train Loss: 0.5336


                                                                   

Epoch 05 | Train Loss: 0.3588




In [None]:

@torch.no_grad()
def eval_metrics(loader):
    model.eval()
    correct = 0
    ranks = []
    for batch in loader:
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        logits = model(**inputs).logits 
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        for i, true_r in enumerate(labels):
            scores = logits[i]
            _, idxs = torch.sort(scores, descending=True)
            rank = (idxs == true_r).nonzero(as_tuple=False).item() + 1
            ranks.append(rank)
    acc = correct / len(loader.dataset)
    ranks = torch.tensor(ranks, dtype=torch.float, device=device)
    mrr = (1.0 / ranks).mean().item()
    hits = {f"Hits@{k}": (ranks <= k).float().mean().item() for k in (1,3,10)}
    return acc, mrr, hits

dev_acc, dev_mrr, dev_hits = eval_metrics(dev_loader)
print(f"Dev ▶ Acc={dev_acc:.4f}, MRR={dev_mrr:.4f}, H@10={dev_hits['Hits@10']:.4f}")
test_acc, test_mrr, test_hits = eval_metrics(test_loader)
print(f"Test Acc: {test_acc:.4f} | Test MRR: {test_mrr:.4f} | Hits@1: {test_hits['Hits@1']:.4f} | Hits@3: {test_hits['Hits@3']:.4f}")


Dev ▶ Acc=0.7391, MRR=0.8232, H@10=1.0000
Test Acc: 0.6304 | Test MRR: 0.7460 | Hits@1: 0.6304 | Hits@3: 0.8043
