In [4]:
#!/usr/bin/env python
"""
Build training data for the "Input Ablation: What parts of the input matter?"
task from Li et al. (2025) – Training Language Models to Explain Their Own
Computations.

For each MMLU question q:
    1. Sample a hint option ẋ ∈ {A,B,C,D}.
    2. Query the target LM on q (no hint) → y_nohint ∈ {A,B,C,D}.
    3. Query the target LM on q + hint(ẋ) → y_hint ∈ {A,B,C,D}.
    4. Label whether removing the hint would change the answer and what the
       new answer without the hint would be.
    5. Build an explainer prompt as in Appendix C.1 and a ground-truth
       explanation string:

       - If y_nohint == y_hint:
            "The output would remain unchanged from <<<Answer: y_nohint>>>."
       - Else:
            "The most likely output would change to <<<Answer: y_nohint>>>."

Outputs a JSONL file, one example per line.
"""

import json
import random
import re
from pathlib import Path

import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

CACHE_DIR = "/mnt/raid10/ak-research-01/ak-research-01/codes/.cache"

In [5]:
# -----------------------
# Model utilities
# -----------------------

def load_target_model(model_name: str, device: str = None):
    """
    Load an autoregressive LM from HuggingFace.
    Example model_name:
        - "meta-llama/Meta-Llama-3.1-8B-Instruct"
        - "Qwen/Qwen2.5-7B-Instruct"
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        device_map="auto" if device == "cuda" else None,
        cache_dir=CACHE_DIR
    )
    return model, tokenizer, device


In [6]:

def generate_answer_letter(model, tokenizer, device, prompt: str,
                           max_new_tokens: int = 8,
                           temperature: float = 0.0):
    """
    Run the target model on a prompt and extract a single answer letter A–D.

    Assumes the prompt ends with 'Answer:' (or similar). We then:
      1. Generate a short continuation,
      2. Parse the first A/B/C/D after the word 'Answer' or in the output.
    Returns None if parsing fails.
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=(temperature > 0),
            temperature=temperature,
            pad_token_id=tokenizer.eos_token_id,
        )

    gen_text = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:],
                                skip_special_tokens=True)

    # Try to find pattern "Answer: X"
    m = re.search(r"Answer\s*[:\-]\s*([ABCD])", gen_text, re.IGNORECASE)
    if m:
        return m.group(1).upper()

    # Fallback: first standalone capital A–D
    m = re.search(r"\b([ABCD])\b", gen_text)
    if m:
        return m.group(1).upper()

    return None


# -----------------------
# Prompt formatting
# -----------------------

def format_mmlu_question(example):
    """
    Format a MMLU example into a standard MCQ prompt for the *target model*.

    MMLU fields (HF 'cais/mmlu'): question, answer, choices.
    choices is a list of 4 options corresponding to A,B,C,D.
    """
    letters = ["A", "B", "C", "D"]
    lines = [example["question"].strip(), ""]
    for letter, choice in zip(letters, example["choices"]):
        lines.append(f"{letter}. {choice}")
    lines.append("")
    lines.append("Answer:")
    return "\n".join(lines)


def format_hint(letter: str, style: str):
    """
    style ∈ {"llama", "qwen"} to match Appendix E.3.

    - Llama-3.1-8B-Instruct:  'Hint: B'
    - Qwen-3-8B:              'Hint: The answer is B'
    """
    if style == "qwen":
        return f"Hint: The answer is {letter}"
    else:  # default: llama-style
        return f"Hint: {letter}"


def build_explainer_prompt(question_block: str, hint_text: str):
    """
    Build the explainer prompt from Appendix C.1 (Input Ablation).
    We keep it plain-text; if you use chat templates,
    you can wrap this into your chat format externally.
    """
    return (
        "[SYSTEM]\n"
        "The following are multiple choice questions (with a correct answer). "
        "Output only the answer letter (A, B, C, or D) and nothing else, "
        "in the format Answer: x, where x is one of A, B, C, or D.\n\n"
        "[USER]\n"
        "Question:\n"
        f"{question_block}\n\n"
        "Hint:\n"
        f"{hint_text}\n\n"
        "If the hint were removed how would the assistant answer change?\n\n"
        "[ASSISTANT]\n"
    )


def build_ground_truth_explanation(y_with_hint: str, y_without_hint: str):
    """
    Build the natural-language explanation string that the explainer
    should output, following Appendix C.1.
    """
    if y_with_hint is None or y_without_hint is None:
        return None, None

    if y_with_hint == y_without_hint:
        has_changed = 0  # unchanged
        explanation = (
            f"The output would remain unchanged from <<<Answer: {y_without_hint}>>>."
        )
    else:
        has_changed = 1  # changed
        explanation = (
            f"The most likely output would change to <<<Answer: {y_without_hint}>>>."
        )
    return explanation, has_changed

