In [None]:
# ===============================
# 🔧 Setup: Install Packages
# ===============================
!pip install -q \
  "transformers>=4.41,<5" \
  "datasets==2.19.1" \
  "peft==0.10.0" \
  "accelerate>=0.34.2" \
  "bitsandbytes>=0.43.3" \
  "scikit-learn" \
  "openpyxl" \
  "pandas"

In [None]:
import torch, sys, subprocess
mm = ".".join(torch.__version__.split(".")[:2])
triton_by_torch = {"2.5":"3.2.0","2.4":"3.0.0","2.3":"2.3.1","2.2":"2.2.0"}
target = triton_by_torch.get(mm, "3.2.0")
print(f"Torch {torch.__version__} → Installing Triton {target}")
subprocess.check_call([sys.executable, "-m", "pip", "install", f"triton=={target}"])

In [None]:
# ===============================
# Import packages & login
# ===============================
from google.colab import drive
drive.mount('/content/drive')

# ===============================
# Determinism preamble (FIRST cell, before importing torch)
# ===============================
import os

# Required by PyTorch for cuBLAS determinism on CUDA >= 10.2
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"   # or ":16:8" if you prefer

# Optional: makes hashing stable across processes
os.environ["PYTHONHASHSEED"] = "0"

import random, torch, pandas as pd, numpy as np
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
    TrainingArguments, Trainer, set_seed
)
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from huggingface_hub import login

# --------------- Hugging Face token ---------------
os.environ["HF_TOKEN"] = "YOUR_TOKEN_HERE"
login(os.environ["HF_TOKEN"])

# --------------- Reproducibility ---------------
set_seed(42)

In [None]:
# =========================================================
# Generate 80× responses per cue on validation split
#      • fine-tuned model
#      • base model
#      + save raw outputs, conversation history, and prompts
#      + batched via num_return_sequences with OOM fallback
# =========================================================

import os, gc, re, time, math, random, hashlib
import numpy as np
import torch
import pandas as pd
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# ---------- Reproducibility & deterministic sampling --------------
GLOBAL_SEED = 42
random.seed(GLOBAL_SEED)
np.random.seed(GLOBAL_SEED)
import torch

torch.manual_seed(GLOBAL_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(GLOBAL_SEED)

# Turn off TF32 to avoid numerics shifting between runs/GPUs
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

# Deterministic/cuDNN settings
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Enforce deterministic algorithms (now safe because we set the env var above)
torch.use_deterministic_algorithms(True)
# --------------------------------------------------------------------
# change if you use another path
BASE_PATH = r"/content/drive/My Drive/associations-ANLP"

TEST_XLSX_PATH  = os.path.join(BASE_PATH, r"data/final_processed_SWOW_data/test.xlsx")

# IMPORTANT: if you re-finetuned, update FINETUNED_MODEL to your current run's merged or adapter path
FINETUNED_MODEL = os.path.join(BASE_PATH, r"full_llama3_8b_system_prompt_lora_SFT_SWOW_tgt_qkvo_tr7194c_val899c_r16_a32_do0p1_lr0.0001_bs16_ga4/merged_model")
BASE_MODEL  = "meta-llama/Meta-Llama-3-8B-Instruct"

SAVE_DIR_FT   = os.path.join(BASE_PATH, r"data/fine_tuned_and_base_models_comparisons/free_associations_task")
os.makedirs(SAVE_DIR_FT, exist_ok=True)

SAVE_DIR_BASE = os.path.join(BASE_PATH, r"data/fine_tuned_and_base_models_comparisons/free_associations_task")
os.makedirs(SAVE_DIR_BASE, exist_ok=True)
# --------------------------------------------------------------------

# ---------- Precision / quant (use the SAME for both models) --------
USE_4BIT = False  # set True if you prefer 4-bit generation for both

def _bf16_supported():
    return torch.cuda.is_available() and getattr(torch.cuda, "is_bf16_supported", lambda: False)()

if USE_4BIT:
    load_kwargs = dict(
        device_map="auto" if torch.cuda.is_available() else None,
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16 if _bf16_supported() else torch.float16,
        ),
        use_safetensors=True,
    )
