# Compare model checkpoints

This notebook loads multiple models from `checkpoints/` and shows their responses side by side using a chat-style prompt.

In [None]:
import os
import gc
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer


In [None]:
prompt = "Explain the difference between supervised learning and reinforcement learning."
system_prompt = "You are a helpful assistant."
model_paths = [
    "checkpoints/llama3.1-8b-hard",
    "checkpoints/llama3.1-8b-soft",
]
max_new_tokens = 256
hf_token = os.environ.get("HF_TOKEN")


In [None]:
def build_chat_prompt(tokenizer, messages):
    if getattr(tokenizer, "chat_template", None):
        return tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
    lines = []
    for msg in messages:
        role = msg.get("role", "user")
        if role == "system":
            prefix = "System"
        elif role == "assistant":
            prefix = "Assistant"
        else:
            prefix = "User"
        lines.append(f"{prefix}: {msg['content']}")
    lines.append("Assistant:")
    return "\n".join(lines)

def generate_response(model_id, messages):
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    prompt_text = build_chat_prompt(tokenizer, messages)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        token=hf_token,
        dtype=torch.bfloat16,
        device_map="auto",
    )
    inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
        )
    prompt_len = inputs["input_ids"].shape[-1]
    response = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True).strip()
    del model
    torch.cuda.empty_cache()
    gc.collect()
    return response


In [None]:
messages = []
if system_prompt:
    messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})

rows = []
for model_id in model_paths:
    print("Running", model_id)
    response = generate_response(model_id, messages)
    rows.append({"model": model_id, "response": response})

pd.DataFrame(rows)
