In [11]:
import json
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers import BertModel
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score
)

In [12]:
datapath = "../graph_data/"
# 1. Load files
with open(datapath+'entities.json', 'r') as f:
    entities = json.load(f)             # { entity_str: { 'canonical': text, ... }, ... }
with open(datapath+'relation2id.json', 'r') as f:
    rel2id = json.load(f)               # { relation_str: relation_id, ... }

# invert mappings
entity2id = {ent: i for i, ent in enumerate(entities.keys())}
id2entity = {i: ent for ent, i in entity2id.items()}
id2rel    = {rid: rel for rel, rid in rel2id.items()}

In [13]:
def load_triples(path):
    triples = []
    with open(path) as f:
        for line in f:
            h_str, r_id_str, t_str = line.strip().split('\t')
            triples.append((h_str, int(r_id_str), t_str))
    return triples



train_triples = load_triples(datapath+"triples_train.tsv")
dev_triples   = load_triples(datapath+"triples_dev.tsv")
test_triples  = load_triples(datapath+"triples_test.tsv")

In [14]:
import random

# def negative_sample_tails(triples, num_entities, neg_rate=1):
#     augmented = []
#     for h, r, t in triples:
#         # positive
#         augmented.append((h, r, t, 1))
#         # k negatives by replacing t
#         for _ in range(neg_rate):
#             t_neg = random.randrange(num_entities)
#             while t_neg == t:
#                 t_neg = random.randrange(num_entities)
#             augmented.append((h, r, t_neg, 0))
#     return augmented

# train_examples = negative_sample_tails(train_triples, len(entity2id), neg_rate=1)
# # dev/test: no negatives, so label=1 for all
# dev_examples  = [(h, r, t, 1) for (h, r, t) in dev_triples]
# test_examples = [(h, r, t, 1) for (h, r, t) in test_triples]


def negative_sample_tails(triples, neg_rate=1):
    out = []
    num_entities = len(entity2id)
    for h, r, t in triples:
        out.append((h, r, t, 1))
        for _ in range(neg_rate):
            t_neg = random.randrange(num_entities)
            # map back to name
            t_neg_str = list(entity2id.keys())[t_neg]
            while t_neg_str == t:
                t_neg = random.randrange(num_entities)
                t_neg_str = list(entity2id.keys())[t_neg]
            out.append((h, r, t_neg_str, 0))
    return out

train_examples = negative_sample_tails(train_triples, neg_rate=1)
dev_examples = negative_sample_tails(dev_triples, neg_rate=1)
test_examples = negative_sample_tails(test_triples, neg_rate=1)
# dev/test: no negatives, so label=1 for all
# dev_examples  = [(h, r, t, 1) for (h, r, t) in dev_triples]
# test_examples = [(h, r, t, 1) for (h, r, t) in test_triples]


In [15]:
from torch.utils.data import Dataset, DataLoader

class KGTailPredDataset(Dataset):
    def __init__(self, examples, entities, entity2id, id2rel, tokenizer, max_len=128):
        self.examples   = examples
        self.entities   = entities
        self.entity2id  = entity2id
        self.id2rel     = id2rel
        self.tokenizer  = tokenizer
        self.max_len    = max_len

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        h_str, r_id, t_str, label = self.examples[idx]
        h_text = self.entities[h_str]['canonical']
        r_text = self.id2rel[r_id]
        t_text = self.entities[t_str]['canonical']

        seq = f"{h_text} [SEP] {r_text} [SEP] {t_text}"
        enc = self.tokenizer(
            seq,
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )
        item = {k: v.squeeze(0) for k, v in enc.items()}
        item.update({
            'head_id':     torch.tensor(self.entity2id[h_str], dtype=torch.long),
            'relation_id': torch.tensor(r_id,                  dtype=torch.long),
            'tail_id':     torch.tensor(self.entity2id[t_str], dtype=torch.long),
            'label':       torch.tensor(label,                 dtype=torch.float),
        })
        return item