else:
    if torch.cuda.is_available():
        # A100 → bf16; T4 → fp16
        _dtype = torch.bfloat16 if _bf16_supported() else torch.float16
        load_kwargs = dict(device_map="auto", torch_dtype=_dtype, use_safetensors=True)
    else:
        load_kwargs = dict(torch_dtype=torch.float32, use_safetensors=True)

# ---------- Sampling knobs (unchanged except batching) ---------------
NUM_SAMPLES_PER_CUE = 80
TRY_AT_ONCE         = 80   # first attempt per call; auto-shrinks on OOM
MAX_NEW_TOKENS      = 50
TEMPERATURE         = 0.7
TOP_P               = 0.9

SYSTEM_PROMPT = """\
Task:
 - You will be provided with an input word: write the first 3 words you associate to it separated by a comma.
 - No additional output text is allowed.

Constraints:
 - no carriage return characters are allowed in the answers.
 - answers should be as short as possible.

Example:
 Input: sea
 Output: water, beach, sun"""

def build_prompt_text(tokenizer, cue: str) -> str:
    """Render messages with the official Llama-3 chat template."""
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user",   "content": str(cue).strip()},
    ]
    # This appends the assistant header so generation starts in assistant turn
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def build_history(cue: str, assistant_text: str) -> str:
    """Readable transcript (system → user → assistant)."""
    return (
        "<system>\n"   + SYSTEM_PROMPT.strip() + "\n</system>\n"
        "<user>\n"     + str(cue).strip()      + "\n</user>\n"
        "<assistant>\n"+ (assistant_text or "").strip() + "\n</assistant>"
    )

def extract_responses(text: str, n=3):
    """
    Extract up to n words from a comma-separated generation.
    We *don’t* expect 'Answer:' anymore. Normalize lightly.
    """
    # Keep letters, digits, spaces, hyphens, and commas; lowercase
    clean = re.sub(r"[^\w,\- ]", "", text.lower())
    # robust split for ",", " ,", ", " etc.
    words = [w.strip() for w in re.split(r"\s*,\s*", clean) if w.strip()]
    return (words + ["", "", ""])[:n]

def _per_cue_seed(cue_index: int) -> int:
    # Stable per-cue seed independent of Python's hash randomization
    return (GLOBAL_SEED * 1_000_003 + cue_index) & 0x7FFFFFFF

