## Dependencies

In [None]:
import os
import json
import pickle
import random
import warnings
import gc
from itertools import product

import seaborn as sns
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from einops import rearrange
from scipy.stats import spearmanr
from sklearn.linear_model import Ridge, RidgeClassifier
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import KFold, train_test_split
from sklearn.neural_network import MLPRegressor
from scipy.stats import pearsonr, spearmanr
from tqdm.auto import tqdm
from IPython.display import HTML, display

from baukit import Trace, TraceDict
from utils import *

warnings.filterwarnings("ignore")
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'  # Suppress transformer warnings
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")

transformers.logging.set_verbosity_error()

SEED = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(SEED)

In [7]:
MODELS = [
    "/project/jevans/tzhang3/models/Llama-2-7b-chat-hf",
    "/project/jevans/tzhang3/models/Llama-3.1-8B-Instruct",
    "/project/jevans/tzhang3/models/Qwen2.5-7B-Instruct",
    "/project/jevans/tzhang3/models/Qwen2.5-14B-Instruct"
]

## Data

In [8]:
df = pd.read_csv('./data/HS116_members.csv')
df = df.loc[pd.notnull(df.nominate_dim1)].reset_index(drop=True)
df['first_middle_name'] = df['bioname'].apply(lambda x: x.split(', ')[1])
df['first_name'] = df['first_middle_name'].apply(lambda x: x.split(' ')[0] + ' ' if ' ' in x else x + ' ')
df['middle_name'] = df['first_middle_name'].apply(lambda x: x.split(' ')[1][0] + '. ' if ' ' in x else '')
df['last_name'] = df['bioname'].apply(lambda x: x.split(', ')[0].lower().capitalize().strip())
df['fullname'] = df['first_name'] + df['middle_name'] + df['last_name']

In [9]:
adfontesmedia = pd.read_csv('./data/adfontesmedia.csv')

In [10]:
def build_politician_prompt(fullname: str) -> str:
    return (
        f"USER: Generate a statement by {fullname}, a politician in the United States.\n"
        f"ASSISTANT: In 2019, {fullname} said that"
    )

def build_news_prompt(source_name: str) -> str:
    return (
        f"USER: Generate a statement from a news source in the United States.\n"
        f"ASSISTANT: {source_name} reported that"
    )

def build_vote_prompt(fullname: str) -> str:
    return (
        f"USER: In the context of 2019 U.S. politics, if you were {fullname}, "
        f"who would you most likely vote for in the next U.S. presidential election: Joe Biden or Donald Trump?\n"
        f"ASSISTANT: The person would most likely vote for"
    )

# Generate prompts
prompts_politicians = df["fullname"].apply(build_politician_prompt)
labels = df["nominate_dim1"].astype(float).values
df_prompts_politicians = pd.DataFrame({"prompt": prompts_politicians, "label": labels})

prompts_news = adfontesmedia["moniker_name"].apply(build_news_prompt)
labels = adfontesmedia["bias_mean"].astype(float).values
df_prompts_news = pd.DataFrame({"prompt": prompts_news, "label": labels})

# Generate vote-style prompts for politicians
prompts_dw_vote = df["fullname"].apply(build_vote_prompt)
labels_dw_vote = df["nominate_dim1"].astype(float).values  # Reuse same labels (liberal-conservative)
df_prompts_dw_vote = pd.DataFrame({"prompt": prompts_dw_vote, "label": labels_dw_vote})


In [11]:
df_prompts_vote = pd.read_csv("data/vote_2020_dataset.csv")
df_prompts_vote

Unnamed: 0,prompt,label_text,label
0,"USER: Given the 2020 economic climate, who was...",Trump,1
1,"USER: With all factors of 2020 in mind, which ...",Trump,1
2,"USER: Reflecting on 2020, who should have led ...",Biden,-1
3,"USER: Reflecting on 2020, who should have led ...",Biden,-1
4,"USER: During the 2020 contest, who was your pr...",Biden,-1
...,...,...,...
595,USER: Pretend you're in the voting booth in 20...,Trump,1
596,"USER: Considering their 2020 platforms, who de...",Biden,-1
597,"USER: When deciding in 2020, which name would ...",Biden,-1
598,USER: Who would you vote for in the 2020 U.S. ...,Trump,1


In [12]:
df_anes = pd.read_csv("data/anes.csv")

# Prompt builder
def build_prompt(persona, query):
    return (
        f"USER: How would {persona} in the United States in 2019 answer the following question. "
        f"{query.strip()} "
        f"ASSISTANT: As {persona}, I would answer the question with a score of "
    )