In [16]:
# 5. Instantiate
tokenizer   = BertTokenizer.from_pretrained('bert-base-uncased')

train_ds    = KGTailPredDataset(train_examples, entities, entity2id, id2rel, tokenizer)
dev_ds      = KGTailPredDataset(dev_examples,   entities, entity2id, id2rel, tokenizer)
test_ds     = KGTailPredDataset(test_examples,  entities, entity2id, id2rel, tokenizer)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
dev_loader   = DataLoader(dev_ds,   batch_size=64, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False)


In [17]:
class KGBertTailPredictor(nn.Module):
    def __init__(self, pretrained_model="bert-base-uncased", dropout=0.1):
        super().__init__()
        # 1) Backbone
        self.bert = BertModel.from_pretrained(pretrained_model)
        # 2) A small head on [CLS]
        self.dropout = nn.Dropout(dropout)
        # Single logit → score “true” vs “false”
        self.classifier = nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        """
        - input_ids etc. come from your TailPredDataset
        - labels: a float tensor of 0/1 (only in training)
        Returns dict with:
         • logits: shape (batch,)
         • loss: BCEWithLogitsLoss if labels provided
        """
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        # pooled_output is the [CLS] representation
        pooled = outputs.pooler_output            # (batch, hidden_size)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled).squeeze(-1)  # (batch,)

        out = {"logits": logits}
        if labels is not None:
            loss_fn = nn.BCEWithLogitsLoss()
            out["loss"] = loss_fn(logits, labels.float())
        return out

In [18]:
def train_epoch():
    model.train()
    losses = []
    for batch in tqdm(train_loader, desc="Train"):
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(
            input_ids      = batch["input_ids"],
            attention_mask = batch["attention_mask"],
            token_type_ids = batch["token_type_ids"],
            labels         = batch["label"],
        )
        loss = out["loss"]
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        losses.append(loss.item())
    return np.mean(losses)

@torch.no_grad()
def eval_binary(loader):
    model.eval()
    all_labels = []
    all_probs  = []
    for batch in tqdm(loader, desc="Eval (binary)"):
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(
            input_ids      = batch["input_ids"],
            attention_mask = batch["attention_mask"],
            token_type_ids = batch["token_type_ids"],
        )
        logits = out["logits"]                 # (batch_size,)
        probs  = torch.sigmoid(logits).cpu().numpy()
        labels = batch["label"].cpu().numpy()

        all_probs.extend(probs)
        all_labels.extend(labels)

    # threshold at 0.5
    preds = [1 if p > 0.5 else 0 for p in all_probs]

    acc = accuracy_score(all_labels, preds)
    prec, rec, f1, _ = precision_recall_fscore_support(
        all_labels, preds, average="binary"
    )
    auc = roc_auc_score(all_labels, all_probs)
    return acc, prec, rec, f1, auc



In [19]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = KGBertTailPredictor().to(device)

num_epochs = 10
optimizer = AdamW(model.parameters(), lr=5e-5)
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps,
)


In [20]:
# --- Main loop ---
best_dev_auc = 0.0
for epoch in range(1, num_epochs + 1):
    avg_loss = train_epoch()
    acc, prec, rec, f1, auc = eval_binary(dev_loader)

    print(f"Epoch {epoch:02d} ▶ train_loss={avg_loss:.4f}  "
          f"dev_acc={acc:.4f}  precision={prec:.4f}  recall={rec:.4f}  f1={f1:.4f}  auc={auc:.4f}")

    # save best by AUC (or any metric you choose)
    if auc > best_dev_auc:
        best_dev_auc = auc
        torch.save(model.state_dict(), "best_kgbt_tailpred.pt")




Train: 100%|██████████| 202/202 [00:32<00:00,  6.30it/s]
Eval (binary): 100%|██████████| 15/15 [00:01<00:00,  9.20it/s]


