# Compare Base vs Trained Models

This notebook loads helper functions from `scripts/compare_models.py` and lets you
interactively configure models, prompts, and evaluation settings.


In [None]:
# Imports and helpers
import json
from types import SimpleNamespace
from pathlib import Path

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from scripts.compare_models import (
    CompareConfig,
    load_model_and_tokenizer,
    generate_text,
    compute_logprobs,
)

# Configure here
cfg = CompareConfig(
    base_model="/path/to/original",
    trained_model="/path/to/trained",
    prompts_path=None,  # or a path to .jsonl/.json with {"prompt": ...}
    max_new_tokens=128,
    temperature=0.0,
    top_p=1.0,
    top_k=0,
    device="auto",
    dtype="float16",
    batch_size=8,
    eos_token=None,
)

# Load prompts
if cfg.prompts_path is None:
    prompts = [
        "Write a short poem about the sea.",
        "What is the capital of France?",
        "Explain the concept of reinforcement learning in one paragraph.",
        "Translate to Spanish: 'Good morning, how are you?'",
    ]
else:
    p = Path(cfg.prompts_path)
    if p.suffix == ".jsonl":
        prompts = [json.loads(line)["prompt"] for line in p.open()]
    else:  # .json
        items = json.load(p.open())
        prompts = [item["prompt"] if isinstance(item, dict) else item for item in items]

# Load models
base_model, base_tok = load_model_and_tokenizer(cfg.base_model, cfg.device, cfg.dtype)
trained_model, trained_tok = load_model_and_tokenizer(cfg.trained_model, cfg.device, cfg.dtype)

# Optional: enforce same eos_token
if cfg.eos_token is not None:
    base_tok.eos_token = cfg.eos_token
    trained_tok.eos_token = cfg.eos_token

# Generate
base_outputs = generate_text(base_model, base_tok, prompts, cfg)
trained_outputs = generate_text(trained_model, trained_tok, prompts, cfg)

# Compute log-probs of generated responses under each model
base_logps = compute_logprobs(base_model, base_tok, prompts, base_outputs, cfg)
trained_logps = compute_logprobs(trained_model, trained_tok, prompts, trained_outputs, cfg)

# Aggregate per-sample metrics
rows = []
for i, prompt in enumerate(prompts):
    base_nonzero = (base_logps[i] != 0).sum().clamp(min=1).item()
    trained_nonzero = (trained_logps[i] != 0).sum().clamp(min=1).item()
    rows.append({
        "prompt": prompt,
        "base_text": base_outputs[i],
        "trained_text": trained_outputs[i],
        "base_mean_logp": float(base_logps[i].sum().item() / base_nonzero),
        "trained_mean_logp": float(trained_logps[i].sum().item() / trained_nonzero),
    })

rows[:2]  # preview