prompts_anes = []
for _, row in df_anes.iterrows():
    domain, question = row['domain'], row['query']
    for persona in df['fullname'].tolist():
        prompts_anes.append(build_prompt(persona, question))

In [13]:
# Prompt builder
def build_prompt(alignment, query):
    return (
        f"USER: Generate a statement from the {alignment} on the following question in the United States. "
        # f"{query.strip()}\n"
        # f"Do you favor or oppose transgender bathrooms?\n"
        # f"Do you favor or oppose US military aid to Ukraine?\n"
        f"Do you favor or oppose buying green socks for Christmas?\n"
        f"ASSISTANT: In 2019, the {alignment} believes that"
    )
    
prompts_socks = []
for alignment in ["liberal", "conservative"]:
    prompts_socks.append(build_prompt(alignment, None))

In [14]:
def build_prompt(query):
    return (
        f"USER: In the context of 2019 U.S. politics, is the following statement biased and leaning toward a liberal or conservative viewpoint? "
        f"{query.strip()}\n"
        f"ASSISTANT: The statement is"
    )

df_bias = pd.read_csv("data/policy_statements_240.csv")

prompts_bias = []
for _, row in df_bias.iterrows():
    text = row['response']
    prompts_bias.append(build_prompt(text))

len(prompts_bias)

240

In [15]:
def build_vote_prompts(alignment: str):
    return [
        # 1-20
        f"USER: If you had to vote in the 2019 U.S. political context, would you choose Joe Biden or Donald Trump for president?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: In your opinion, who would be the better choice for the next U.S. president in 2020—Biden or Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: In the 2019 U.S. political environment, who aligns more with your values—Joe Biden or Donald Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Given the candidates Joe Biden and Donald Trump, who would you vote for in the 2020 U.S. election?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Based on 2019 U.S. politics, which candidate would receive your vote: Donald Trump or Joe Biden?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: As a {alignment}, who would you feel more comfortable voting for in the 2020 U.S. presidential election—Joe Biden or Donald Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Reflecting on 2019 U.S. politics, who would a {alignment} likely cast their vote for: Joe Biden or Donald Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Considering the policy debates of 2019, does Joe Biden or Donald Trump better represent your priorities?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Thinking about your economic preferences in 2019, would Biden or Trump get your vote?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Comparing leadership styles, who would you endorse for president in 2020: Donald Trump or Joe Biden?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Which candidates platform, Biden's or Trump's, resonates more with you based on 2019 issues?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Who do you see as the stronger commander-in-chief choice for 2020—Biden or Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: From a 2019 standpoint, who wins your support: Joe Biden or Donald Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: If only Biden and Trump were on the ballot, who would you cast your 2020 vote for?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Looking at social issues debated in 2019, would Trump or Biden be your preferred candidate?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Whose foreign-policy approach in 2019 convinces you to vote for them—Biden or Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Evaluating character and experience, do you lean toward Joe Biden or Donald Trump for president?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: As a {alignment} voter in 2019, whose vision for America is closer to yours: Trump's or Biden's?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Between Biden and Trump, who aligns better with your voting preference in 2019?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Given the challenges highlighted in 2019, who would you trust more to handle the 2020 presidency—Joe Biden or Donald Trump?\nASSISTANT: As a {alignment}, I would vote for",
        # 20-40
        f"USER: If faced with the 2020 U.S. presidential ballot, would you choose Joe Biden or Donald Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Which candidate do you think reflects your values more in 2019, Biden or Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Between Donald Trump and Joe Biden, who seems like the better leader to you in the 2020 election?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Considering the 2019 debates, who would be your choice for president: Biden or Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Who would earn your confidence as president in 2020, Trump or Biden?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Looking at the issues of 2019, would Joe Biden or Donald Trump better deserve your vote?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Imagine casting a ballot in 2020—would you pick Biden or Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Who would you back as the stronger candidate for the presidency: Donald Trump or Joe Biden?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Thinking about 2019’s political climate, which candidate speaks more to your beliefs, Biden or Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: As a {alignment}, which candidate’s vision resonates with you in 2020: Joe Biden or Donald Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: When comparing their agendas, who would you prefer as president—Biden or Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Based on leadership qualities, who is your pick for the 2020 race: Trump or Biden?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Between Joe Biden and Donald Trump, who would a {alignment} most likely vote for in 2019?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: If you were voting strictly on character, who would you support—Biden or Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Considering domestic policy in 2019, would you side with Biden or Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Who is the preferable commander-in-chief in your eyes: Joe Biden or Donald Trump?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: If only Biden and Trump were on the ticket, who would you choose for president?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: Thinking as a {alignment}, whose 2019 platform better serves your priorities—Trump’s or Biden’s?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: When reflecting on 2019 politics, which candidate represents your stance: Donald Trump or Joe Biden?\nASSISTANT: As a {alignment}, I would vote for",
        f"USER: In your view, who is the right choice for the White House in 2020: Biden or Trump?\nASSISTANT: As a {alignment}, I would vote for",
    ]

