In [None]:
from src.data.load_csqa import load_csqa
from src.traces_utils.store import TraceStore
import numpy as np
import pandas as pd
from pathlib import Path
import json
import torch

**Data creation using extract_trace_csqa_gpt.py**

python -m src.cli.extract_trace_csqa_gpt --split validation --limit 8 --batch_size 2 --max_seq_len auto

- here 8 rows for test purposes

**Dane**

In [22]:
df = load_csqa("validation", limit=3)
print(df.columns)
print(df.loc[0, "text"])
print(df.loc[0, "answerKey"])
print(df.loc[0, "csqa_choices"])
assert isinstance(df.loc[0,"csqa_choices"], list) and len(df.loc[0,"csqa_choices"]) == 5

Index(['example_id', 'text', 'answerKey', 'correct_idx', 'csqa_choices'], dtype='object')
Q: A revolving door is convenient for two direction travel, but it also serves as a security measure at a what?
Choices:
A: bank
B: library
C: department store
D: mall
E: new york
Answer:
A
[{'label': 'A', 'text': 'bank'}, {'label': 'B', 'text': 'library'}, {'label': 'C', 'text': 'department store'}, {'label': 'D', 'text': 'mall'}, {'label': 'E', 'text': 'new york'}]


In [23]:
run_dir = Path(r"traces\20260112-175054_gpt2_csqa_validation_n8")

**parquet file check**

In [24]:
meta = json.load(open(run_dir/"meta.json"))
T = meta["max_seq_len"]

df_tok = pd.read_parquet(run_dir/"tokens.parquet")
print(df_tok.columns)
print("rows:", len(df_tok), "T:", T)

# kluczowe sanity:
lens = df_tok["input_ids"].apply(len)
print("input_ids lens min/max:", lens.min(), lens.max())
print("mask sums:", df_tok["attention_mask"].apply(sum).describe())

print(df_tok.loc[0, "text"])
print(df_tok.loc[0, "csqa_choices"])

Index(['example_id', 'text', 'input_ids', 'attention_mask', 'offset_mapping',
       'tokens', 'answerKey', 'csqa_choices'],
      dtype='object')
rows: 8 T: 54
input_ids lens min/max: 54 54
mask sums: count     8.000000
mean     47.625000
std       3.852179
min      43.000000
25%      44.750000
50%      47.000000
75%      49.750000
max      54.000000
Name: attention_mask, dtype: float64
Q: A revolving door is convenient for two direction travel, but it also serves as a security measure at a what?
Choices:
A: bank
B: library
C: department store
D: mall
E: new york
Answer:
[{'label': 'A', 'text': 'bank'} {'label': 'B', 'text': 'library'}
 {'label': 'C', 'text': 'department store'} {'label': 'D', 'text': 'mall'}
 {'label': 'E', 'text': 'new york'}]


**Zarr check**

In [25]:
store = TraceStore(run_dir)
print(store.meta)
print(store.arrays())

eid = store.tokens.loc[0, "example_id"]
A = store.attn(eid, side="dec", kind="self", layer=0, head=0)  # (T,T)
print("attn shape:", A.shape, "nan:", np.isnan(A).sum())

{'run_id': '20260112-175054_gpt2_csqa_validation_n8', 'model': 'gpt2', 'arch': 'dec', 'dataset': 'csqa', 'split': 'validation', 'n_examples': 8, 'max_seq_len': 54, 'num_layers': 12, 'num_heads': 12, 'head_dim': 64, 'dtype': 'float16', 'capture': ['attn', 'qkv', 'hidden', 'resid'], 'has_targets': None, 'time': '2026-01-12 17:51:03'}
{'dec_self_attn': (8, 12, 12, 54, 54), 'dec_self_q': (8, 12, 12, 54, 64), 'dec_self_k': (8, 12, 12, 54, 64), 'dec_self_v': (8, 12, 12, 54, 64), 'dec_hidden': (8, 13, 54, 768), 'dec_res_embed': (8, 54, 768), 'dec_res_pre_attn': (8, 12, 54, 768), 'dec_res_post_attn': (8, 12, 54, 768), 'dec_res_post_mlp': (8, 12, 54, 768)}
attn shape: (54, 54) nan: 0


**Sanity-check: recompute attention from stored Q,K and compare to stored attn**

In [None]:
# example index of a transformer decision trace
ex_id = store.tokens.loc[0, "example_id"]
enc = store.encodings(ex_id)
mask = np.array(enc["attention_mask"], dtype=np.int64)  
T = mask.shape[0]

print("ex_id:", ex_id)
print("T:", T, "mask.sum:", int(mask.sum()))
print("arrays:", store.arrays())

def causal_mask(T: int) -> np.ndarray:
    # allow attend to self and past: j <= i
    return np.tril(np.ones((T, T), dtype=bool))

def pad_key_mask(mask_1d: np.ndarray) -> np.ndarray:
    # keys with mask=0 are blocked for all Qs
    return (mask_1d.astype(bool)[None, :])  # shape (1, T) broadcast over Qs

