In [1]:
"""
eval_compare_mistral.py

Compare base Mistral vs LoRA-tuned Mistral on a small eval set and
produce plots for a presentation.

Requires:
  pip install --upgrade torch transformers accelerate bitsandbytes datasets matplotlib peft
"""

# =========================================================
# 1. CONFIG & IMPORTS
# =========================================================
import os
import gc
import random
from typing import List

import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
from peft import PeftModel
import matplotlib.pyplot as plt

# ----------- EDIT THESE PATHS / SETTINGS AS NEEDED ------------------

# Base model you fine-tuned
MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"

# Folder with your exported LoRA adapter (the one we just created)
ADAPTER_DIR = "/blue/cis6930/cj.bowers/llama3_project/outputs/mistral_ai_tutor_lora"

# Eval set: JSONL with the *same format as your training data*:
#   {"text": "###System: ... ###Human: ... ###Assistant: ideal answer ..."}
EVAL_JSONL = "/blue/cis6930/cj.bowers/llama3_project/data/ai_tutor_merged.jsonl"

# How many examples to evaluate on (random subset)
MAX_EVAL_EXAMPLES = 400  # set to 300–500 as you like

# Where to save plots
PLOTS_DIR = "/blue/cis6930/cj.bowers/llama3_project/outputs/plots"

# Generation settings
MAX_NEW_TOKENS = 160
TEMPERATURE = 0.7
TOP_P = 0.9

# Marker used in your data to separate prompt from answer
ASSISTANT_MARK = "###Assistant:"

os.makedirs(PLOTS_DIR, exist_ok=True)
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# =========================================================
# 2. HELPER: LOAD EVAL DATA & SPLIT PROMPT / REFERENCE
# =========================================================

def load_eval_examples(path: str):
    """
    Load JSONL with field 'text' and split into (prompt, reference_answer)
    using ASSISTANT_MARK, same as training format.
    """
    ds = load_dataset("json", data_files=path, split="train")
    prompts = []
    refs = []

    for row in ds:
        text = row["text"]
        if ASSISTANT_MARK in text:
            before, after = text.split(ASSISTANT_MARK, 1)
            prompt = before.strip()
            ref = after.strip()
        else:
            # fallback: whole text is prompt, reference empty
            prompt = text.strip()
            ref = ""
        prompts.append(prompt)
        refs.append(ref)

    print(f"Loaded {len(prompts)} eval examples from {path}")
    return prompts, refs


# =========================================================
# 3. HELPER: SIMPLE TOKEN-LEVEL F1 SCORE
# =========================================================

def token_f1(pred: str, ref: str) -> float:
    """
    Very simple F1 on whitespace tokens; not perfect but fine for comparison.
    """
    pred_tokens = pred.lower().split()
    ref_tokens = ref.lower().split()
    if not ref_tokens:
        return 0.0
    pred_set = set(pred_tokens)
    ref_set = set(ref_tokens)
    inter = pred_set & ref_set
    if not inter:
        return 0.0
    precision = len(inter) / len(pred_set) if pred_set else 0.0
    recall = len(inter) / len(ref_set) if ref_set else 0.0
    if precision + recall == 0:
        return 0.0
    return 2 * precision * recall / (precision + recall)


# =========================================================
# 4. MODEL LOADING HELPERS (4-BIT TO FIT ON L4)
# =========================================================

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

def cleanup():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def load_base_model():
    print("\n[LOAD] Base model (4-bit, no LoRA)…")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        device_map={"": 0},   # all modules on GPU 0 in 4-bit
    )
    model.eval()
    return model

def load_tuned_model():
    print("\n[LOAD] Tuned model (4-bit + LoRA)…")
    base = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        device_map={"": 0},
    )
    model = PeftModel.from_pretrained(base, ADAPTER_DIR)
    model.eval()
    return model


# =========================================================
# 5. GENERATION LOOP
# =========================================================