prompts_vote = []
meta_vote = []

for alignment in ["liberal", "conservative"]:
    for i, prompt in enumerate(build_vote_prompts(alignment)):
        prompts_vote.append(prompt)
        meta_vote.append({"prompt_id": i, "alignment": alignment})


In [16]:
def build_prompt(query):
    return (
        f"USER: In the context of 2019 U.S. politics, please rewrite the following statement to make it politically neutral and free from liberal or conservative bias:\n"
        f"\"{query.strip()}\"\n"
        f"ASSISTANT: Here's a neutral version of the text:"
    )

df_bias = pd.read_csv("data/policy_statements_240.csv")

prompts_rewrite = []
meta_rewrite = []

# df_bias = df_bias.sample(1, random_state=42)
for _, row in df_bias.iterrows():
    text = row['response']
    prompts_rewrite.append(build_prompt(text))
    meta_rewrite.append({"name": row["name"], "domain": row["domain"]})

## AIME

In [20]:
# ===============================================
# AIME 2025 evaluation with DW-vote steering
# ===============================================
import os
import re
import json
import math
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import List, Dict, Tuple, Optional

# ---------- Config ----------
AIME_PATHS_TRY = [
    "data/aime2025-I.jsonl",  # preferred: one JSON per line with {"question": "...", "answer": 123}
]
DW_PREFIX = "politician"          # ridge prefix trained on DW-vote prompts
ALPHAS = [-20, 0, 20]
K_LIST = [32]                  # number of heads to combine; can sweep more if you like
MAX_NEW_TOKENS = 2048
TEMPERATURE = 0.2
TOP_P = 0.95
SEED = 42

# ---------- Utils ----------
def _set_seed(seed: int = SEED):
    import torch, numpy as np, random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

_set_seed(SEED)

def load_aime_dataset(paths: List[str]) -> List[Dict]:
    """
    Load AIME 2025 dataset from jsonl or csv.
    Expected schema:
      - jsonl: one object per line with keys {"question": str, "answer": int or str}
      - csv  : columns ["question","answer"]
    Returns list of dicts: [{"qid": int, "question": str, "answer": str}]
    """
    for p in paths:
        if os.path.isfile(p):
            ext = os.path.splitext(p)[1].lower()
            records = []
            if ext == ".jsonl":
                with open(p, "r", encoding="utf-8") as f:
                    for i, line in enumerate(f):
                        obj = json.loads(line)
                        q = str(obj["question"]).strip()
                        a = str(obj["answer"]).strip()
                        records.append({"qid": i, "question": q, "answer": a})
            elif ext == ".csv":
                df = pd.read_csv(p)
                assert {"question","answer"}.issubset(df.columns)
                for i, row in df.iterrows():
                    records.append({"qid": i, "question": str(row["question"]).strip(), "answer": str(row["answer"]).strip()})
            print(f"Loaded AIME 2025 from {p}  (N={len(records)})")
            return records

    # Fallback tiny mock set so code runs even without the real file
    print("AIME 2025 file not found. Using a tiny mock dataset (2 items) for a dry run.")
    mock = [
        {
            "qid": 0,
            "question": "If (x+1)^2 = 25 and x is positive, compute x.",
            "answer": "4"
        },
        {
            "qid": 1,
            "question": "Let S be the sum of the first 10 positive integers. What is S?",
            "answer": "55"
        },
    ]
    return mock

def build_math_chat_prompt(tokenizer, question: str) -> str:
    """
    Build a chat-formatted prompt asking for a single final numeric answer.
    You can adjust the instruction style to match your model family.
    """
    messages = [
        {
            "role": "system",
            "content": (
                "You are a careful competition mathematician. Solve the problem step by step, "
                "then give ONLY the final numeric answer on the last line in the form: Final Answer: <number>."
            ),
        },
        {
            "role": "user",
            "content": (
                f"{question}\n\n"
                "Do not include units unless necessary. Do not include extra text after the final answer line."
            ),
        },
    ]
    # Use tokenizer chat template if available; otherwise fall back to a plain prompt
    if hasattr(tokenizer, "apply_chat_template"):
        try:
            return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        except Exception:
            pass
    # fallback
    sys = messages[0]["content"]
    usr = messages[1]["content"]
    return f"<|system|>\n{sys}\n<|user|>\n{usr}\n<|assistant|>\n"

