In [1]:
import json
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
from torch.optim import AdamW

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_triples(path):
    triples = []
    with open(path, 'r') as f:
        for line in f:
            h, r, t = line.strip().split()[:3]
            triples.append((h, r, t))
    return triples

with open('../graph_data/entities.json', 'r') as f:
    ent2text = json.load(f)


train_triples = load_triples("../graph_data/train.tsv")
dev_triples =  load_triples("../graph_data/dev.tsv")
test_triples =  load_triples("../graph_data/test.tsv")

relations = sorted({r for _, r, _ in train_triples + dev_triples + test_triples})
rel2id = {r: i for i, r in enumerate(relations)}

In [3]:
train_triples

[('I05-2014_E2151', '0', 'I05-2014_E2159'),
 ('C92-4199_E3803', '4', 'C92-4199_E3804'),
 ('CVPR_2016_406_abs_E2724', '3', 'CVPR_2016_406_abs_E2743'),
 ('C90-3072_E3316', '0', 'C90-3072_E3317'),
 ('NIPS_2014_10_abs_E4516', '0', 'NIPS_2014_10_abs_E4517'),
 ('CVPR_2003_10_abs_E4809', '4', 'CVPR_2003_10_abs_E4808'),
 ('A92-1026_E1305', '0', 'A92-1026_E1304'),
 ('NIPS_2003_10_abs_E3747', '4', 'NIPS_2003_10_abs_E3746'),
 ('NIPS_2015_21_abs_E473', '0', 'NIPS_2015_21_abs_E481'),
 ('NIPS_2015_10_abs_E3328', '2', 'NIPS_2015_10_abs_E3327'),
 ('P01-1004_E504', '0', 'P01-1004_E508'),
 ('A00-1024_E4674', '5', 'A00-1024_E4665'),
 ('ICCV_2005_47_abs_E1146', '2', 'ICCV_2005_47_abs_E1137'),
 ('H90-1016_E1528', '3', 'H90-1016_E1529'),
 ('INTERSPEECH_2014_40_abs_E6358', '0', 'INTERSPEECH_2014_40_abs_E6357'),
 ('C88-1007_E1215', '0', 'C88-1007_E1216'),
 ('ECCV_2012_37_abs_E1937', '0', 'ECCV_2012_37_abs_E1936'),
 ('N04-1024_E4401', '0', 'N04-1024_E4395'),
 ('E91-1043_E3719', '4', 'E91-1043_E3720'),
 ('H01-1

In [4]:
ent2text

{'J87-1003_E0': {'canonical': 'strictly syntactic cross-serial agreement',
  'mentions': [[17, 20], [23, 23]]},
 'J87-1003_E1': {'canonical': 'English', 'mentions': [[0, 0]]},
 'J87-1003_E2': {'canonical': 'coordinations', 'mentions': [[10, 10]]},
 'J87-1003_E3': {'canonical': 'nouns', 'mentions': [[29, 29]]},
 'J87-1003_E4': {'canonical': 'reflexive pronouns', 'mentions': [[31, 32]]},
 'J87-1003_E5': {'canonical': 'grammatical number', 'mentions': [[42, 43]]},
 'J87-1003_E6': {'canonical': 'English', 'mentions': [[45, 45]]},
 'J87-1003_E7': {'canonical': 'grammatical gender', 'mentions': [[48, 49]]},
 'J87-1003_E8': {'canonical': 'languages', 'mentions': [[51, 51]]},
 'J87-1003_E9': {'canonical': 'French', 'mentions': [[54, 54]]},
 'J87-1003_E10': {'canonical': 'Interchange Lemma', 'mentions': [[70, 71]]},
 'J87-1003_E11': {'canonical': 'English', 'mentions': [[86, 86]]},
 'CVPR_2003_18_abs_E12': {'canonical': 'object intrinsic representation',
  'mentions': [[62, 64], [90, 91], [96, 

In [5]:
def make_example(h, r, t, label):
    head_text = ent2text[h]['canonical']
    tail_text = ent2text[t]['canonical']
    encoded = tokenizer(head_text, r + ' ' + tail_text,
                        truncation=True, padding='max_length', max_length=64,
                        return_tensors='pt')
    for k in encoded:
        encoded[k] = encoded[k].squeeze(0)
    encoded['labels'] = torch.tensor(label, dtype=torch.long)
    return encoded

class KGBertDataset(Dataset):
    def __init__(self, triples, neg_ratio=1):
        self.examples = []
        ents = list(ent2text.keys())
        for h, r, t in triples:
            self.examples.append(make_example(h, r, t, 1))
            for _ in range(neg_ratio):
                if random.random() < 0.5:
                    h2 = random.choice(ents)
                    self.examples.append(make_example(h2, r, t, 0))
                else:
                    t2 = random.choice(ents)
                    self.examples.append(make_example(h, r, t2, 0))

    def __len__(self): return len(self.examples)
    def __getitem__(self, idx): return self.examples[idx]


def collate_fn(batch):
    keys = batch[0].keys()
    return {k: torch.stack([ex[k] for ex in batch]) for k in keys}

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
model.to('cuda')
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
train_loader = DataLoader(KGBertDataset(train_triples, neg_ratio=1), batch_size=16, shuffle=True, collate_fn=collate_fn)
dev_loader   = DataLoader(KGBertDataset(dev_triples, neg_ratio=0), batch_size=16, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(KGBertDataset(test_triples, neg_ratio=0), batch_size=16, shuffle=False, collate_fn=collate_fn)

In [9]:
# Training loop
for epoch in range(3):
    model.train()
    for batch in train_loader:
        batch = {k: v.cuda() for k, v in batch.items()}
        loss = model(**batch).loss
        print(loss)
        optimizer.zero_grad(); loss.backward(); optimizer.step()


    model.eval()
    correct = total = 0
    for batch in dev_loader:
        labels = batch['labels'].cuda()
        inputs = {k: v.cuda() for k, v in batch.items() if k!='labels'}
        with torch.no_grad():
            preds = model(**inputs).logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    print(f'Epoch {epoch}: Dev Accuracy = {correct/total:.4f}')


tensor(0.7303, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.7337, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.6639, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.6865, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.7123, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.5782, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.7228, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.6856, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.7699, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.7867, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.7385, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.6346, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.7614, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.7481, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.7032, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.6789, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.6226, device='cuda:0', grad_fn=

In [None]:
# Ranking evaluation function
def evaluate_ranking(loader, num_neg=100, k_values=(1,3,10)):
    model.eval()
    ranks = []
    with torch.no_grad():
        for batch in loader:

            input_ids = batch['input_ids'].cuda()
            attention_mask = batch['attention_mask'].cuda()
            token_type_ids = batch['token_type_ids'].cuda()

            break
    # Loop over test_triples
    for h, r, t in test_triples:
        # positive
        pos = make_example(h, r, t, 1)
        pos = {k: v.unsqueeze(0).cuda() for k, v in pos.items() if k!='labels'}
        pos_score = model(**pos).logits[:,1]
        # negatives
        scores = [pos_score.item()]
        ents = list(ent2text.keys())
        for _ in range(num_neg):
            if random.random() < 0.5:
                h2 = random.choice(ents); t2 = t
            else:
                h2 = h; t2 = random.choice(ents)
            neg = make_example(h2, r, t2, 0)
            neg = {k: v.unsqueeze(0).cuda() for k, v in neg.items() if k!='labels'}
            scores.append(model(**neg).logits[:,1].item())
        # compute rank
        sorted_scores = sorted(scores, reverse=True)
        rank = sorted_scores.index(scores[0]) + 1
        ranks.append(rank)
    ranks = torch.tensor(ranks, dtype=torch.float)
    mrr = (1.0/ranks).mean().item()
    hits = {f'Hits@{k}': (ranks <= k).float().mean().item() for k in k_values}
    return mrr, hits

# Test ranking eval
mrr, hits = evaluate_ranking(test_loader)
print(f'Test Ranking → MRR: {mrr:.4f}, ' + ', '.join([f"{k}: {v:.4f}" for k, v in hits.items()]))

Test Ranking → MRR: 0.1930, Hits@1: 0.1087, Hits@3: 0.1522, Hits@10: 0.4348