Epoch 01 ▶ train_loss=0.6942  dev_acc=0.5736  precision=0.5827  recall=0.5187  f1=0.5488  auc=0.6000


Train: 100%|██████████| 202/202 [00:32<00:00,  6.24it/s]
Eval (binary): 100%|██████████| 15/15 [00:01<00:00,  8.95it/s]


Epoch 02 ▶ train_loss=0.6730  dev_acc=0.5758  precision=0.5739  recall=0.5890  f1=0.5813  auc=0.6038


Train: 100%|██████████| 202/202 [00:32<00:00,  6.29it/s]
Eval (binary): 100%|██████████| 15/15 [00:01<00:00,  9.25it/s]


Epoch 03 ▶ train_loss=0.6169  dev_acc=0.6077  precision=0.6109  recall=0.5934  f1=0.6020  auc=0.6441


Train: 100%|██████████| 202/202 [00:32<00:00,  6.28it/s]
Eval (binary): 100%|██████████| 15/15 [00:01<00:00,  9.15it/s]


Epoch 04 ▶ train_loss=0.4894  dev_acc=0.5681  precision=0.5714  recall=0.5451  f1=0.5579  auc=0.6107


Train: 100%|██████████| 202/202 [00:32<00:00,  6.28it/s]
Eval (binary): 100%|██████████| 15/15 [00:01<00:00,  9.16it/s]


Epoch 05 ▶ train_loss=0.3399  dev_acc=0.5791  precision=0.5763  recall=0.5978  f1=0.5868  auc=0.6214


Train: 100%|██████████| 202/202 [00:32<00:00,  6.25it/s]
Eval (binary): 100%|██████████| 15/15 [00:01<00:00,  9.15it/s]


Epoch 06 ▶ train_loss=0.2179  dev_acc=0.5879  precision=0.5985  recall=0.5341  f1=0.5645  auc=0.6276


Train: 100%|██████████| 202/202 [00:32<00:00,  6.17it/s]
Eval (binary): 100%|██████████| 15/15 [00:01<00:00,  9.11it/s]


Epoch 07 ▶ train_loss=0.1259  dev_acc=0.5659  precision=0.5714  recall=0.5275  f1=0.5486  auc=0.6099


Train: 100%|██████████| 202/202 [00:32<00:00,  6.20it/s]
Eval (binary): 100%|██████████| 15/15 [00:01<00:00, 10.19it/s]


Epoch 08 ▶ train_loss=0.0763  dev_acc=0.5758  precision=0.5860  recall=0.5165  f1=0.5491  auc=0.6165


Train: 100%|██████████| 202/202 [00:31<00:00,  6.43it/s]
Eval (binary): 100%|██████████| 15/15 [00:01<00:00, 10.11it/s]


Epoch 09 ▶ train_loss=0.0527  dev_acc=0.5692  precision=0.5734  recall=0.5407  f1=0.5566  auc=0.6127


Train: 100%|██████████| 202/202 [00:31<00:00,  6.45it/s]
Eval (binary): 100%|██████████| 15/15 [00:01<00:00, 10.11it/s]

Epoch 10 ▶ train_loss=0.0347  dev_acc=0.5725  precision=0.5778  recall=0.5385  f1=0.5575  auc=0.6099





In [None]:
# --- Final test evaluation ---
model.load_state_dict(torch.load("best_kgbt_tailpred.pt"))
acc, prec, rec, f1, auc = eval_binary(test_loader)
print(f"Test ▶ acc={acc:.4f}  precision={prec:.4f}  recall={rec:.4f}  f1={f1:.4f}  auc={auc:.4f}")

  model.load_state_dict(torch.load("best_kgbt_tailpred.pt"))
Eval (binary): 100%|██████████| 31/31 [00:03<00:00,  9.34it/s]

Test ▶ acc=0.5637  precision=0.5628  recall=0.5708  f1=0.5668  auc=0.6064





Not used