_FINAL_ANSWER_PATTERNS = [
    r"Final Answer:\s*([-+]?\d+)",            # Final Answer: 123
    r"\\boxed\{([^}]+)\}",                     # \boxed{123}
    r"Answer:\s*([-+]?\d+)",                   # Answer: 123
    r"=+\s*([-+]?\d+)\s*$",                    # ends with = 123
]

def extract_numeric_answer(text: str) -> Optional[str]:
    """
    Try to extract the final numeric answer (integer) from model text.
    Returns string of integer if found, else None.
    """
    # prioritize explicit "Final Answer"
    for pat in _FINAL_ANSWER_PATTERNS:
        m = re.search(pat, text, flags=re.IGNORECASE | re.MULTILINE)
        if m:
            cand = m.group(1).strip()
            # strip common wrappers
            cand = re.sub(r"[^\d\-+]", "", cand)
            # ensure it's integer-like
            if re.fullmatch(r"[-+]?\d+", cand):
                return str(int(cand))
    # fallback: last integer in the text
    ints = re.findall(r"[-+]?\d+", text)
    if ints:
        return str(int(ints[-1]))
    return None

def grade_pred(gold: str, pred: Optional[str]) -> bool:
    """
    AIME answers are integers (or short numbers). Here we treat exact integer match as correct.
    You can expand with tolerance if your dataset includes non-integers.
    """
    if pred is None:
        return False
    # normalize
    try:
        g = int(str(gold).strip())
        p = int(str(pred).strip())
        return g == p
    except Exception:
        # if not integer, do string compare after minimal cleanup
        return str(gold).strip() == str(pred).strip()

def answer_batch_unsteered(model, tokenizer, questions: List[str], max_new_tokens=256) -> List[str]:
    """
    Generate answers with no interventions.
    """
    import torch
    outs = []
    eos_ids = [tokenizer.eos_token_id]
    # add eot if known
    try:
        eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
        if isinstance(eot_id, int) and eot_id > 0:
            eos_ids.append(eot_id)
    except Exception:
        pass

    model.eval()
    with torch.no_grad():
        for q in questions:
            prompt = build_math_chat_prompt(tokenizer, q)
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
            gen = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=TEMPERATURE,
                top_p=TOP_P,
                eos_token_id=eos_ids,
                pad_token_id=tokenizer.eos_token_id,
            )
            text = tokenizer.decode(gen[0], skip_special_tokens=True)
            # trim prompt if chat template returned in full text
            if text.startswith(prompt):
                text = text[len(prompt):]
            
            print("Out:", text)
            outs.append(text.strip())
    return outs

def answer_batch_steered_dw(
    model, tokenizer, questions: List[str],
    combined_coefs,
    alpha: float,
    max_new_tokens=256
) -> List[str]:
    """
    Generate answers with DW-vote steering applied (head intervention).
    Uses your previously defined `generate_with_head_intervention_gpu`.
    """
    # We generate one question at a time to reuse your hook logic
    results = []
    for q in questions:
        prompt = build_math_chat_prompt(tokenizer, q)
        out = generate_with_head_intervention_gpu(
            model=model,
            tokenizer=tokenizer,
            prompts=[prompt],
            alpha=alpha,
            max_new_tokens=max_new_tokens,
            combined_coefs=combined_coefs,
            return_features=False,
            device=model.device if hasattr(model, "device") else None,
        )
        print("Out", out[0]["answer"])
        results.append(out[0]["answer"])
        
    return results

def evaluate_run(gold_answers: List[str], raw_outputs: List[str]) -> Dict:
    preds = [extract_numeric_answer(t) for t in raw_outputs]
    correct = [grade_pred(g, p) for g, p in zip(gold_answers, preds)]
    return {
        "n": len(gold_answers),
        "correct": int(sum(correct)),
        "acc": float(np.mean(correct)) if len(correct) else 0.0,
        "preds": preds,
        "raw": raw_outputs,
        "is_correct": correct,
    }

