In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
from itertools import product

import torch
import dotenv
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)
from tqdm import tqdm

from llm_ol.dataset import data_model
from llm_ol.utils import batch

dotenv.load_dotenv()

In [None]:
G = data_model.load_graph("out/data/wikipedia/v2/train_eval_split/train_graph.json")
nodes = list(G.nodes())
edges = list(G.edges())

In [None]:
def title(node):
    return G.nodes[node]["title"]


def sample_batch(batch_size: int):
    pos_samples = random.sample(edges, batch_size)
    neg_samples = []
    for _ in range(batch_size):
        while True:
            src, dst = random.sample(nodes, 2)
            if not G.has_edge(src, dst):
                neg_samples.append((src, dst))
                break

    pos_samples = [(title(src), title(dst)) for src, dst in pos_samples]
    neg_samples = [(title(src), title(dst)) for src, dst in neg_samples]
    samples = pos_samples + neg_samples
    labels = [1] * batch_size + [0] * batch_size
    return samples, labels

In [None]:
# model_id = "bert-base-uncased"
model_id = "out/experiments/link_prediction/debug/checkpoint-1000"
model = AutoModelForSequenceClassification.from_pretrained(
    model_id, num_labels=2, device_map="cuda", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
samples, labels = sample_batch(1)
heads, tails = zip(*samples)

inputs = tokenizer(heads, tails, return_tensors="pt", padding=True, truncation=True)

with torch.no_grad():
    output = model(**inputs, labels=torch.tensor(labels))

output.logits.shape, output.loss.shape

In [None]:
G_test = data_model.load_graph("out/data/wikipedia/v2/train_test_split/test_graph.json")
nodes_test = list(G_test.nodes())

In [None]:
weights = []
for uv_batch in batch(
    tqdm(product(nodes_test, nodes_test), total=len(nodes_test) ** 2), 2048
):
    us, vs = zip(*uv_batch)
    inputs = tokenizer(
        [G_test.nodes[u]["title"] for u in us],
        [G_test.nodes[v]["title"] for v in vs],
        return_tensors="pt",
        padding=True,
        truncation=True,
    ).to(model.device)
    with torch.no_grad():
        output = model(**inputs)
        probs = torch.softmax(output.logits, dim=1)
        weights.append(probs[:, 0])