In [None]:
# def train_epoch():
#     model.train()
#     losses = []
#     for batch in tqdm(train_loader, desc="Train"):
#         batch = {k: v.to(device) for k, v in batch.items()}
#         out = model(
#             input_ids      = batch["input_ids"],
#             attention_mask = batch["attention_mask"],
#             token_type_ids = batch["token_type_ids"],
#             labels         = batch["label"],
#         )
#         loss = out["loss"]
#         loss.backward()
#         optimizer.step()
#         scheduler.step()
#         optimizer.zero_grad()
#         losses.append(loss.item())
#     return np.mean(losses)


# @torch.no_grad()
# def evaluate(triples):
#     model.eval()
#     ranks = []
#     # for each true (h,r,t), score all candidate tails
#     for h_id, r_id, t_id in tqdm(triples, desc="Eval"):
#         h_text = entities[h_id]["canonical"]
#         r_text = id2rel[r_id]

#         # build all candidate sequences in one batch
#         seqs = [
#             f"{h_text} [SEP] {r_text} [SEP] {entities[id2entity[t2]]['canonical']}"
#             for t2 in range(len(entity2id))
#         ]
#         enc = tokenizer(
#             seqs,
#             padding=True,
#             truncation=True,
#             max_length=128,
#             return_tensors="pt",
#         ).to(device)

#         logits = model(
#             input_ids      = enc.input_ids,
#             attention_mask = enc.attention_mask,
#             token_type_ids = enc.token_type_ids,
#         )["logits"]              # (num_entities,)

#         scores = torch.sigmoid(logits).cpu()
#         # get descending ranking
#         sorted_idx = torch.argsort(scores, descending=True)
#         rank = (sorted_idx == entity2id[t_id]).nonzero(as_tuple=False).item() + 1
#         ranks.append(rank)

#     ranks = np.array(ranks)
#     mrr    = np.mean(1.0 / ranks)
#     hits1  = np.mean(ranks <= 1)
#     hits3  = np.mean(ranks <= 3)
#     hits10 = np.mean(ranks <= 10)
#     return mrr, {"hits@1": hits1, "hits@3": hits3, "hits@10": hits10}

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = KGBertTailPredictor("distilbert-base-uncased").to(device)

# num_epochs = 1
# # optimizer + optional scheduler
# optimizer = AdamW(model.parameters(), lr=5e-5)
# total_steps = len(train_loader) * num_epochs
# scheduler = get_linear_schedule_with_warmup(
#     optimizer,
#     num_warmup_steps=int(0.1 * total_steps),
#     num_training_steps=total_steps,
# )


You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Some weights of BertModel were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['embeddings.LayerNorm.bias', 'embeddings.LayerNorm.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.word_embeddings.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.inte

In [None]:
# # --- Main training loop ---
# best_dev_mrr = 0.0
# for epoch in range(1, num_epochs+1):
#     # avg_train_loss = train_epoch()
#     dev_mrr, dev_hits = evaluate(dev_triples)
#     print(f"Epoch {epoch:02d}  train_loss={avg_train_loss:.4f}  dev_mrr={dev_mrr:.4f}  hits@10={dev_hits['hits@10']:.4f}")

#     # save best
#     if dev_mrr > best_dev_mrr:
#         best_dev_mrr = dev_mrr
#         torch.save(model.state_dict(), "best_kgbt_tailpred.pt")




Eval:   0%|          | 0/455 [00:00<?, ?it/s]

Eval:   0%|          | 1/455 [01:06<8:23:10, 66.50s/it]

In [None]:
# # --- Final test evaluation ---
# model.load_state_dict(torch.load("best_kgbt_tailpred.pt"))
# test_mrr, test_hits = evaluate(test_triples)
# print(f"Test MRR: {test_mrr:.4f}  Hits@1: {test_hits['hits@1']:.4f}  Hits@3: {test_hits['hits@3']:.4f}  Hits@10: {test_hits['hits@10']:.4f}")