logit-restricted multiple-choice evaluation

argmax(logits[" A"," B"," C"," D"," E"])

# Intervention vectors : 
    Based on :
        *correct-wrong
    
        *logit gap

        *entropy

        *tuned-lens-agreement

1. individual metrics separetly example large logit gap
2. combined vector : q=w1​⋅(logit gap)−w2​⋅entropy+w3​⋅tuned-lens agreement+...

**EDA**

In [6]:
from pathlib import Path
import numpy as np
import pandas as pd
import zarr

run_dir = Path(r"traces\20251216-193911_gpt2_csqa_validation_n4")

df = pd.read_parquet(run_dir / "tokens.parquet")

z_hidden = zarr.open(str(run_dir / "dec_hidden.zarr"), mode="r")["h"]
z_pre = zarr.open(str(run_dir / "dec_res_pre_attn.zarr"), mode="r")["x"]
z_post_mlp = zarr.open(str(run_dir / "dec_res_post_mlp.zarr"), mode="r")["x"]

print("df rows:", len(df))
print("dec_hidden:", z_hidden.shape)
print("dec_res_pre_attn:", z_pre.shape)
print("dec_res_post_mlp:", z_post_mlp.shape)


df rows: 4
dec_hidden: (4, 13, 128, 768)
dec_res_pre_attn: (4, 12, 128, 768)
dec_res_post_mlp: (4, 12, 128, 768)


In [7]:
import re

LABELS = ["A","B","C","D","E"]

def parse_csqa_from_text(txt: str):
    q = re.search(r"Q:\s*(.*?)\nChoices:", txt, flags=re.S)
    if not q:
        raise ValueError("Could not parse question")
    question = q.group(1).strip()

    choices = {}
    for lab in LABELS:
        m = re.search(rf"\n{lab}:\s*(.*)", txt)
        if m:
            choices[lab] = m.group(1).strip()
    if len(choices) < 2:
        raise ValueError(f"Parsed only {len(choices)} choices")

    return question, choices

q0, c0 = parse_csqa_from_text(df.loc[0, "text"])
print("question[0]:", q0[:120], "...")
print("choices[0]:", c0)


ValueError: Parsed only 1 choices

# Conditional Intervention

In [4]:
import os, json
from pathlib import Path

import numpy as np
import pandas as pd
import zarr

run_dir = Path(r"traces\20251216-193911_gpt2_csqa_validation_n4")

print("RUN DIR:", run_dir.resolve())
assert run_dir.exists(), f"Not found: {run_dir}"

meta_path = run_dir / "meta.json"
print("\n=== meta.json ===")
with meta_path.open("r", encoding="utf-8") as f:
    meta = json.load(f)
print(json.dumps(meta, indent=2)[:2000], "...\n")

tok_path = run_dir / "tokens.parquet"
print("=== tokens.parquet ===")
df = pd.read_parquet(tok_path)
print("rows:", len(df), "cols:", list(df.columns))
print(df.head(2)[["example_id", "text"]])

i = 0
print("\n--- example[0] tokens preview ---")
print("example_id:", df.loc[i, "example_id"])
print("text:", df.loc[i, "text"][:200], "...")
print("input_ids[:25]:", df.loc[i, "input_ids"][:25])
print("tokens[:25]:", df.loc[i, "tokens"][:25])
print("attention_mask sum:", int(np.sum(df.loc[i, "attention_mask"])))

for col in ["answerKey", "csqa_choices"]:
    if col in df.columns:
        print(f"\n--- CSQA column: {col} ---")
        print(df[col].head(2).tolist())

def describe_zarr_store(store_path: Path):
    root = zarr.open(str(store_path), mode="r")
    if isinstance(root, zarr.hierarchy.Group):
        print(f"\n[{store_path.name}] (Group) keys:", list(root.array_keys()))
        for k in root.array_keys():
            arr = root[k]
            print(f"  - {k}: shape={arr.shape}, chunks={arr.chunks}, dtype={arr.dtype}")
    else:
        arr = root
        print(f"\n[{store_path.name}] (Array) shape={arr.shape}, chunks={arr.chunks}, dtype={arr.dtype}")

def get_array_from_store(store_path: Path, key: str):
    root = zarr.open(str(store_path), mode="r")
    if isinstance(root, zarr.hierarchy.Group):
        return root[key]
    return root  # array store

zarr_paths = sorted(run_dir.glob("*.zarr"))
print("\n=== ZARR stores found ===")
for p in zarr_paths:
    print(" -", p.name)

for p in zarr_paths:
    describe_zarr_store(p)