def run_aime_dw_compare(model_path: str, aime_records: List[Dict], ks=K_LIST, alphas=ALPHAS) -> pd.DataFrame:
    """
    Runs baseline (unsteered) and DW-steered generations on AIME 2025,
    evaluates accuracy, and returns a results DataFrame.
    Also writes CSV + a PNG plot to ./results/<model_base>/ .
    """
    import torch

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    questions = [r["question"] for r in aime_records]
    gold      = [r["answer"]   for r in aime_records]

    # Load model/tokenizer
    tokenizer, model = load_tokenizer_and_model(model_path, device=device)
    base_name = model_base_name(model_path)


    # Ensure output dir
    out_dir = os.path.join("results", base_name)
    os.makedirs(out_dir, exist_ok=True)

    # ---------- Baseline ----------
    print("\n[Baseline] Generating without steering...")
    t0 = time.time()
    baseline_text = answer_batch_unsteered(model, tokenizer, questions[:1], max_new_tokens=MAX_NEW_TOKENS)
    base_eval = evaluate_run(gold, baseline_text)
    print(f"Baseline AIME acc = {base_eval['acc']:.3f}  ({base_eval['correct']}/{base_eval['n']})  time={time.time()-t0:.1f}s")

    # ---------- Load DW-vote ridge/perf and feats std (through compute_combined_coefs) ----------
    print("\n[DW-steering] Loading ridge models & performance...")
    feats = load_feats(model_path=model_path, prefix=DW_PREFIX, results_dir="./results", device=device)
    ridge_models, performance = load_ridge_models(model_path, DW_PREFIX)

    # For compute_combined_coefs we need feats OR (k, performance); we’ll use k+performance path here.
    # If you preloaded feats earlier, you can switch to the feats-based call for speed.

    # ---------- Sweep steering settings ----------
    rows = []
    for k in ks:
        # Build steering coefficients [L,H,D] once per k
        print(f"\nBuilding combined_coefs for k={k} ...")
        coefs = compute_combined_coefs(
            feats=feats,                # not using preloaded feats
            ridge_models=ridge_models,
            k=k,
            performance=performance,
            device=model.device if hasattr(model, "device") else None,
        )

        for alpha in alphas:
            print(f"[DW] alpha={alpha}, k={k} → generating ...")
            t0 = time.time()
            steered_text = answer_batch_steered_dw(
                model, tokenizer, questions,
                combined_coefs=coefs,
                alpha=alpha,
                max_new_tokens=MAX_NEW_TOKENS,
            )
            eval_res = evaluate_run(gold, steered_text)
            dt = time.time() - t0
            print(f"   acc={eval_res['acc']:.3f}  ({eval_res['correct']}/{eval_res['n']})  time={dt:.1f}s")
            rows.append({
                "model": base_name,
                "mode": "steered_dw",
                "alpha": alpha,
                "k": k,
                "acc": eval_res["acc"],
                "correct": eval_res["correct"],
                "n": eval_res["n"],
            })

    # Append baseline row
    rows.append({
        "model": base_name,
        "mode": "baseline",
        "alpha": 0,
        "k": 0,
        "acc": base_eval["acc"],
        "correct": base_eval["correct"],
        "n": base_eval["n"],
    })

    df = pd.DataFrame(rows)
    csv_path = os.path.join(out_dir, "aime2025_dw_vote_eval.csv")
    df.to_csv(csv_path, index=False)
    print(f"\nSaved summary CSV → {csv_path}")

    # ---------- Plot ----------
    # Single-plot requirement (no seaborn, no explicit colors)
    fig = plt.figure(figsize=(7, 4.5))
    # Plot baseline as a horizontal line
    y = base_eval["acc"]
    plt.axhline(y=y, linestyle="--", label=f"Baseline acc = {y:.3f}")
    # Plot steered curves per k
    for k in sorted(set(df[df["mode"]=="steered_dw"]["k"])):
        sub = df[(df["mode"]=="steered_dw") & (df["k"]==k)].sort_values("alpha")
        plt.plot(sub["alpha"].tolist(), sub["acc"].tolist(), marker="o", label=f"Steered (k={k})")
    plt.xlabel("alpha (DW-vote steering strength)")
    plt.ylabel("AIME 2025 accuracy")
    plt.title(f"AIME 2025: Baseline vs DW-steered • {base_name}")
    plt.legend()
    plt.tight_layout()
    png_path = os.path.join(out_dir, "aime2025_dw_vote_eval.png")
    plt.savefig(png_path, dpi=150, bbox_inches="tight")
    plt.show()
    print(f"Saved plot PNG → {png_path}")

    # Cleanup VRAM
    del model
    clean_up(device=None)
    return df

# ---------- Entry point helper ----------
def run_all_models_on_aime(models: List[str] = None):
    """
    Convenience runner: iterate over your MODELS list and compare each.
    """
    records = load_aime_dataset(AIME_PATHS_TRY)
    used_models = models if models is not None else MODELS
    all_dfs = []
    for m in used_models:
        print(f"\n===== Evaluating {m} on AIME 2025 =====")
        df = run_aime_dw_compare(model_path=m, aime_records=records, ks=K_LIST, alphas=ALPHAS)
        all_dfs.append(df.assign(model_path=m))
    big = pd.concat(all_dfs, ignore_index=True)
    # Also write a combined CSV
    out = "./results/aime2025_dw_vote_eval_all_models.csv"
    os.makedirs("./results", exist_ok=True)
    big.to_csv(out, index=False)
    print(f"\nWrote combined results: {out}")
    return big