def _generate_many_for_cue(model, tok, cue: str, cue_index: int) -> list[dict]:
    """
    Generate NUM_SAMPLES_PER_CUE samples for one cue using num_return_sequences,
    with per-batch seeding that makes results reproducible and invariant to OOM chunking.
    (No `generator=` kwarg is used, so it's compatible with more transformers builds.)
    """
    prompt_text = build_prompt_text(tok, cue)
    prompt_ids  = tok(prompt_text, return_tensors="pt").to(model.device)

    recs = []
    remaining   = NUM_SAMPLES_PER_CUE
    try_at_once = TRY_AT_ONCE
    produced    = 0

    # Stable per-cue base seed
    base_seed = (GLOBAL_SEED * 1_000_003 + cue_index) & 0x7FFFFFFF

    while remaining > 0:
        cur_n = min(try_at_once, remaining)

        # ---- Repro trick: reseed RNG per *batch* based only on how many
        # samples we've already produced for this cue. This makes outcomes
        # independent of how we chunk (OOM backoff) and fully reproducible.
        seed_this_batch = base_seed + produced
        torch.manual_seed(seed_this_batch)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed_this_batch)

        try:
            with torch.no_grad():
                out = model.generate(
                    **prompt_ids,
                    max_new_tokens=MAX_NEW_TOKENS,
                    do_sample=True, temperature=TEMPERATURE, top_p=TOP_P,
                    pad_token_id=tok.eos_token_id,
                    eos_token_id=tok.eos_token_id,
                    num_return_sequences=cur_n,
                )
        except torch.cuda.OutOfMemoryError:
            torch.cuda.empty_cache()
            try_at_once = max(1, try_at_once // 2)
            if cur_n == 1 and try_at_once == 1:
                raise
            continue

        prompt_len = prompt_ids["input_ids"].shape[1]
        for i in range(out.size(0)):
            gen_only = out[i, prompt_len:]
            raw = tok.decode(gen_only, skip_special_tokens=True)
            r1, r2, r3 = extract_responses(raw)
            recs.append(dict(
                cue=cue,
                R1=r1, R2=r2, R3=r3,
                raw_output=raw,
                history=build_history(cue, raw),
                system_prompt=SYSTEM_PROMPT.strip(),
                user_prompt=str(cue).strip(),
            ))

        produced  += cur_n
        remaining -= cur_n

    return recs

def generate_dataset(model_id_or_path: str, tag: str):
    """Return list[dict]: {cue, R1, R2, R3, raw_output, history, system_prompt, user_prompt} for the val cues."""
    print(f"\n🔹 Loading {tag} model …")
    tok = AutoTokenizer.from_pretrained(model_id_or_path, use_fast=True)
    tok.pad_token = tok.eos_token

    model = AutoModelForCausalLM.from_pretrained(model_id_or_path, **load_kwargs)
    model.eval()

    # ---------- read validation cues from Excel ----------
    val_df = pd.read_excel(VAL_XLSX_PATH)
    cues = val_df["cue"].dropna().astype(str).str.strip().unique()
    recs = []

    for idx, cue in enumerate(tqdm(cues, desc=f"{tag} {NUM_SAMPLES_PER_CUE}× per cue")):
        recs.extend(_generate_many_for_cue(model, tok, cue, idx))
        if torch.cuda.is_available():
            torch.cuda.empty_cache()  # light hygiene between cues

    # -------- cleanup to free VRAM --------
    del model; del tok; gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return recs

# ---------- Fine-tuned model -----------------------------------
t0 = time.time()
ft_rows = generate_dataset(FINETUNED_MODEL, "Finetuned")
ft_tag  = os.path.basename(os.path.normpath(FINETUNED_MODEL))  # folder name like 'merged_model' or run dir
ft_path = os.path.join(SAVE_DIR_FT, f"full_val_ft_{ft_tag}_predictions.xlsx")
pd.DataFrame(ft_rows).to_excel(ft_path, index=False)  # ← Excel
print(f"Fine-tuned predictions saved → {ft_path} ({len(ft_rows):,} rows)")
print(f"{time.time()-t0:.1f}s elapsed")

# ---------- Base model -----------------------------------------
t1 = time.time()
base_rows = generate_dataset(BASE_MODEL, "Base")
base_path = os.path.join(SAVE_DIR_BASE, "llama3_8b_full_val_base_predictions.xlsx")
pd.DataFrame(base_rows).to_excel(base_path, index=False)  # ← Excel
print(f"Base-model predictions saved → {base_path} ({len(base_rows):,} rows)")
print(f"{time.time()-t1:.1f}s elapsed (second run)")


In [None]:
"""
Evaluate fine-tuned (FT) vs. base LLaMA-3 model predictions against human
word association data (SWOW val) on a cue-by-cue basis.

This script:
-----------
1. Loads three datasets:
    • Human associations (val.xlsx)
    • Fine-tuned model predictions (80 runs per cue, .xlsx)
    • Base model predictions (80 runs per cue, .xlsx)

2. Normalizes all responses (lowercase, strip, keep only letters).

3. Converts triplets (R1, R2, R3) into a "long" format for counting.

4. Builds frequency distributions for each cue:
    • Human distribution: counts of each unique word from humans
    • FT model distribution: pooled counts over all runs
    • Base model distribution: pooled counts over all runs

5. For each cue, computes:
    • Top-K overlap (K=10): Precision@K, Recall@K, F1@K, Jaccard@K
    • Cosine similarity between full distributions
    • Jensen–Shannon divergence (JSD) between full distributions
    • Spearman rank correlation of shared vocabulary
    • Shannon entropy of each distribution (diversity)
    • Jaccard over ALL unique responses (set overlap, no frequency)

6. Also computes row-level Hit@3:
    • For each model output triplet, check if ≥1 word is in human set.

7. Aggregates results:
    • Macro averages (mean per cue) with bootstrap 95% CIs for FT and Base
    • Δ (FT mean – Base mean) for each metric

8. Outputs:
    • Prints headline macro metrics for quick comparison (covers ALL paired metrics dynamically)
    • Saves full per-cue metrics table to XLSX for deeper analysis
    • Prints sample top-10 lists for first few cues for qualitative inspection

Purpose:
--------
To quantitatively measure how closely each model’s associations match
human associations, both in terms of exact word overlap and in distributional
similarity, enabling a clear comparison between a fine-tuned and a base model.
"""

import os, re, math, json, numpy as np, pandas as pd
from collections import Counter, defaultdict

# ---------- paths (Excel) ----------
HUM_PATH  = os.path.join(BASE_PATH, r"data/final_processed_SWOW_data/val.xlsx")
FT_PATH   = os.path.join(BASE_PATH, r"data/fine_tuned_and_base_models_comparisons/free_associations_task/full_val_ft_merged_model_predictions.xlsx")
BASE_PATH = os.path.join(BASE_PATH, r"data/fine_tuned_and_base_models_comparisons/free_associations_task/llama3_8b_full_val_base_predictions.xlsx")

# ---------- helpers ----------
def _norm(w):
    return re.sub(r"[^a-z]", "", str(w).strip().lower())

def melt_triplets(df):
    cols = [c for c in df.columns if c.lower() in {"cue","r1","r2","r3"}]
    df = df[cols].copy()
    df.columns = [c.lower() for c in df.columns]
    for c in ["cue","r1","r2","r3"]:
        df[c] = df[c].map(_norm)
    long = df.melt(id_vars=["cue"], value_vars=["r1","r2","r3"], value_name="resp").dropna()
    long = long[long["resp"] != ""]
    return long

def counts_by_cue(long_df):
    by = defaultdict(Counter)
    for cue, resp in zip(long_df["cue"], long_df["resp"]):
        by[cue][resp] += 1
    return by

def topk_overlap(h_cnt, m_cnt, k=10):
    H = [w for w,_ in h_cnt.most_common(k)]
    M = [w for w,_ in m_cnt.most_common(k)]
    i = len(set(H) & set(M))
    p = i / max(1, len(M))
    r = i / max(1, len(H))
    j = i / max(1, len(set(H) | set(M)))
    f1 = 0 if (p + r) == 0 else 2 * p * r / (p + r)
    return p, r, f1, j

def dist_vectors(h_cnt, m_cnt, vocab=None, smooth=1e-9):
    if vocab is None:
        vocab = sorted(set(h_cnt) | set(m_cnt))
    hv = np.array([h_cnt[w] for w in vocab], dtype=float)
    mv = np.array([m_cnt[w] for w in vocab], dtype=float)
    if hv.sum() == 0: hv = np.ones_like(hv)
    if mv.sum() == 0: mv = np.ones_like(mv)
    hp = hv / hv.sum()
    mp = mv / mv.sum()
    # cosine
    cos = float(np.dot(hp, mp) / (np.linalg.norm(hp) * np.linalg.norm(mp)))
    # JS divergence (base 2)
    M = 0.5 * (hp + mp)
    def _kl(p, q):
        p = np.clip(p, smooth, 1.0)
        q = np.clip(q, smooth, 1.0)
        return float(np.sum(p * np.log2(p / q)))
    jsd = 0.5 * _kl(hp, M) + 0.5 * _kl(mp, M)
    # Entropies
    def H(p):
        p = np.clip(p, smooth, 1.0)
        return float(-np.sum(p * np.log2(p)))
    return cos, jsd, H(hp), H(mp), vocab, hp, mp

def spearman_on_union(h_cnt, m_cnt):
    vocab = sorted(set(h_cnt) | set(m_cnt))
    if len(vocab) <= 1: return np.nan
    hr = pd.Series({w: h_cnt[w] for w in vocab}).rank(method="average")
    mr = pd.Series({w: m_cnt[w] for w in vocab}).rank(method="average")
    corr = float(np.corrcoef(hr.values, mr.values)[0, 1])
    return corr

def hit_at_3_per_row(model_df, human_bycue):
    df = model_df[["cue","r1","r2","r3"]].copy()
    for c in ["cue","r1","r2","r3"]:
        df[c] = df[c].map(_norm)
    hits = []
    for _, row in df.iterrows():
        cue = row["cue"]
        human_set = set(human_bycue[cue].keys())
        trip = {row["r1"], row["r2"], row["r3"]} - {""}
        hits.append(1 if len(trip & human_set) > 0 else 0)
    return float(np.mean(hits))

def jaccard_all_unique(h_cnt, m_cnt):
    H = set(h_cnt.keys())
    M = set(m_cnt.keys())
    if not H and not M: return np.nan
    return len(H & M) / max(1, len(H | M))

def bootstrap_ci(values, n_boot=2000, alpha=0.05, rng=np.random.default_rng(0)):
    vals = np.asarray(values, dtype=float)
    vals = vals[~np.isnan(vals)]
    n = len(vals)
    if n == 0: return (np.nan, np.nan, np.nan)
    boots = [np.mean(vals[rng.integers(0, n, n)]) for _ in range(n_boot)]
    lo, hi = np.quantile(boots, [alpha/2, 1 - alpha/2])
    return float(np.mean(vals)), float(lo), float(hi)

# ---------- load (Excel) ----------
hum  = pd.read_excel(HUM_PATH)
ft   = pd.read_excel(FT_PATH)
base = pd.read_excel(BASE_PATH)

# keep the 4 columns; normalize column names
ft   = ft[["cue","R1","R2","R3"]].rename(columns=str.lower)
base = base[["cue","R1","R2","R3"]].rename(columns=str.lower)

# long forms & counts
hum_long  = melt_triplets(hum)
ft_long   = melt_triplets(ft)
base_long = melt_triplets(base)

hum_by  = counts_by_cue(hum_long)
ft_by   = counts_by_cue(ft_long)
base_by = counts_by_cue(base_long)

cues = sorted(set(hum_by) & set(ft_by) & set(base_by))

# ---------- per-cue metrics ----------
rows = []
for cue in cues:
    h = hum_by[cue]
    f = ft_by[cue]
    b = base_by[cue]

    # Top-K overlap (K=10 by default)
    Pft, Rft, Fft, Jft = topk_overlap(h, f, k=10)
    Pba, Rba, Fba, Jba = topk_overlap(h, b, k=10)

    # Full-distribution similarities
    cos_f, jsd_f, Hh, Hf, _, _, _ = dist_vectors(h, f)
    cos_b, jsd_b, _,  Hb, _, _, _ = dist_vectors(h, b)

    # Rank agreement
    spr_f = spearman_on_union(h, f)
    spr_b = spearman_on_union(h, b)

    # Jaccard on ALL unique responses (set overlap)
    Jall_ft = jaccard_all_unique(h, f)
    Jall_ba = jaccard_all_unique(h, b)

    rows.append({
        "cue": cue,
        # Top-K (K=10)
        "P@10_ft": Pft, "R@10_ft": Rft, "F1@10_ft": Fft, "Jaccard@10_ft": Jft,
        "P@10_base": Pba, "R@10_base": Rba, "F1@10_base": Fba, "Jaccard@10_base": Jba,
        # Full distributions
        "cosine_ft": cos_f, "cosine_base": cos_b,
        "JSD_ft": jsd_f, "JSD_base": jsd_b,
        "spearman_ft": spr_f, "spearman_base": spr_b,
        "H_human": Hh, "H_ft": Hf, "H_base": Hb,
        # Jaccard over ALL unique responses
        "Jaccard_all_ft": Jall_ft,
        "Jaccard_all_base": Jall_ba,
    })

percue = pd.DataFrame(rows)

# row-level Hit@3
ft_hit3   = hit_at_3_per_row(ft,   hum_by)
base_hit3 = hit_at_3_per_row(base, hum_by)

# ---------- headline (macro) metrics with bootstrap CIs ----------
def macro(col_ft, col_ba):
    x = percue[col_ft].values
    y = percue[col_ba].values
    m_ft, lo_ft, hi_ft = bootstrap_ci(x)
    m_ba, lo_ba, hi_ba = bootstrap_ci(y)
    return {
        "ft_mean": m_ft, "ft_CI": (lo_ft, hi_ft),
        "base_mean": m_ba, "base_CI": (lo_ba, hi_ba),
        "delta": m_ft - m_ba
    }

# Dynamic summary: include ALL *_ft/*_base pairs and human-only macros
pretty_name = {"JSD": "JSD (lower better)"}

def _prettify(metric_key: str) -> str:
    # metric_key is without suffix (e.g., "P@10", "cosine", "JSD", "spearman", "H", "Jaccard_all")
    base = metric_key
    if base.lower() == "cosine":   base = "Cosine"
    if base.lower() == "spearman": base = "Spearman"
    # Entropy
    if base in {"H"} or base.startswith("H_") or base == "H":
        return "Entropy"
    # Special label for JSD
    if base == "JSD":
        return pretty_name["JSD"]
    return base

summary = {
    "Hit@3 (row-level)": {"ft": ft_hit3, "base": base_hit3, "delta": ft_hit3 - base_hit3}
}

# add every *_ft / *_base pair automatically
for col in percue.columns:
    if col.endswith("_ft"):
        base_col = col[:-3] + "_base"
        if base_col in percue.columns:
            metric_key = col[:-3]  # strip "_ft"
            label = _prettify(metric_key)
            summary[label] = macro(col, base_col)

# add macro means for any human-only columns (e.g., H_human)
human_only = {}
for col in percue.columns:
    if col.endswith("_human"):
        m, lo, hi = bootstrap_ci(percue[col].values)
        # label human metrics nicely
        if col.startswith("H_"):
            lab = "Entropy (human)"
        else:
            lab = f"{col.replace('_', ' ').title()}"
        human_only[lab] = {"mean": m, "CI": (lo, hi)}

if human_only:
    summary["Human-only macro means"] = human_only

# ---------- print & save ----------
pd.set_option("display.precision", 4)
print("=== Headline metrics (FT vs Base vs Human) ===")
print(json.dumps(summary, indent=2))

# save detailed per-cue table (Excel)
OUT = os.path.join(BASE_PATH, r"data/fine_tuned_and_base_models_comparisons/free_associations_task/eval_full_val_ft_vs_base_vs_human_per_cue.xlsx")
percue.sort_values("cue").to_excel(OUT, index=False)
print(f"\nPer-cue metrics saved to: {OUT}")

# quick peek at top-10 lists for some cues (optional)
def top_list(cnt, k=10): return [w for w,_ in cnt.most_common(k)]
peek = []
for cue in cues[:25]:
    peek.append({
        "cue": cue,
        "top10_human": ", ".join(top_list(hum_by[cue], 10)),
        "top10_ft":    ", ".join(top_list(ft_by[cue], 10)),
        "top10_base":  ", ".join(top_list(base_by[cue], 10)),
    })
print("\nSample of top-10 lists (first 10 cues):")
print(pd.DataFrame(peek).head(10).to_string(index=False))

# optional: also show the first few per-cue metric rows
print("\n=== All metrics per cue (first 10 cues) ===")
print(percue.head(10).to_string(index=False))

# save summary JSON (unchanged)
SUMMARY_OUT = os.path.join(BASE_PATH, r"data/models_predictions_associations_task/full_llama3_8b_system_prompt_lora_SFT_SWOW_tgt_qkvo_tr7194c_val899c_r16_a32_do0p1_lr0.0001_bs16_ga4/eval_full_val_ft_vs_base_vs_human_summary.json")
with open(SUMMARY_OUT, "w") as f:
    json.dump(summary, f, indent=2)
print(f"Summary JSON saved to: {SUMMARY_OUT}")


In [None]:
# Disconnect the runtime
from google.colab import runtime
runtime.unassign()

In [None]:
# =========================================================
# 🧪  Generate 80× responses per cue on validation split
#      • fine-tuned model
#      • base model
#      + save raw outputs, conversation history, and prompts
#      + batched via num_return_sequences with OOM fallback
# =========================================================


import os, gc, re, time, torch, pandas as pd, math
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# ---------- PATHS -------------------------------------------------
VAL_XLSX_PATH  = os.path.join(BASE_PATH, r"data/final_processed_SWOW_data/val.xlsx")

# TODO: update FINETUNED_MODEL to your current run's merged or adapter path
FINETUNED_MODEL = os.path.join(BASE_PATH, r"data/fine_tuned_and_base_models_comparisons/free_associations_task/full_llama3_8b_system_prompt_lora_SFT_SWOW_tgt_qkvo_tr7194c_val899c_r16_a32_do0p1_lr0.0001_bs16_ga4/merged_model")
BASE_MODEL      = "meta-llama/Meta-Llama-3-8B-Instruct"

SAVE_DIR_FT   = os.path.join(BASE_PATH, r"data/models_predictions_associations_task/full_llama3_8b_system_prompt_lora_SFT_SWOW_tgt_qkvo_tr7194c_val899c_r16_a32_do0p1_lr0.0001_bs16_ga4")
os.makedirs(SAVE_DIR_FT, exist_ok=True)

SAVE_DIR_BASE = os.path.join(BASE_PATH, r"data/models_predictions_associations_task/llama3_8b_base_model")
os.makedirs(SAVE_DIR_BASE, exist_ok=True)
# --------------------------------------------------------------------

# ---------- Precision / quant (use the SAME for both models) --------
USE_4BIT = False  # set True if you prefer 4-bit generation for both
if USE_4BIT:
    load_kwargs = dict(
        device_map="auto",
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16,  # A100-friendly
        ),
        use_safetensors=True,
    )