9
def quick_stats(arr, name, slc=None):
    x = arr[:] if slc is None else arr[slc]
    x = np.asarray(x)
    finite = np.isfinite(x)
    print(f"\n{name}: shape={x.shape}, dtype={x.dtype}")
    print("  finite%:", float(finite.mean()) * 100.0)
    if finite.any():
        xf = x[finite]
        print("  min/mean/max:", float(xf.min()), float(xf.mean()), float(xf.max()))
        print("  abs-mean:", float(np.abs(xf).mean()))
    else:
        print("  all non-finite (bad)")


if (run_dir / "dec_hidden.zarr").exists():
    arr = get_array_from_store(run_dir / "dec_hidden.zarr", "h")
    quick_stats(arr, "dec_hidden (sample ex0, layer0)")
    # ex0, layer0, first 5 tokens, first 8 dims
    print("  slice [0,0,:5,:8] =", np.asarray(arr[0,0,:5,:8]))

# 2) attentions
if (run_dir / "dec_self_attn.zarr").exists():
    arr = get_array_from_store(run_dir / "dec_self_attn.zarr", "attn")
    x = np.asarray(arr[0,0,0,:5,:5])  # ex0, layer0, head0, first 5x5
    print("\nattn small slice [0,0,0,:5,:5]:\n", x)
    row_sums = x.sum(axis=-1)
    print("row_sums (should be ~1.0 if not masked in this slice):", row_sums)

# 3) QKV
if (run_dir / "dec_self_qkv.zarr").exists():
    q = get_array_from_store(run_dir / "dec_self_qkv.zarr", "q")
    k = get_array_from_store(run_dir / "dec_self_qkv.zarr", "k")
    v = get_array_from_store(run_dir / "dec_self_qkv.zarr", "v")
    quick_stats(q, "q (ex0, layer0, head0, first 5 tokens, first 8 dims)", slc=(0,0,0,slice(0,5),slice(0,8)))
    quick_stats(k, "k (ex0, layer0, head0, first 5 tokens, first 8 dims)", slc=(0,0,0,slice(0,5),slice(0,8)))
    quick_stats(v, "v (ex0, layer0, head0, first 5 tokens, first 8 dims)", slc=(0,0,0,slice(0,5),slice(0,8)))

# 4) Residual checkpoints
for nm, key in [
    ("dec_res_embed.zarr", "x"),
    ("dec_res_pre_attn.zarr", "x"),
    ("dec_res_post_attn.zarr", "x"),
    ("dec_res_post_mlp.zarr", "x"),
]:
    p = run_dir / nm
    if p.exists():
        arr = get_array_from_store(p, key)
        print(f"\n=== {nm} tiny slice ===")
        print("shape:", arr.shape, "dtype:", arr.dtype)
        print("ex0 slice:", np.asarray(arr[0, :5, :8]) if arr.ndim == 3 else np.asarray(arr[0, 0, :5, :8]))

RUN DIR: C:\Users\mikol\OneDrive\Pulpit\Praca Magisterska\Transformer-Decision-Traces\traces\20251216-193911_gpt2_csqa_validation_n4

=== meta.json ===
{
  "run_id": "20251216-193911_gpt2_csqa_validation_n4",
  "model": "gpt2",
  "arch": "dec",
  "dataset": "csqa",
  "split": "validation",
  "n_examples": 4,
  "max_seq_len": 128,
  "num_layers": 12,
  "num_heads": 12,
  "head_dim": 64,
  "layers_stored": {
    "enc": [],
    "dec": [
      0,
      1,
      2,
      3,
      4,
      5,
      6,
      7,
      8,
      9,
      10,
      11
    ]
  },
  "heads_stored": {
    "enc": [],
    "dec": [
      0,
      1,
      2,
      3,
      4,
      5,
      6,
      7,
      8,
      9,
      10,
      11
    ]
  },
  "dtype": "float16",
  "capture": [
    "attn",
    "qkv",
    "hidden",
    "resid"
  ],
  "has_targets": null,
  "time": "2025-12-16 19:39:17"
} ...