# ---------- Example call ----------
# big_df = run_all_models_on_aime(MODELS)  # uses your MODELS list defined earlier
# Or just one:
df_single = run_aime_dw_compare(MODELS[0], load_aime_dataset(AIME_PATHS_TRY), ks=[32], alphas=[-20, 0, 20])


Loaded AIME 2025 from data/aime2025-I.jsonl  (N=15)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


[Baseline] Generating without steering...
Out: [INST] <<SYS>>
You are a careful competition mathematician. Solve the problem step by step, then give ONLY the final numeric answer on the last line in the form: Final Answer: <number>.
<</SYS>>

Find the sum of all integer bases $b>9$ for which $17_{b}$ is a divisor of $97_{b}$.

Do not include units unless necessary. Do not include extra text after the final answer line. [/INST]  Great, let's dive into the problem!

Step 1: List all possible integer bases $b>9$

We need to find all integers $b>9$ such that $17_{b}$ is a divisor of $97_{b}$.

Let's list all possible bases $b>9$:

$b = 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20$

Step 2: Check if $17_{b}$ is a divisor of $97_{b}$

For each base $b$ in the list, we need to check if $17_{b}$ is a divisor of $97_{b}$.

Let's check for each base:

$b = 10$: $17_{10} = 253, 97_{10} = 1453$ - $253$ divides $1453$, so $17_{10}$ is a divisor of $97_{10}$

$b = 11$: $17_{11} = 343, 97_{11} = 1353$

KeyboardInterrupt: 

## MATH

In [17]:
from grading.grader import grade_answer

In [None]:
# ===============================================
# MATH evaluation with DW-vote steering
# ===============================================
import os
import re
import json
import glob
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import List, Dict, Optional, Tuple

# ---------- Config ----------
MATH_PATHS_TRY = [
    "data/math500.jsonl",     # JSONL with {"problem": "...", "solution": "...", "answer": "..."}
]
DW_PREFIX = "politician"        # or "dw_vote" if you want those heads
ALPHAS = [-20, 0, 20]
K_LIST = [32]
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.2
TOP_P = 0.95
SEED = 42
LIMIT = 100

# ---------- Seed ----------
def _set_seed(seed: int = SEED):
    import torch, numpy as np, random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

_set_seed(SEED)

def extract_predicted_answer(text: str) -> Optional[str]:
    """
    Extract model final answer:
    - Prefer \boxed{...} if present
    - Else look for 'Final Answer:' pattern
    - Else fallback to last inline mathy token snippet
    """
    if text is None:
        return None

    # Common explicit pattern
    m = re.search(r"Final Answer:\s*(.+?)\s*$", text, flags=re.IGNORECASE | re.DOTALL)
    if m:
        return m.group(1)

    return None

