In [11]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
import numpy as np

In [12]:
datapath = "../graph_data/"
# 1. Load files
with open(datapath+'entities.json', 'r') as f:
    entities = json.load(f)             # { entity_str: { 'canonical': text, ... }, ... }
with open(datapath+'relation2id.json', 'r') as f:
    rel2id = json.load(f)               # { relation_str: relation_id, ... }

# Create mappings
entity_names = list(entities.keys())
entity2id = {name: idx for idx, name in enumerate(entity_names)}
id2entity = {idx: name for name, idx in entity2id.items()}
id2rel = {rid: rel for rel, rid in rel2id.items()}

In [13]:
# Load splits
def load_triples(path):
    triples = []
    with open(path) as f:
        for line in f:
            h, r_str, t = line.strip().split('\t')
            triples.append((h, int(r_str), t))
    return triples

train_triples = load_triples(datapath+"triples_train.tsv")
dev_triples   = load_triples(datapath+"triples_dev.tsv")
test_triples  = load_triples(datapath+"triples_test.tsv")

In [14]:
# === 2. Dataset for two-tower ===
class TwoTowerDataset(Dataset):
    def __init__(self, triples, entities, id2rel, entity2id, tokenizer, max_len=128):
        self.triples = triples
        self.entities = entities
        self.id2rel = id2rel
        self.entity2id = entity2id
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        h_str, r_id, t_str = self.triples[idx]
        h_text = self.entities[h_str]['canonical']
        r_text = self.id2rel[r_id]
        t_text = self.entities[t_str]['canonical']

        qr_seq = f"{h_text} [SEP] {r_text}"
        qc_seq = t_text

        qr_enc = self.tokenizer(qr_seq, truncation=True, padding='max_length',
                                 max_length=self.max_len, return_tensors='pt')
        qc_enc = self.tokenizer(qc_seq, truncation=True, padding='max_length',
                                 max_length=self.max_len, return_tensors='pt')

        return {
            'qr_input_ids':      qr_enc.input_ids.squeeze(0),
            'qr_attention_mask': qr_enc.attention_mask.squeeze(0),
            'qr_token_type_ids': qr_enc.token_type_ids.squeeze(0),
            'qc_input_ids':      qc_enc.input_ids.squeeze(0),
            'qc_attention_mask': qc_enc.attention_mask.squeeze(0),
            'qc_token_type_ids': qc_enc.token_type_ids.squeeze(0),
            'tail_id':           torch.tensor(entity2id[t_str], dtype=torch.long),
        }