def softmax_np(x: np.ndarray, axis=-1) -> np.ndarray:  # stable softmax
    x = x - np.max(x, axis=axis, keepdims=True)
    ex = np.exp(x)
    return ex / np.sum(ex, axis=axis, keepdims=True)

def attn_from_qk(q: np.ndarray, k: np.ndarray, mask_1d: np.ndarray) -> np.ndarray:
    """
    q,k: (T,d) numpy
    returns: (T,T) attention probs (float64) matching GPT-2 masking:
      - causal mask 
      - padding keys blocked (mask=0)
      - only compares active query rows.
    """
    d = q.shape[-1]
    scores = (q @ k.T) / np.sqrt(d)  # (T,T)

    # build combined mask: allowed positions True
    allow = causal_mask(T) & pad_key_mask(mask_1d)  # (T,T) via broadcast
    # set disallowed to very negative
    scores = np.where(allow, scores, -1e9)

    probs = softmax_np(scores, axis=-1)  # (T,T)
    return probs

def compare_one(layer: int, head: int, verbose=True):
    # stored arrays
    A_stored = store.attn(ex_id, layer=layer, head=head, side="dec", kind="self")  # (T,T)
    q = store.qkv(ex_id, which="q", layer=layer, head=head, side="dec", kind="self")  # (T,d)
    k = store.qkv(ex_id, which="k", layer=layer, head=head, side="dec", kind="self")  # (T,d)

    # recompute
    A_calc = attn_from_qk(q.astype(np.float64), k.astype(np.float64), mask)

    # compare only on active query rows so ones that are real tokens
    q_active = mask.astype(bool)
    diff = A_calc[q_active] - A_stored[q_active].astype(np.float64)

    max_abs = float(np.max(np.abs(diff)))
    mean_abs = float(np.mean(np.abs(diff)))
    rmse = float(np.sqrt(np.mean(diff**2)))

    # row-sum sanity
    row_sum_calc = A_calc[q_active].sum(axis=-1)
    row_sum_stored = A_stored[q_active].astype(np.float64).sum(axis=-1)

    if verbose:
        print(f"[L{layer} H{head}] max_abs={max_abs:.6g} mean_abs={mean_abs:.6g} rmse={rmse:.6g} "
              f"| row_sum(calc) min/mean/max={row_sum_calc.min():.6g}/{row_sum_calc.mean():.6g}/{row_sum_calc.max():.6g} "
              f"| row_sum(stored) min/mean/max={row_sum_stored.min():.6g}/{row_sum_stored.mean():.6g}/{row_sum_stored.max():.6g}")

    return {"layer": layer, "head": head, "max_abs": max_abs, "mean_abs": mean_abs, "rmse": rmse}

# test a few ]
tests = [(0,0), (0,5), (5,0), (11,11)]
results = [compare_one(L,H) for (L,H) in tests]

# stricter pass/fail heuristic for float16 traces:
# expected small differences due to float16 number format + potential attention scaling/masking conventions
# acceptable if : max_abs <= 1e-2, mean_abs <= 1e-4..1e-3
max_abs_all = max(r["max_abs"] for r in results)
mean_abs_all = sum(r["mean_abs"] for r in results)/len(results)

print("\nSummary:")
print("max_abs_all:", max_abs_all)
print("mean_abs_all:", mean_abs_all)

if max_abs_all < 5e-2 and mean_abs_all < 5e-3:
    print("All good, stored attention matches attention recomputed from stored Q,K (within tolerance).")
else:
    print("Potential mismatch, differences are larger than expected.")


ex_id: 701fac8b8c04ab56c4394b2e7b2aa8df
T: 54 mask.sum: 54
arrays: {'dec_self_attn': (8, 12, 12, 54, 54), 'dec_self_q': (8, 12, 12, 54, 64), 'dec_self_k': (8, 12, 12, 54, 64), 'dec_self_v': (8, 12, 12, 54, 64), 'dec_hidden': (8, 13, 54, 768), 'dec_res_embed': (8, 54, 768), 'dec_res_pre_attn': (8, 12, 54, 768), 'dec_res_post_attn': (8, 12, 54, 768), 'dec_res_post_mlp': (8, 12, 54, 768)}
[L0 H0] max_abs=0.00012804 mean_abs=4.81548e-06 rmse=1.18393e-05 | row_sum(calc) min/mean/max=1/1/1 | row_sum(stored) min/mean/max=0.999891/1/1.00015
[L0 H5] max_abs=0.000257051 mean_abs=3.26902e-06 rmse=1.93978e-05 | row_sum(calc) min/mean/max=1/1/1 | row_sum(stored) min/mean/max=0.999753/1.00001/1.0003
[L5 H0] max_abs=0.000787193 mean_abs=6.44366e-06 rmse=3.91147e-05 | row_sum(calc) min/mean/max=1/1/1 | row_sum(stored) min/mean/max=0.999755/0.999991/1.00024
[L11 H11] max_abs=0.000344378 mean_abs=5.38771e-06 rmse=2.37184e-05 | row_sum(calc) min/mean/max=1/1/1 | row_sum(stored) min/mean/max=0.999687/0.99