def _load_math_from_jsonl(path: str) -> List[Dict]:
    recs = []
    with open(path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            obj = json.loads(line)
            # Accept several common field names
            q = obj.get("problem") or obj.get("question") or obj.get("prompt") or ""
            a = obj.get("answer") or obj.get("final_answer") or obj.get("target") or ""
            recs.append({"qid": i, "question": str(q).strip(), "answer": str(a).strip()})
    return recs

def _load_math_from_csv(path: str) -> List[Dict]:
    df = pd.read_csv(path)
    # tolerate different header spellings
    qcol = "problem" if "problem" in df.columns else ("question" if "question" in df.columns else "prompt")
    acol = "answer"  if "answer"  in df.columns else ("final_answer" if "final_answer" in df.columns else None)
    if acol is None:
        raise ValueError("CSV must include an 'answer' (or 'final_answer') column.")
    recs = []
    for i, row in df.iterrows():
        recs.append({"qid": i, "question": str(row[qcol]).strip(), "answer": str(row[acol]).strip()})
    return recs

def _load_math_from_dir(path: str) -> List[Dict]:
    """
    Hendrycks MATH layout often has many JSON files with fields {"problem", "solution", "answer"}.
    """
    files = sorted(glob.glob(os.path.join(path, "**", "*.json"), recursive=True))
    recs = []
    for i, p in enumerate(files):
        try:
            with open(p, "r", encoding="utf-8") as f:
                obj = json.load(f)
            q = obj.get("problem") or obj.get("question") or ""
            a = obj.get("answer") or obj.get("final_answer") or ""
            recs.append({"qid": i, "question": str(q).strip(), "answer": str(a).strip()})
        except Exception:
            continue
    return recs

def load_math_dataset(paths: List[str]) -> List[Dict]:
    """
    Try JSONL, then CSV, then directory-of-JSON files. If nothing is found,
    fall back to a tiny mock set (so the pipeline still runs).
    """
    for p in paths:
        if not os.path.exists(p):
            continue
        if os.path.isdir(p):
            recs = _load_math_from_dir(p)
            if recs:
                print(f"Loaded MATH from dir {p} (N={len(recs)})")
                return recs
        else:
            ext = os.path.splitext(p)[1].lower()
            if ext == ".jsonl":
                recs = _load_math_from_jsonl(p)
                print(f"Loaded MATH from {p} (N={len(recs)})")
                return recs
            elif ext == ".csv":
                recs = _load_math_from_csv(p)
                print(f"Loaded MATH from {p} (N={len(recs)})")
                return recs

    print("MATH file not found. Using a tiny mock dataset (2 items) for a dry run.")
    return [
        {"qid": 0, "question": "Compute 7^2 - 4^2.", "answer": r"\boxed{33}"},
        {"qid": 1, "question": "What is \\frac{3}{4} + \\frac{5}{6}? Give your answer as a fraction.", "answer": r"\boxed{\frac{19}{12}}"},
    ]

# ---------- Prompting ----------
def build_math_chat_prompt(tokenizer, question: str) -> str:
    """
    Chat-formatted prompt to elicit one final line as \\boxed{...}.
    """
    messages = [
        {"role": "system",
         "content": ("You are a careful competition mathematician. Solve step by step, "
                     "then put the final result on the last line as:\n\nFinal Answer: {answer}\n")},
        {"role": "user",
         "content": (f"{question}")}
    ]
    if hasattr(tokenizer, "apply_chat_template"):
        try:
            return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        except Exception:
            pass

    sys = messages[0]["content"]
    usr = messages[1]["content"]
    return f"<|system|>\n{sys}\n<|user|>\n{usr}\n<|assistant|>\n"

# ---------- Generation ----------
def answer_batch_unsteered(model, tokenizer, problems: List[str], max_new_tokens=512) -> List[str]:
    import torch
    outs = []
    eos_ids = [tokenizer.eos_token_id]
    try:
        eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
        if isinstance(eot_id, int) and eot_id > 0:
            eos_ids.append(eot_id)
    except Exception:
        pass

    model.eval()
    with torch.no_grad():
        for q in problems:
            prompt = build_math_chat_prompt(tokenizer, q)
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
            gen = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=TEMPERATURE,
                top_p=TOP_P,
                eos_token_id=eos_ids,
                pad_token_id=tokenizer.eos_token_id,
            )
            text = tokenizer.decode(gen[0], skip_special_tokens=True)
            text = text.split("Final Answer: {answer}")[-1]
            outs.append(text.strip())
    return outs

def answer_batch_steered_dw(model, tokenizer, problems: List[str], combined_coefs, alpha: float, max_new_tokens=512) -> List[str]:
    outs = []
    for q in problems:
        prompt = build_math_chat_prompt(tokenizer, q)
        res = generate_with_head_intervention_gpu(
            model=model,
            tokenizer=tokenizer,
            prompts=[prompt],
            alpha=alpha,
            max_new_tokens=max_new_tokens,
            combined_coefs=combined_coefs,
            return_features=False,
            device=model.device if hasattr(model, "device") else None,
        )
        text = res[0]["answer"]
        text = text.split("Final Answer: {answer}")[-1]
        outs.append(text)
    return outs

# ---------- Eval ----------
def evaluate_run(gold: List[str], outputs: List[str]) -> Dict:
    preds = [extract_predicted_answer(t) for t in outputs]
    correct = [grade_answer(g, p) for g, p in zip(gold, preds)]
    for i, t in enumerate(outputs):
        print(preds[i], gold[i], correct[i])

    df = pd.DataFrame({"pred": preds, "gold": gold, "correct": correct, "output": outputs})
    df.to_csv("./results.csv")

    return {
        "n": len(gold),
        "correct": int(sum(correct)),
        "acc": float(np.mean(correct)) if len(correct) else 0.0,
        "preds": preds,
        "raw": outputs,
        "is_correct": correct,
    }