In [15]:
# === 3. Two-tower model ===
class TwoTowerKGBert(nn.Module):
    def __init__(self, pretrained='bert-base-uncased', dropout=0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained(pretrained)
        self.dropout = nn.Dropout(dropout)

    def encode(self, input_ids, attention_mask, token_type_ids):
        out = self.bert(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)
        pooled = out.pooler_output  # [CLS] embedding
        return self.dropout(pooled)

    def forward(self, qr_input_ids, qr_attention_mask, qr_token_type_ids,
                      qc_input_ids, qc_attention_mask, qc_token_type_ids):
        qr_emb = self.encode(qr_input_ids, qr_attention_mask, qr_token_type_ids)  # (B,d)
        qc_emb = self.encode(qc_input_ids, qc_attention_mask, qc_token_type_ids)  # (B,d)
        # similarity matrix: (B, B)
        return torch.matmul(qr_emb, qc_emb.T)

In [16]:
# === 4. Prepare DataLoaders ===
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_ds = TwoTowerDataset(train_triples, entities, id2rel, entity2id, tokenizer)
dev_ds   = TwoTowerDataset(dev_triples,   entities, id2rel, entity2id, tokenizer)
test_ds  = TwoTowerDataset(test_triples,  entities, id2rel, entity2id, tokenizer)

# 2. Instantiate DataLoaders
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
dev_loader   = DataLoader(dev_ds,   batch_size=64, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False)


In [None]:
# === 5. Initialize model, optimizer ===
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TwoTowerKGBert().to(device)
optimizer = optim.AdamW(model.parameters(), lr=5e-4)



In [18]:
# === 6. Precompute all tail embeddings ===
model.eval()
all_tail_texts = [entities[name]['canonical'] for name in entity_names]
tail_enc = tokenizer(all_tail_texts, truncation=True, padding='longest',
                     return_tensors='pt').to(device)
with torch.no_grad():
    all_tail_embs = model.encode(tail_enc.input_ids,
                                 tail_enc.attention_mask,
                                 tail_enc.token_type_ids)  # (|E|, d)

In [19]:
# === 7. Training + evaluation loops ===
def train_epoch(train_loader):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc="Training"):
        optimizer.zero_grad()
        # Move to device
        batch = {k: v.to(device) for k, v in batch.items()}
        logits = model(batch['qr_input_ids'],
                       batch['qr_attention_mask'],
                       batch['qr_token_type_ids'],
                       batch['qc_input_ids'],
                       batch['qc_attention_mask'],
                       batch['qc_token_type_ids'])  # (B, B)

        labels = torch.arange(logits.size(0), device=device)
        loss = nn.CrossEntropyLoss()(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(train_loader)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    ranks = []

    for batch in tqdm(loader, desc="Evaluating"):
        # move to GPU
        for k,v in batch.items():
            batch[k] = v.to(device)
        # encode all queries in the batch
        qr_emb = model.encode(
            batch['qr_input_ids'],
            batch['qr_attention_mask'],
            batch['qr_token_type_ids']
        )  # (B, d)

        # for each example in the batch, score against all tails
        for i, true_t in enumerate(batch['tail_id']):
            q = qr_emb[i]                           # (d,)
            scores = all_tail_embs @ q             # (|E|,)
            sorted_idx = torch.argsort(scores, descending=True)
            rank = (sorted_idx == true_t).nonzero(as_tuple=True)[0].item() + 1
            ranks.append(rank)

    ranks = np.array(ranks)
    mrr   = np.mean(1.0 / ranks)
    hits1 = np.mean(ranks <= 1)
    hits3 = np.mean(ranks <= 3)
    hits10= np.mean(ranks <= 10)
    return mrr, hits1, hits3, hits10

In [22]:
# Main loop
num_epochs = 10
best_mrr = 0
for epoch in range(1, num_epochs + 1):
    avg_loss = train_epoch(train_loader)
    print("Evaluation")
    dev_mrr, dev_h1, dev_h3, dev_h10 = evaluate(dev_loader)
    print(f"Epoch {epoch} ▶ loss={avg_loss:.4f}  dev_MRR={dev_mrr:.4f}  Hits@10={dev_h10:.4f}")
    if dev_mrr > best_mrr:
        best_mrr = dev_mrr
        torch.save(model.state_dict(), "best_two_tower.pt")

# Final test
model.load_state_dict(torch.load("best_two_tower.pt"))
test_mrr, test_h1, test_h3, test_h10 = evaluate(test_loader)
print(f"Test ▶ MRR={test_mrr:.4f}  Hits@1={test_h1:.4f}  Hits@3={test_h3:.4f}  Hits@10={test_h10:.4f}")


Training: 100%|██████████| 101/101 [00:30<00:00,  3.31it/s]


Evaluation


Evaluating: 100%|██████████| 8/8 [00:01<00:00,  7.50it/s]


Epoch 1 ▶ loss=5.0709  dev_MRR=0.0013  Hits@10=0.0022


Training: 100%|██████████| 101/101 [00:30<00:00,  3.29it/s]


Evaluation


Evaluating: 100%|██████████| 8/8 [00:00<00:00,  8.81it/s]


Epoch 2 ▶ loss=4.1680  dev_MRR=0.0012  Hits@10=0.0000


Training: 100%|██████████| 101/101 [00:28<00:00,  3.50it/s]


Evaluation


Evaluating: 100%|██████████| 8/8 [00:00<00:00,  9.30it/s]


Epoch 3 ▶ loss=3.7513  dev_MRR=0.0011  Hits@10=0.0000


Training: 100%|██████████| 101/101 [00:28<00:00,  3.57it/s]


Evaluation


Evaluating: 100%|██████████| 8/8 [00:00<00:00,  9.56it/s]


Epoch 4 ▶ loss=3.5986  dev_MRR=0.0010  Hits@10=0.0000


Training: 100%|██████████| 101/101 [00:29<00:00,  3.46it/s]


Evaluation


Evaluating: 100%|██████████| 8/8 [00:01<00:00,  7.93it/s]


Epoch 5 ▶ loss=3.5032  dev_MRR=0.0011  Hits@10=0.0000


Training: 100%|██████████| 101/101 [00:29<00:00,  3.47it/s]


Evaluation


Evaluating: 100%|██████████| 8/8 [00:01<00:00,  7.90it/s]


Epoch 6 ▶ loss=3.4946  dev_MRR=0.0010  Hits@10=0.0000


Training: 100%|██████████| 101/101 [00:28<00:00,  3.52it/s]


Evaluation


Evaluating: 100%|██████████| 8/8 [00:00<00:00,  8.07it/s]


Epoch 7 ▶ loss=3.4794  dev_MRR=0.0009  Hits@10=0.0000


Training: 100%|██████████| 101/101 [00:28<00:00,  3.55it/s]


Evaluation


Evaluating: 100%|██████████| 8/8 [00:00<00:00,  8.18it/s]


Epoch 8 ▶ loss=3.4584  dev_MRR=0.0014  Hits@10=0.0000


Training: 100%|██████████| 101/101 [00:28<00:00,  3.51it/s]


Evaluation


Evaluating: 100%|██████████| 8/8 [00:01<00:00,  7.41it/s]


Epoch 9 ▶ loss=3.5044  dev_MRR=0.0012  Hits@10=0.0000


Training: 100%|██████████| 101/101 [00:28<00:00,  3.57it/s]


Evaluation


Evaluating: 100%|██████████| 8/8 [00:01<00:00,  7.77it/s]
  model.load_state_dict(torch.load("best_two_tower.pt"))


Epoch 10 ▶ loss=3.4254  dev_MRR=0.0011  Hits@10=0.0000


Evaluating: 100%|██████████| 16/16 [00:02<00:00,  7.18it/s]

Test ▶ MRR=0.0025  Hits@1=0.0010  Hits@3=0.0010  Hits@10=0.0031