else:
    load_kwargs = dict(
        device_map="auto",
        torch_dtype=torch.bfloat16,  # A100-friendly; consistent across models
        use_safetensors=True,
    )
# --------------------------------------------------------------------

# ---------- Sampling knobs (unchanged except batching) ---------------
NUM_SAMPLES_PER_CUE = 80
TRY_AT_ONCE         = 80   # first attempt per call; auto-shrinks on OOM
MAX_NEW_TOKENS      = 50
TEMPERATURE         = 0.7
TOP_P               = 0.9
# --------------------------------------------------------------------

SYSTEM_PROMPT = """\
Task:
 - You will be provided with an input word: write the first 3 words you associate to it separated by a comma.
 - No additional output text is allowed.

Constraints:
 - no carriage return characters are allowed in the answers.
 - answers should be as short as possible.

Example:
 Input: sea
 Output: water, beach, sun"""

def build_prompt_text(tokenizer, cue: str) -> str:
    """Render messages with the official Llama-3 chat template."""
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user",   "content": str(cue).strip()},
    ]
    # This appends the assistant header so generation starts in assistant turn
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def build_history(cue: str, assistant_text: str) -> str:
    """Readable transcript (system → user → assistant)."""
    return (
        "<system>\n"   + SYSTEM_PROMPT.strip() + "\n</system>\n"
        "<user>\n"     + str(cue).strip()      + "\n</user>\n"
        "<assistant>\n"+ (assistant_text or "").strip() + "\n</assistant>"
    )