def generate_answers(model, tokenizer, prompts: List[str]) -> List[str]:
    answers = []
    for i, prompt_text in enumerate(prompts, start=1):
        if i % 10 == 0 or i == 1:
            print(f"  Generating {i}/{len(prompts)}…")

        # prompt_text already contains '###Human:' etc from your data
        full_prompt = prompt_text + "\n" + ASSISTANT_MARK

        inputs = tokenizer(
            full_prompt,
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=1024,
        ).to(device)

        with torch.no_grad():
            gen = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=TEMPERATURE,
                top_p=TOP_P,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
            )

        text = tokenizer.decode(gen[0], skip_special_tokens=True)
        # extract part after ASSISTANT_MARK as model's answer
        if ASSISTANT_MARK in text:
            _, answer = text.split(ASSISTANT_MARK, 1)
        else:
            answer = text
        answers.append(answer.strip())
    return answers


# =========================================================
# 6. MAIN EVALUATION LOGIC
# =========================================================

def main():
    # ----- load eval data -----
    prompts, refs = load_eval_examples(EVAL_JSONL)

    # ---- random subset selection ----
    combined = list(zip(prompts, refs))
    random.seed(42)  # reproducible
    random.shuffle(combined)
    n = min(MAX_EVAL_EXAMPLES, len(combined))
    combined = combined[:n]
    prompts, refs = zip(*combined)
    prompts, refs = list(prompts), list(refs)
    print(f"Using {n} random examples for evaluation.\n")

    # tokenizer from adapter dir (includes your special tokens)
    tokenizer = AutoTokenizer.from_pretrained(
        ADAPTER_DIR,
        use_fast=True,
        padding_side="left",
    )

    # ----- BASE MODEL -----
    base_model = load_base_model()
    base_answers = generate_answers(base_model, tokenizer, prompts)
    base_scores = [token_f1(p, r) for p, r in zip(base_answers, refs)]
    base_mean = float(sum(base_scores) / len(base_scores))
    print(f"\n[RESULT] Base model mean F1: {base_mean:.3f}")

    del base_model
    cleanup()

    # ----- TUNED MODEL -----
    tuned_model = load_tuned_model()
    tuned_answers = generate_answers(tuned_model, tokenizer, prompts)
    tuned_scores = [token_f1(p, r) for p, r in zip(tuned_answers, refs)]
    tuned_mean = float(sum(tuned_scores) / len(tuned_scores))
    print(f"[RESULT] Tuned model mean F1: {tuned_mean:.3f}")

    del tuned_model
    cleanup()

    # =====================================================
    # 7. PLOTS FOR PRESENTATION
    # =====================================================
    import numpy as np

    improvements = [t - b for b, t in zip(tuned_scores, base_scores)]

    # --- 1) Bar chart of average scores ---
    plt.figure()
    plt.bar(["Base", "Tuned"], [base_mean, tuned_mean])
    plt.ylabel("Mean token F1 vs reference")
    plt.title(f"Base vs LoRA-tuned Mistral (mean F1 on {n} examples)")
    plt.ylim(0, 1.0)
    plt.savefig(os.path.join(PLOTS_DIR, "mean_f1_base_vs_tuned.png"), bbox_inches="tight")
    plt.close()

    # --- 2) Histogram of per-example improvement ---
    plt.figure()
    plt.hist(improvements, bins=20)
    plt.xlabel("F1(Tuned) - F1(Base)")
    plt.ylabel("# examples")
    plt.title("Per-example improvement (positive = tuned is better)")
    plt.axvline(0.0, color="black", linestyle="--", linewidth=1)
    plt.savefig(os.path.join(PLOTS_DIR, "improvement_histogram.png"), bbox_inches="tight")
    plt.close()

    # --- 3) Boxplot of F1 distributions ---
    plt.figure()
    plt.boxplot(
        [base_scores, tuned_scores],
        labels=["Base", "Tuned"],
        showmeans=True,
    )
    plt.ylabel("Token F1")
    plt.title("Distribution of F1 scores (Base vs Tuned)")
    plt.ylim(0, 1.0)
    plt.savefig(os.path.join(PLOTS_DIR, "f1_boxplot_base_vs_tuned.png"), bbox_inches="tight")
    plt.close()

    # --- 4) Scatter: base vs tuned F1 ---
    plt.figure()
    plt.scatter(base_scores, tuned_scores, alpha=0.5, s=10)
    plt.xlabel("Base F1")
    plt.ylabel("Tuned F1")
    plt.title("Per-example F1: Base vs Tuned")
    plt.plot([0, 1], [0, 1], linestyle="--")  # diagonal line
    plt.xlim(0, 1.0)
    plt.ylim(0, 1.0)
    plt.savefig(os.path.join(PLOTS_DIR, "f1_scatter_base_vs_tuned.png"), bbox_inches="tight")
    plt.close()

    print(f"\nPlots saved to: {PLOTS_DIR}")
    print("  - mean_f1_base_vs_tuned.png")
    print("  - improvement_histogram.png")
    print("  - f1_boxplot_base_vs_tuned.png")
    print("  - f1_scatter_base_vs_tuned.png")


