In [50]:
#!/usr/bin/env python
import re
from dataclasses import dataclass
from typing import List, Dict, Any

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =========================
# CONFIG
# =========================
CACHE_DIR = "/mnt/raid10/ak-research-01/ak-research-01/codes/.cache"
DATA_PATH = "/mnt/raid10/ak-research-01/ak-research-01/codes/steer-vector/latentqa/3_input_ablation/generated_dataset/input_ablation_llama_instruct_train.jsonl"

EXPLAINER_MODEL_PATH = (
    # "/mnt/raid10/ak-research-01/ak-research-01/"
    # "codes/steer-vector/latentqa/3_input_ablation/models/qwen"
    # "meta-llama/Meta-Llama-3.1-8B-Instruct"    
    "Qwen/Qwen3-8B"
)

IS_QWEN = False          # True if explainer is Qwen-based
VAL_SPLIT = 0.05
SEED = 42

MAX_NEW_TOKENS = 64
BATCH_SIZE = 4           # adjust for VRAM

In [51]:

# =========================
# LOAD MODEL & TOKENIZER
# =========================

device = "cuda" if torch.cuda.is_available() else "cpu"

tk_kwargs: Dict[str, Any] = {}
model_kwargs: Dict[str, Any] = {}
if IS_QWEN:
    tk_kwargs["trust_remote_code"] = True
    model_kwargs["trust_remote_code"] = True

tokenizer = AutoTokenizer.from_pretrained(EXPLAINER_MODEL_PATH, **tk_kwargs)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"


model = AutoModelForCausalLM.from_pretrained(
    EXPLAINER_MODEL_PATH,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None,
    cache_dir=CACHE_DIR,
    **model_kwargs,
)
model.eval()


# =========================
# LOAD DATA (train/val split)
# =========================

raw_ds = load_dataset(
    "json",
    data_files={"train": DATA_PATH},
)["train"]

ds = raw_ds.train_test_split(test_size=VAL_SPLIT, seed=SEED)
# train_ds = ds["train"]      # not used here, but kept for clarity
val_ds = ds["test"]         # evaluation split

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [52]:

# =========================
# HELPERS
# =========================

LETTER_RE = re.compile(r"Answer\s*[:\-]\s*([ABCD])", re.IGNORECASE)
TRIPLE_ANGLE_RE = re.compile(r"<<<\s*Answer\s*:\s*([ABCD])\s*>>>", re.IGNORECASE)


@dataclass
class ParsedExplanation:
    has_changed: int
    answer_letter: str


def parse_explanation(text: str) -> ParsedExplanation:
    lower = text.lower()

    if "remain unchanged" in lower:
        has_changed = 0
    elif "change to" in lower or "would change" in lower or "will change" in lower:
        has_changed = 1
    else:
        has_changed = 0

    # Prefer <<<Answer: X>>>
    m = TRIPLE_ANGLE_RE.search(text)
    if m:
        letter = m.group(1).upper()
    else:
        m2 = LETTER_RE.search(text)
        if m2:
            letter = m2.group(1).upper()
        else:
            m3 = re.search(r"\b([ABCD])\b", text)
            letter = m3.group(1).upper() if m3 else ""

    return ParsedExplanation(has_changed=has_changed, answer_letter=letter)


def binary_f1(gold: List[int], pred: List[int]) -> float:
    tp = sum(1 for g, p in zip(gold, pred) if g == 1 and p == 1)
    fp = sum(1 for g, p in zip(gold, pred) if g == 0 and p == 1)
    fn = sum(1 for g, p in zip(gold, pred) if g == 1 and p == 0)

    if tp == 0:
        return 0.0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    return 2 * precision * recall / (precision + recall)


@torch.no_grad()
def generate_batch(prompts: List[str]) -> List[str]:
    enc = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
    ).to(device)

    out = model.generate(
        **enc,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False,
        temperature=0.0,
        pad_token_id=tokenizer.eos_token_id,
    )

    gens: List[str] = []
    for i in range(len(prompts)):
        prompt_len = enc["input_ids"][i].shape[0]
        gen_ids = out[i][prompt_len:]
        text = tokenizer.decode(gen_ids, skip_special_tokens=True)
        gens.append(text.strip())
    return gens

In [53]:

# =========================
# EVALUATION
# =========================
from tqdm import tqdm

def evaluate() -> None:
    N = len(val_ds)

    exact_flags: List[int] = []
    gold_has_changed: List[int] = []
    pred_has_changed: List[int] = []
    gold_letters: List[str] = []
    pred_letters: List[str] = []

    for start in tqdm(range(0, N, BATCH_SIZE)):
        end = min(start + BATCH_SIZE, N)
        batch = val_ds[start:end]
        prompts = batch["explainer_prompt"]
        preds = generate_batch(prompts)
        gold_texts = batch["explainer_output"]
        has_changeds = batch["has_changed"]
        target_answer_nohints = batch["target_answer_nohint"]

        for example, has_changed, target_answer_nohint, pred_text in zip(gold_texts, has_changeds, target_answer_nohints, preds):
            gold_text = example.strip()
            
            # exact match
            exact_flags.append(int(pred_text.strip() == gold_text))

            parsed = parse_explanation(pred_text)
            pred_has_changed.append(parsed.has_changed)
            pred_letters.append(parsed.answer_letter)

            gold_has_changed.append(int(has_changed))
            gold_letters.append(target_answer_nohint.strip().upper())

    # Exact Match
    exact_match = sum(exact_flags) / N * 100.0

    # Has-Changed F1
    has_changed_f1 = binary_f1(gold_has_changed, pred_has_changed) * 100.0

    # Content Match: only where explainer predicts "changed"
    correct_content = 0
    denom_content = 0
    for p_c, p_l, g_l in zip(pred_has_changed, pred_letters, gold_letters):
        if p_c == 1:
            denom_content += 1
            if p_l == g_l:
                correct_content += 1
    content_match = (
        100.0 * correct_content / denom_content if denom_content > 0 else 0.0
    )

    print("===== Input Ablation Explainer Evaluation (val split) =====")
    print(f"#examples (val): {N}")
    print(f"Exact Match:     {exact_match:.2f}%")
    print(f"Has-Changed F1:  {has_changed_f1:.2f}%")
    print(f"Content Match:   {content_match:.2f}% "
          f"(on {denom_content} predicted 'changed' examples)")


In [54]:
evaluate()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 155/155 [05:31<00:00,  2.14s/it]

===== Input Ablation Explainer Evaluation (val split) =====
#examples (val): 620
Exact Match:     0.00%
Has-Changed F1:  3.66%
Content Match:   16.67% (on 12 predicted 'changed' examples)