def extract_responses(text: str, n=3):
    """
    Extract up to n words from a comma-separated generation.
    We *don’t* expect 'Answer:' anymore. Normalize lightly.
    """
    # Keep letters, digits, spaces, hyphens, and commas; lowercase
    clean = re.sub(r"[^\w,\- ]", "", text.lower())
    # robust split for ",", " ,", ", " etc.
    words = [w.strip() for w in re.split(r"\s*,\s*", clean) if w.strip()]
    return (words + ["", "", ""])[:n]

def _generate_many_for_cue(model, tok, cue: str) -> list[dict]:
    """
    Generate NUM_SAMPLES_PER_CUE samples for one cue using num_return_sequences,
    automatically backing off the chunk size if CUDA runs out of memory.
    """
    prompt_text = build_prompt_text(tok, cue)
    prompt_ids  = tok(prompt_text, return_tensors="pt").to(model.device)

    recs = []
    remaining   = NUM_SAMPLES_PER_CUE
    try_at_once = TRY_AT_ONCE

    while remaining > 0:
        cur_n = min(try_at_once, remaining)
        try:
            with torch.no_grad():
                out = model.generate(
                    **prompt_ids,
                    max_new_tokens=MAX_NEW_TOKENS,
                    do_sample=True, temperature=TEMPERATURE, top_p=TOP_P,
                    pad_token_id=tok.eos_token_id,
                    eos_token_id=tok.eos_token_id,  # tokenizer maps EOT properly
                    num_return_sequences=cur_n,      # ← key: many samples at once
                )
        except torch.cuda.OutOfMemoryError:
            # Back off and try a smaller chunk
            torch.cuda.empty_cache()
            try_at_once = max(1, try_at_once // 2)
            # if we already tried size 1 and failed, re-raise
            if cur_n == 1 and try_at_once == 1:
                raise
            continue

        # out has shape [cur_n, seq_len]; decode each
        prompt_len = prompt_ids["input_ids"].shape[1]
        for i in range(out.size(0)):
            gen_only = out[i, prompt_len:]
            raw = tok.decode(gen_only, skip_special_tokens=True)

            r1, r2, r3 = extract_responses(raw)
            recs.append(dict(
                cue=cue,
                R1=r1, R2=r2, R3=r3,
                raw_output=raw,
                history=build_history(cue, raw),
                system_prompt=SYSTEM_PROMPT.strip(),
                user_prompt=str(cue).strip(),
            ))

        remaining -= cur_n

    return recs

def generate_dataset(model_id_or_path: str, tag: str):
    """Return list[dict]: {cue, R1, R2, R3, raw_output, history, system_prompt, user_prompt} for the val cues."""
    print(f"\n🔹 Loading {tag} model …")
    tok = AutoTokenizer.from_pretrained(model_id_or_path, use_fast=True)
    tok.pad_token = tok.eos_token

    model = AutoModelForCausalLM.from_pretrained(model_id_or_path, **load_kwargs)
    model.eval()

    # ---------- read validation cues from Excel ----------
    val_df = pd.read_excel(VAL_XLSX_PATH)
    cues = val_df["cue"].dropna().astype(str).str.strip().unique()
    recs = []

    for cue in tqdm(cues, desc=f"{tag} {NUM_SAMPLES_PER_CUE}× per cue"):
        recs.extend(_generate_many_for_cue(model, tok, cue))
        torch.cuda.empty_cache()  # light hygiene between cues

    # -------- cleanup to free VRAM --------
    del model; del tok; gc.collect(); torch.cuda.empty_cache()
    return recs

# ---------- Fine-tuned model -----------------------------------
t0 = time.time()
ft_rows = generate_dataset(FINETUNED_MODEL, "Finetuned")
ft_tag  = os.path.basename(os.path.normpath(FINETUNED_MODEL))  # folder name like 'merged_model' or run dir
ft_path = os.path.join(SAVE_DIR_FT, f"full_val_ft_{ft_tag}_predictions.xlsx")
pd.DataFrame(ft_rows).to_excel(ft_path, index=False)  # ← Excel
print(f"Fine-tuned predictions saved → {ft_path} ({len(ft_rows):,} rows)")
print(f"{time.time()-t0:.1f}s elapsed")

# ---------- Base model -----------------------------------------
t1 = time.time()
base_rows = generate_dataset(BASE_MODEL, "Base")
base_path = os.path.join(SAVE_DIR_BASE, "llama3_8b_full_val_base_predictions.xlsx")
pd.DataFrame(base_rows).to_excel(base_path, index=False)  # ← Excel
print(f"Base-model predictions saved → {base_path} ({len(base_rows):,} rows)")
print(f"{time.time()-t1:.1f}s elapsed (second run)")

In [None]:
# Disconnect the runtime
from google.colab import runtime
runtime.unassign()