=== tokens.parquet ===
rows: 4 cols: ['example_id', 'text', 'input_ids', 'attention_mask', 'offset_mapping', 'tokens', 'a

In [None]:
import os, json, re
from pathlib import Path

import numpy as np
import pandas as pd
import zarr

RUN_DIR = r"C:\Users\mikol\OneDrive\Pulpit\Praca Magisterska\Transformer-Decision-Traces\traces\20251216-193911_gpt2_csqa_validation_n4"

run = Path(RUN_DIR)
meta = json.loads((run / "meta.json").read_text(encoding="utf-8"))
df = pd.read_parquet(run / "tokens.parquet")

print("RUN DIR:", run)
print("\n=== meta.json (key fields) ===")
print({k: meta[k] for k in ["run_id","model","arch","dataset","split","n_examples","max_seq_len","num_layers","num_heads","head_dim","dtype","capture"]})
print("\n=== tokens.parquet ===")
print("rows:", len(df), "cols:", list(df.columns))

# OPEN ZARR STORES
attn = zarr.open(str(run / "dec_self_attn.zarr"), mode="r")["attn"]      # (N,L,H,T,T)
qkv  = zarr.open(str(run / "dec_self_qkv.zarr"),  mode="r")
Q = qkv["q"]; K = qkv["k"]; V = qkv["v"]                                 # (N,L,H,T,d)
pre  = zarr.open(str(run / "dec_res_pre_attn.zarr"),  mode="r")["x"]      # (N,L,T,D)
post = zarr.open(str(run / "dec_res_post_attn.zarr"), mode="r")["x"]      # (N,L,T,D)

print("\n=== ZARR SHAPES ===")
print("attn:", attn.shape, attn.dtype)
print("Q:", Q.shape, Q.dtype, "| K:", K.shape, "| V:", V.shape)
print("pre:", pre.shape, pre.dtype, "| post:", post.shape, post.dtype)

def softmax(x, axis=-1):
    x = x.astype(np.float64)  # stability
    x = x - np.max(x, axis=axis, keepdims=True)
    ex = np.exp(x)
    return ex / np.sum(ex, axis=axis, keepdims=True)

def find_choice_char_spans(text: str):
    txt = text.replace("\\n", "\n").replace("\r\n", "\n")

    pat = re.compile(r"(?:^|\n)\s*([A-E])\s*[:\)\.]\s*")
    matches = list(pat.finditer(txt))
    if len(matches) < 2:
        return {}

    spans = {}
    for j, m in enumerate(matches):
        L = m.group(1)
        start = m.end()
        end = matches[j + 1].start() if j + 1 < len(matches) else len(txt)
        spans[L] = (start, end)
    return spans

def charspan_to_token_idxs(offset_mapping, char_span):
    a, b = char_span
    idxs = []
    for t, (s, e) in enumerate(offset_mapping):
        if s == 0 and e == 0:
            continue
        if e <= a:
            continue
        if s >= b:
            break
        if max(s, a) < min(e, b):
            idxs.append(t)
    return idxs

LETTER_IDS = {
    "A": [32, 317],
    "B": [33, 347],
    "C": [34, 327],
    "D": [35, 360],
    "E": [36, 412],
}
print("\nLETTER_IDS:", LETTER_IDS)

def predict_letter_from_logits(last_logits):
    p = softmax(last_logits, axis=-1)
    letter_probs = {}
    for L, ids in LETTER_IDS.items():
        ids = [i for i in ids if i < p.shape[0]]
        letter_probs[L] = float(np.max(p[ids])) if ids else float("nan")
    pred = max(letter_probs, key=lambda k: letter_probs[k])
    return pred, letter_probs[pred], letter_probs


print("\n=== TEXT DEBUG (repr preview) ===")
print(repr(df.loc[0, "text"][:400]))
print("\n=== TEXT DEBUG (tail) ===")
print(df.loc[0, "text"][-200:])

rows = []
for i in range(len(df)):
    text = df.loc[i, "text"]
    gold = df.loc[i, "answerKey"] if "answerKey" in df.columns else None

    input_ids = np.array(df.loc[i, "input_ids"], dtype=np.int64)
    attn_mask = np.array(df.loc[i, "attention_mask"], dtype=np.int64)
    offsets   = df.loc[i, "offset_mapping"]

    # last non-pad position (for left padding, this is last token where mask=1)
    last = int(np.where(attn_mask == 1)[0][-1])

    # option spans 
    spans = find_choice_char_spans(text)
    token_spans = {L: charspan_to_token_idxs(offsets, sp) for L, sp in spans.items()}
    choices_parsed = sum(1 for L in "ABCDE" if L in token_spans and len(token_spans[L]) > 0)

    # attention tensor: mean over layers/heads at "query token = last"
    # attn[i] shape (L,H,T,T) ; take [:,:,last,:] -> (L,H,T)
    A = np.asarray(attn[i, :, :, last, :], dtype=np.float32)
    A_mean = A.mean(axis=(0, 1))  # (T,)

    attn_mass = {}
    for L in "ABCDE":
        toks = token_spans.get(L, [])
        attn_mass[L] = float(A_mean[toks].sum()) if toks else float("nan")

    # best attention-picked letter
    attn_pick = None
    if any(np.isfinite(attn_mass[L]) for L in "ABCDE"):
        attn_pick = max("ABCDE", key=lambda L: (attn_mass[L] if np.isfinite(attn_mass[L]) else -1e9))

    # residual norms (cast to float32 to avoid overflow)
    pre_i  = np.asarray(pre[i], dtype=np.float32)   # (L,T,D)
    post_i = np.asarray(post[i], dtype=np.float32)
    pre_norm_meanL  = float(np.linalg.norm(pre_i[:, last, :],  axis=-1).mean())
    post_norm_meanL = float(np.linalg.norm(post_i[:, last, :], axis=-1).mean())

    rows.append({
        "i": i,
        "example_id": df.loc[i, "example_id"],
        "gold": gold,
        "attn_pick": attn_pick,
        "attn_gold": attn_mass.get(gold, float("nan")) if gold else float("nan"),
        "attn_pred": attn_mass.get(attn_pick, float("nan")) if attn_pick else float("nan"),
        "pre_norm_meanL": pre_norm_meanL,
        "post_norm_meanL": post_norm_meanL,
        "choices_parsed": choices_parsed,
    })

res = pd.DataFrame(rows)

print("\n=== ATTENTION-TO-OPTIONS SUMMARY ===")
print(res)

# diagnostics
print("\nchoices_parsed distribution:")
print(res["choices_parsed"].value_counts(dropna=False))

if res["choices_parsed"].max() == 0:
    print("\n[RED FLAG] choices_parsed is 0 for all examples e.g. prompt format doesn't match the parser.")

# difference metric 
valid = res[np.isfinite(res["attn_gold"]) & np.isfinite(res["attn_pred"])]
if len(valid) > 0:
    diffs = (valid["attn_pred"] - valid["attn_gold"]).to_list()
    print("\n(attn_pred - attn_gold) per valid example:")
    print(diffs)
else:
    print("\n(attn_pred - attn_gold) per example:")
    print([float("nan")] * len(res))

RUN DIR: C:\Users\mikol\OneDrive\Pulpit\Praca Magisterska\Transformer-Decision-Traces\traces\20251216-193911_gpt2_csqa_validation_n4

=== meta.json (key fields) ===
{'run_id': '20251216-193911_gpt2_csqa_validation_n4', 'model': 'gpt2', 'arch': 'dec', 'dataset': 'csqa', 'split': 'validation', 'n_examples': 4, 'max_seq_len': 128, 'num_layers': 12, 'num_heads': 12, 'head_dim': 64, 'dtype': 'float16', 'capture': ['attn', 'qkv', 'hidden', 'resid']}

=== tokens.parquet ===
rows: 4 cols: ['example_id', 'text', 'input_ids', 'attention_mask', 'offset_mapping', 'tokens', 'answerKey', 'csqa_choices']

=== ZARR SHAPES ===
attn: (4, 12, 12, 128, 128) float16
Q: (4, 12, 12, 128, 64) float16 | K: (4, 12, 12, 128, 64) | V: (4, 12, 12, 128, 64)
pre: (4, 12, 128, 768) float16 | post: (4, 12, 128, 768) float16

LETTER_IDS: {'A': [32, 317], 'B': [33, 347], 'C': [34, 327], 'D': [35, 360], 'E': [36, 412]}

=== TEXT DEBUG (repr preview) ===
'Q: A revolving door is convenient for two direction travel, but it 

In [None]:
from src.data.load_csqa import load_csqa

df = load_csqa(split="validation", limit=2)

print("Columns:", df.columns.tolist())
print("\n--- Prompt preview ---")
print(df.loc[0, "text"][:400])

print("\nanswerKey:", df.loc[0, "answerKey"], "correct_idx:", df.loc[0, "correct_idx"])
print("csqa_choices type:", type(df.loc[0, "csqa_choices"]))
print("csqa_choices len:", len(df.loc[0, "csqa_choices"]))
print("csqa_choices[0]:", df.loc[0, "csqa_choices"][0])

Columns: ['example_id', 'text', 'answerKey', 'correct_idx', 'csqa_choices']

--- Prompt preview ---
Q: A revolving door is convenient for two direction travel, but it also serves as a security measure at a what?
Choices:
A: bank
B: library
C: department store
D: mall
E: new york
Answer:

answerKey: A correct_idx: 0
csqa_choices type: <class 'list'>
csqa_choices len: 5
csqa_choices[0]: {'label': 'A', 'text': 'bank'}


In [8]:
from datasets import load_dataset
import json

ds = load_dataset("commonsense_qa", split="validation")

ex = ds[0]
print("hf keys:", ex.keys())
print("id:", ex["id"])
print("question:", ex["question"][:120])
print("choices labels:", ex["choices"]["label"])
print("choices texts head:", [t[:25] for t in ex["choices"]["text"]])
print("answerKey:", ex["answerKey"])


  from .autonotebook import tqdm as notebook_tqdm


hf keys: dict_keys(['id', 'question', 'question_concept', 'choices', 'answerKey'])
id: 1afa02df02c908a558b4036e80242fac
question: A revolving door is convenient for two direction travel, but it also serves as a security measure at a what?
choices labels: ['A', 'B', 'C', 'D', 'E']
choices texts head: ['bank', 'library', 'department store', 'mall', 'new york']
answerKey: A