if __name__ == "__main__":
    main()


Using device: cuda
Loaded 4410183 eval examples from /blue/cis6930/cj.bowers/llama3_project/data/ai_tutor_merged.jsonl
Using 400 random examples for evaluation.


[LOAD] Base model (4-bit, no LoRA)…


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  Generating 1/400…
  Generating 10/400…
  Generating 20/400…
  Generating 30/400…
  Generating 40/400…
  Generating 50/400…
  Generating 60/400…
  Generating 70/400…
  Generating 80/400…
  Generating 90/400…
  Generating 100/400…
  Generating 110/400…
  Generating 120/400…
  Generating 130/400…
  Generating 140/400…
  Generating 150/400…
  Generating 160/400…
  Generating 170/400…
  Generating 180/400…
  Generating 190/400…
  Generating 200/400…
  Generating 210/400…
  Generating 220/400…
  Generating 230/400…
  Generating 240/400…
  Generating 250/400…
  Generating 260/400…
  Generating 270/400…
  Generating 280/400…
  Generating 290/400…
  Generating 300/400…
  Generating 310/400…
  Generating 320/400…
  Generating 330/400…
  Generating 340/400…
  Generating 350/400…
  Generating 360/400…
  Generating 370/400…
  Generating 380/400…
  Generating 390/400…
  Generating 400/400…

[RESULT] Base model mean F1: 0.294

[LOAD] Tuned model (4-bit + LoRA)…


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  Generating 1/400…
  Generating 10/400…
  Generating 20/400…
  Generating 30/400…
  Generating 40/400…
  Generating 50/400…
  Generating 60/400…
  Generating 70/400…
  Generating 80/400…
  Generating 90/400…
  Generating 100/400…
  Generating 110/400…
  Generating 120/400…
  Generating 130/400…
  Generating 140/400…
  Generating 150/400…
  Generating 160/400…
  Generating 170/400…
  Generating 180/400…
  Generating 190/400…
  Generating 200/400…
  Generating 210/400…
  Generating 220/400…
  Generating 230/400…
  Generating 240/400…
  Generating 250/400…
  Generating 260/400…
  Generating 270/400…
  Generating 280/400…
  Generating 290/400…
  Generating 300/400…
  Generating 310/400…
  Generating 320/400…
  Generating 330/400…
  Generating 340/400…
  Generating 350/400…
  Generating 360/400…
  Generating 370/400…
  Generating 380/400…
  Generating 390/400…
  Generating 400/400…
[RESULT] Tuned model mean F1: 0.329

Plots saved to: /blue/cis6930/cj.bowers/llama3_project/outputs/plots
  -

  plt.boxplot(