In [10]:

# -----------------------
# Dataset construction
# -----------------------

def build_input_ablation_dataset(
    model_name: str,
    output_path: str,
    hint_style: str = "llama",
    split: str = "test",           # "test" or "validation"
    max_examples: int = None,      # limit for quick experiments
    balance_classes: bool = True,  # approximate 50/50 changed vs unchanged
    seed: int = 0,
):
    random.seed(seed)

    # Load target LM
    model, tokenizer, device = load_target_model(model_name)

    # Load MMLU
    # We use the light wrapper tasksource/mmlu, but you can swap to cais/mmlu.
    ds = load_dataset("cais/mmlu", "all", split=split,
                     cache_dir=CACHE_DIR)

    letters = ["A", "B", "C", "D"]
    records_changed = []
    records_unchanged = []

    print(len(ds))
    
    for example in tqdm(ds, desc="Building input-ablation data"):
        if max_examples is not None and (
            len(records_changed) + len(records_unchanged) >= max_examples
        ):
            break

        question_block = format_mmlu_question(example)
        print(example)
        print(question_block)
        assert False
        

        # Sample a random hint option
        hint_letter = random.choice(letters)
        hint_text = format_hint(hint_letter, style=hint_style)

        # 1) Prediction WITHOUT hint
        base_prompt = question_block
        y_nohint = generate_answer_letter(
            model, tokenizer, device, base_prompt
        )

        # 2) Prediction WITH hint appended to question for the *target* model
        hinted_prompt = question_block + "\n\n" + hint_text + "\n"
        y_hint = generate_answer_letter(
            model, tokenizer, device, hinted_prompt
        )

        if y_nohint not in letters or y_hint not in letters:
            continue  # skip examples where we couldn't parse an answer

        explanation, has_changed = build_ground_truth_explanation(
            y_with_hint=y_hint, y_without_hint=y_nohint
        )
        if explanation is None:
            continue

        explainer_prompt = build_explainer_prompt(question_block, hint_text)

        record = {
            "subject": example.get("subject", None),
            "question": example["question"],
            "choices": example["choices"],      # list of 4 options
            "correct_answer": example["answer"],  # 0–3 index (MMLU format)
            "correct_letter": letters[example["answer"]],
            "target_model": model_name,
            "hint_letter": hint_letter,
            "hint_text": hint_text,
            "prompt_target_nohint": base_prompt,
            "prompt_target_with_hint": hinted_prompt,
            "target_answer_nohint": y_nohint,
            "target_answer_with_hint": y_hint,
            "has_changed": has_changed,         # 1 if output changes when removing hint
            "explainer_prompt": explainer_prompt,
            "explainer_output": explanation,    # ground-truth explanation
        }

        if has_changed:
            records_changed.append(record)
        else:
            records_unchanged.append(record)

    # Optionally downsample to roughly balance changed / unchanged
    if balance_classes and records_changed and records_unchanged:
        n = min(len(records_changed), len(records_unchanged))
        random.shuffle(records_changed)
        random.shuffle(records_unchanged)
        final_records = records_changed[:n] + records_unchanged[:n]
        random.shuffle(final_records)
    else:
        final_records = records_changed + records_unchanged

    # Save to JSONL
    out_path = Path(output_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with out_path.open("w", encoding="utf-8") as f:
        for rec in final_records:
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")

    print(f"Saved {len(final_records)} examples to {out_path}")

In [11]:

build_input_ablation_dataset(
    model_name="meta-llama/Meta-Llama-3.1-8B-Instruct",
    output_path="/mnt/raid10/ak-research-01/ak-research-01/codes/steer-vector/latentqa/3_input_ablation/model_output/input_ablation_llama_instruct_train.jsonl",
    hint_style="llama",      # or "qwen"
    split="test",            # or "validation"
    max_examples=20000,      # None for all
    balance_classes=True,
    seed=0,
)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

dataset_infos.json: 0.00B [00:00, ?B/s]

all/test-00000-of-00001.parquet:   0%|          | 0.00/3.50M [00:00<?, ?B/s]

all/validation-00000-of-00001.parquet:   0%|          | 0.00/408k [00:00<?, ?B/s]

all/dev-00000-of-00001.parquet:   0%|          | 0.00/76.5k [00:00<?, ?B/s]

all/auxiliary_train-00000-of-00001.parqu(…):   0%|          | 0.00/47.5M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/14042 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1531 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/285 [00:00<?, ? examples/s]

Generating auxiliary_train split:   0%|          | 0/99842 [00:00<?, ? examples/s]

AssertionError: 