In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer

from src.go_ast_tokenizer.dataset_builder import LABEL_NAMES
from src.go_ast_tokenizer.train import HParams, Llama3Classifier

DEVICE = "mps"
CHECKPOINT = "path to checkpoint"

params = HParams()
model = Llama3Classifier.load_from_checkpoint(CHECKPOINT, params=params)
model.eval()
model.to(DEVICE)

tokenizer = AutoTokenizer.from_pretrained(params.model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
data = load_dataset("aholovko/go-critic-style", split="test")
class_label = data.features["labels"].feature

In [None]:
idx = 10

snippet = data["code"][idx]
labels = data["labels"][idx]

print(snippet)
print(f"labels: {[class_label.int2str(label) for label in labels]}")

In [None]:
inputs = tokenizer(
    snippet,
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=params.max_length,
)

inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

with torch.no_grad():
    logits = model(inputs["input_ids"], inputs["attention_mask"])

probs = torch.sigmoid(logits)
preds = (probs > 0.5).long()

for label, prob, pred in zip(
    LABEL_NAMES,
    probs[0].cpu().tolist(),
    preds[0].cpu().tolist(),
    strict=False,
):
    print(f"{label:<25}{prob:>8.4f}   {'✔' if pred else '✘'}")