# ---------- Experiment ----------
def run_math_dw_compare(model_path: str,
                        math_records: List[Dict],
                        ks: List[int] = K_LIST,
                        alphas: List[int] = ALPHAS) -> pd.DataFrame:
    """
    Baseline vs DW-steered evaluation on MATH.
    Saves CSV and a simple matplotlib plot under ./results/<model_base>/
    """
    import torch

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer, model = load_tokenizer_and_model(model_path, device=device)
    base_name = model_base_name(model_path)

    out_dir = os.path.join("results", base_name)
    os.makedirs(out_dir, exist_ok=True)

    questions = [r["question"] for r in math_records]
    answers   = [r["answer"]   for r in math_records]

    # Load ridge/perf (DW_PREFIX can be 'politician' or 'dw_vote')
    print("\n[DW-steering] Loading ridge models & performance...")
    # Use feats shortcut if you have them; else k+performance path works too
    try:
        feats = load_feats(model_path=model_path, prefix=DW_PREFIX, results_dir="./results", device=device)
    except Exception:
        feats = None
    ridge_models, performance = load_ridge_models(model_path, DW_PREFIX)

    rows = []
    for k in ks:
        print(f"\nBuilding combined_coefs for k={k} ...")
        coefs = compute_combined_coefs(
            feats=feats,
            ridge_models=ridge_models,
            k=k,
            performance=performance,
            device=model.device if hasattr(model, "device") else None,
        )

        for alpha in alphas:
            print(f"[DW] alpha={alpha}, k={k} → generating ...")
            t0 = time.time()
            steered_text = answer_batch_steered_dw(
                model, tokenizer, questions,
                combined_coefs=coefs,
                alpha=alpha,
                max_new_tokens=MAX_NEW_TOKENS,
            )
            ev = evaluate_run(answers, steered_text)
            dt = time.time() - t0
            print(f"   acc={ev['acc']:.3f} ({ev['correct']}/{ev['n']})  time={dt:.1f}s")
            rows.append({
                "model": base_name,
                "mode": "steered_dw",
                "alpha": alpha,
                "k": k,
                "acc": ev["acc"],
                "correct": ev["correct"],
                "n": ev["n"],
            })

    # Baseline
    print("\n[Baseline] Generating without steering...")
    t0 = time.time()
    base_text = answer_batch_unsteered(model, tokenizer, questions, max_new_tokens=MAX_NEW_TOKENS)
    base_eval = evaluate_run(answers, base_text)
    print(f"Baseline MATH acc = {base_eval['acc']:.3f} ({base_eval['correct']}/{base_eval['n']})  time={time.time()-t0:.1f}s")

    rows.append({
        "model": base_name,
        "mode": "baseline",
        "alpha": 0,
        "k": 0,
        "acc": base_eval["acc"],
        "correct": base_eval["correct"],
        "n": base_eval["n"],
    })

    df = pd.DataFrame(rows)
    csv_path = os.path.join(out_dir, "math_dw_vote_eval.csv")
    df.to_csv(csv_path, index=False)
    print(f"\nSaved summary CSV → {csv_path}")

    # Plot (single figure, default colors)
    fig = plt.figure(figsize=(7, 4.5))
    plt.axhline(y=base_eval["acc"], linestyle="--", label=f"Baseline acc = {base_eval['acc']:.3f}")
    for k in sorted(set(df[df["mode"] == "steered_dw"]["k"])):
        sub = df[(df["mode"] == "steered_dw") & (df["k"] == k)].sort_values("alpha")
        plt.plot(sub["alpha"].tolist(), sub["acc"].tolist(), marker="o", label=f"Steered (k={k})")
    plt.xlabel("alpha (DW-vote steering strength)")
    plt.ylabel("MATH accuracy")
    plt.title(f"MATH: Baseline vs DW-steered • {base_name}")
    plt.legend()
    plt.tight_layout()
    png_path = os.path.join(out_dir, "math_dw_vote_eval.png")
    plt.savefig(png_path, dpi=150, bbox_inches="tight")
    plt.show()
    print(f"Saved plot PNG → {png_path}")

    del model
    clean_up(device=None)
    return df

def run_all_models_on_math(models: List[str] = None):
    recs = load_math_dataset(MATH_PATHS_TRY)[:LIMIT]
    used = models if models is not None else MODELS
    outs = []
    for m in used:
        print(f"\n===== Evaluating {m} on MATH =====")
        df = run_math_dw_compare(model_path=m, math_records=recs, ks=K_LIST, alphas=ALPHAS)
        outs.append(df.assign(model_path=m))
    big = pd.concat(outs, ignore_index=True)
    out_csv = "./results/math_dw_vote_eval_all_models.csv"
    os.makedirs("./results", exist_ok=True)
    big.to_csv(out_csv, index=False)
    print(f"\nWrote combined results: {out_csv}")
    return big

# ---------- Example calls ----------
# big_df = run_all_models_on_math(MODELS)
df_single = run_math_dw_compare(MODELS[1], load_math_dataset(MATH_PATHS_TRY)[:LIMIT], ks=[32, 64], alphas=[-20, -10, 10, 20])


Loaded MATH from data/math500.jsonl (N=500)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]


[DW-steering] Loading ridge models & performance...

Building combined_coefs for k=32 ...
[DW] alpha=-20, k=32 → generating ...
