# Urn Task: Bayesian Cognition in LLMs (TransformerLens)

This notebook tests whether a model behaves like a Bayesian observer on a two-urn sequential inference task.

**Key ideas**
- Behavioral readout: infer belief from next-token logits for a forced-choice answer (A vs B).
- Martingale drift: permutation invariance should hold under an i.i.d. model; deviations indicate order sensitivity.
- Head ablation: identify and remove heads that drive drift.
- Linear probes: decode Bayes posterior from internal states.

**Roadmap**
1) Single-episode demo (behavior) -> 2) Batch metrics
3) Drift test -> 4) Head ablation scan -> 5) Ablated behavior
6) Probes -> 7) Probe vs baseline vs ablated


In [None]:
# Cell 1 - Install dependencies (Colab-safe)
# If running locally and already installed, you can skip this cell.

# !pip -q install transformer_lens circuitsvis einops pandas numpy matplotlib tqdm scikit-learn

# Plotly renderer (CircuitVis)
try:
    import plotly.io as pio
    pio.renderers.default = 'notebook_connected'
except Exception as e:
    print('Plotly renderer not set:', e)


In [None]:
# Cell 2 - Imports + global seeds
import os
import math
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

import circuitsvis as cv
from IPython.display import display

# Reproducibility
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Determinism (best-effort)
try:
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
except Exception:
    pass

print('Seeds set to', SEED)


In [None]:
# Cell 3 - Global configuration (user-editable)

MODEL_ID = 'meta-llama/Llama-3.2-3B'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DTYPE = torch.float16 if DEVICE == 'cuda' else torch.float32

# Task config
urn_A_pX = 0.75
urn_B_pX = 0.25
prior_P_A = 0.5

seq_len_single = 20
seq_len_batch = 20
n_batch_episodes = 200

prompt_template_version = 'v1'  # v1 | v2 | v3

# Readout config (forced choice)
ANSWER_TOKEN_A = ' A'
ANSWER_TOKEN_B = ' B'
ANSWER_MEANS = {'A': 'X', 'B': 'Y'}

# Drift config
drift_counts_total_N = 12
drift_counts_nX = 6
n_permutations = 64
drift_metric = 'std'  # 'std' or 'max-min'

# Head search config
head_search_scope = 'all'  # 'all' or 'layers_subset'
layers_subset = list(range(0, 4))  # used only if scope == 'layers_subset'
top_k_heads_to_report = 30
top_k_heads_to_ablate_default = 5

# Ablation config
heads_to_ablate = []  # fill after seeing ranking
ablate_mode = 'zero'  # 'zero' or 'scale'
ablate_scale = 0.0

# Probe config
probe_target = 'posterior_logodds'  # 'posterior_logodds' or 'predictive_logodds'
probe_act_name = 'resid_post'  # 'resid_post' or 'resid_pre'
probe_position = 'answer_pos'  # last position
probe_train_size = 200
probe_test_size = 100
probe_epochs = 1  # unused for ridge; kept for API parity
probe_lr = 1e-3
probe_weight_decay = 0.0
probe_batch_size = 32
probe_type = 'ridge'  # 'ridge' or 'logistic' (ridge recommended)

# Output
output_dir = 'results_urn_task'
os.makedirs(output_dir, exist_ok=True)

print('MODEL_ID:', MODEL_ID)
print('DEVICE:', DEVICE, 'DTYPE:', DTYPE)
print('Task: A_pX=%.2f B_pX=%.2f prior=%.2f' % (urn_A_pX, urn_B_pX, prior_P_A))
print('Prompt template:', prompt_template_version)
print('Output dir:', output_dir)


In [None]:
# Cell 4 - Load model with TransformerLens

model = HookedTransformer.from_pretrained(MODEL_ID, device=DEVICE, dtype=DTYPE)
model.eval()

print('Loaded', MODEL_ID)
print('n_layers:', model.cfg.n_layers, 'n_heads:', model.cfg.n_heads, 'd_model:', model.cfg.d_model)
print('tokenizer:', type(model.tokenizer))


In [None]:
# Cell 5 - Tokenization sanity checks (strict)


def assert_single_token(text):
    ids = model.tokenizer.encode(text, add_special_tokens=False)
    if len(ids) != 1:
        raise RuntimeError(f'Tokenization error: {text!r} is not a single token (ids={ids}). Choose another label.')
    return ids[0]

# Check answer tokens
A_id = assert_single_token(ANSWER_TOKEN_A)
B_id = assert_single_token(ANSWER_TOKEN_B)

print('Chosen answer tokens:')
print('A token:', repr(ANSWER_TOKEN_A), 'id=', A_id, 'decoded=', model.tokenizer.decode([A_id]))
print('B token:', repr(ANSWER_TOKEN_B), 'id=', B_id, 'decoded=', model.tokenizer.decode([B_id]))

# Also check plain A/B for info
for s in ['A','B']:
    ids = model.tokenizer.encode(s, add_special_tokens=False)
    print('Tokenization for', repr(s), '->', ids)

# Observation symbol tokens (X/Y)
obs_tokens = ['X','Y',' X',' Y']
print('Observation tokenization lengths:')
for s in obs_tokens:
    ids = model.tokenizer.encode(s, add_special_tokens=False)
    print(repr(s), 'len=', len(ids), 'ids=', ids)


In [None]:
# Cell 6 - Urn task math engine (ground truth Bayes)


def bayes_update_posterior(prior_P_A, pX_A, pX_B, obs_seq):
    pA = prior_P_A
    for obs in obs_seq:
        if obs == 'X':
            lik_A = pX_A
            lik_B = pX_B
        else:
            lik_A = 1 - pX_A
            lik_B = 1 - pX_B
        unnorm_A = lik_A * pA
        unnorm_B = lik_B * (1 - pA)
        denom = unnorm_A + unnorm_B
        pA = unnorm_A / denom if denom > 0 else 0.5
    return pA


def bayes_predictive(posterior_P_A, pX_A, pX_B):
    return posterior_P_A * pX_A + (1 - posterior_P_A) * pX_B


def logodds(p):
    p = max(1e-9, min(1 - 1e-9, p))
    return math.log(p / (1 - p))


def compute_truth_table(obs_seq, prior_P_A, pX_A, pX_B, laplace=0.0):
    rows = []
    pA = prior_P_A
    nX = 0
    for t, obs in enumerate(obs_seq, start=1):
        if obs == 'X':
            nX += 1
        pA = bayes_update_posterior(prior_P_A, pX_A, pX_B, obs_seq[:t])
        p_next_X = bayes_predictive(pA, pX_A, pX_B)
        denom = t + 2 * laplace if laplace > 0 else t
        p_hat = (nX + laplace) / denom if denom > 0 else 0.5
        rows.append({
            'step': t,
            'obs': obs,
            'posterior_P_A': pA,
            'bayes_P_next_X': p_next_X,
            'bayes_logodds': logodds(p_next_X),
            'freq_Phat_X': p_hat,
        })
    return pd.DataFrame(rows)

print('Example truth table:')
print(compute_truth_table(['X','Y','X'], prior_P_A, urn_A_pX, urn_B_pX))


In [None]:
# Cell 7 - Prompt design (behavioral, not calculation)

TEMPLATES = {
    'v1': (
        'Two urns produce X/Y. Urn A makes X with probability {pA:.2f}. '
        'Urn B makes X with probability {pB:.2f}. Prior P(A)={prior:.2f}. '
        'Observed sequence: {seq}.
'
        'Next output is more likely:
'
        'A) X
'
        'B) Y
'
        'Answer:'
    ),
    'v2': (
        'We have two generators. A: P(X)={pA:.2f}. B: P(X)={pB:.2f}. '
        'Assume i.i.d. draws. Prior P(A)={prior:.2f}.
'
        'Sequence so far: {seq}.
'
        'Next is more likely?
'
        'A) X
'
        'B) Y
'
        'Answer:'
    ),
    'v3': (
        'Urn A: {pA:.2f} of X. Urn B: {pB:.2f} of X. Prior(A)={prior:.2f}.
'
        'Evidence: {seq}.
'
        'Choose the more likely next symbol.
'
        'A) X
'
        'B) Y
'
        'Answer:'
    ),
}


def build_prompt(obs_seq, pA, pB, prior, template_id='v1'):
    seq_str = ' '.join(obs_seq) if len(obs_seq) > 0 else '(none yet)'
    return TEMPLATES[template_id].format(pA=pA, pB=pB, prior=prior, seq=seq_str)

print(build_prompt(['X','Y','X'], urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version))


In [None]:
# Cell 8 - Extract model belief from logits (forced-choice readout)

@torch.no_grad()
def get_ab_logprobs(prompt_str, token_id_A, token_id_B):
    logits = model(prompt_str)
    last = logits[0, -1, :]
    logit_A = float(last[token_id_A].item())
    logit_B = float(last[token_id_B].item())
    probs = torch.softmax(last.float(), dim=-1)
    pA = float(probs[token_id_A].item())
    pB = float(probs[token_id_B].item())
    denom = pA + pB
    pA_norm = pA / denom if denom > 0 else 0.5
    pB_norm = pB / denom if denom > 0 else 0.5
    logodds_AB = math.log(pA_norm / pB_norm) if pA_norm > 0 and pB_norm > 0 else 0.0
    return {
        'logit_A': logit_A,
        'logit_B': logit_B,
        'pA_norm': pA_norm,
        'pB_norm': pB_norm,
        'logodds_AB': logodds_AB,
    }

prompt = build_prompt(['X','Y','X'], urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version)
print(get_ab_logprobs(prompt, A_id, B_id))


In [None]:
# Cell 9 - Single-episode demo (dynamic sequence)

hidden_urn = 'A' if random.random() < prior_P_A else 'B'
true_pX = urn_A_pX if hidden_urn == 'A' else urn_B_pX

obs_seq = ['X' if random.random() < true_pX else 'Y' for _ in range(seq_len_single)]

rows = []
for t in range(1, seq_len_single + 1):
    obs_prefix = obs_seq[:t]
    prompt = build_prompt(obs_prefix, urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version)
    model_out = get_ab_logprobs(prompt, A_id, B_id)

    truth = compute_truth_table(obs_prefix, prior_P_A, urn_A_pX, urn_B_pX)
    bayes_p = float(truth.iloc[-1]['bayes_P_next_X'])
    bayes_logodds = float(truth.iloc[-1]['bayes_logodds'])
    freq_p = float(truth.iloc[-1]['freq_Phat_X'])

    rows.append({
        'step': t,
        'obs_t': obs_prefix[-1],
        'true_hidden_urn': hidden_urn,
        'bayes_P_next_X': bayes_p,
        'model_P_next_X': model_out['pA_norm'],
        'freq_Phat_X': freq_p,
        'bayes_logodds': bayes_logodds,
        'model_logodds': model_out['logodds_AB'],
    })

single_df = pd.DataFrame(rows)
display(single_df.head())

fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(single_df.step, single_df.bayes_P_next_X, label='Bayes', marker='o')
ax.plot(single_df.step, single_df.model_P_next_X, label='Model', marker='x')
ax.plot(single_df.step, single_df.freq_Phat_X, label='Frequentist', linestyle='--')

for i, tok in enumerate(single_df.obs_t):
    ax.annotate(tok, (single_df.step.iloc[i], 0.45), fontsize=10, ha='center',
                color='black' if tok == 'X' else 'red')

ax.set_title('Single-episode tracking')
ax.set_xlabel('step')
ax.set_ylabel('P(next is X)')
ax.set_ylim(0.0, 1.0)
ax.grid(True, alpha=0.3)
ax.legend()
plt.show()

mae = float((single_df.model_P_next_X - single_df.bayes_P_next_X).abs().mean())
bias = float((single_df.model_P_next_X - single_df.bayes_P_next_X).mean())
corr = float(np.corrcoef(single_df.model_P_next_X, single_df.bayes_P_next_X)[0, 1])
print('MAE:', mae, 'Bias:', bias, 'Corr:', corr)


In [None]:
# Cell 10 - Batch evaluation (no ablation)

rows = []
for ep in tqdm(range(n_batch_episodes), desc='episodes'):
    hidden_urn = 'A' if random.random() < prior_P_A else 'B'
    true_pX = urn_A_pX if hidden_urn == 'A' else urn_B_pX
    obs_seq = ['X' if random.random() < true_pX else 'Y' for _ in range(seq_len_batch)]

    for t in range(1, seq_len_batch + 1):
        obs_prefix = obs_seq[:t]
        prompt = build_prompt(obs_prefix, urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version)
        model_out = get_ab_logprobs(prompt, A_id, B_id)

        truth = compute_truth_table(obs_prefix, prior_P_A, urn_A_pX, urn_B_pX)
        bayes_p = float(truth.iloc[-1]['bayes_P_next_X'])

        rows.append({
            'ep': ep,
            't': t,
            'bayes_P_next_X': bayes_p,
            'model_P_next_X': model_out['pA_norm'],
        })

all_df = pd.DataFrame(rows)

mae = float((all_df.model_P_next_X - all_df.bayes_P_next_X).abs().mean())
bias = float((all_df.model_P_next_X - all_df.bayes_P_next_X).mean())
corr = float(np.corrcoef(all_df.model_P_next_X, all_df.bayes_P_next_X)[0, 1])

summary = pd.DataFrame([{'MAE': mae, 'Bias': bias, 'Corr': corr, 'n': len(all_df)}])
display(summary)
summary.to_csv(os.path.join(output_dir, 'batch_summary_baseline.csv'), index=False)


In [None]:
# Cell 11 - Martingale drift setup (permutation invariance test)

base_seq = ['X'] * drift_counts_nX + ['Y'] * (drift_counts_total_N - drift_counts_nX)

perm_seqs = []
perm_prompts = []
rng = np.random.default_rng(SEED)

for _ in range(n_permutations):
    seq = base_seq.copy()
    rng.shuffle(seq)
    perm_seqs.append(seq)
    perm_prompts.append(build_prompt(seq, urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version))

preds = []
for prompt in tqdm(perm_prompts, desc='perm eval'):
    out = get_ab_logprobs(prompt, A_id, B_id)
    preds.append(out['pA_norm'])

perm_df = pd.DataFrame({
    'seq': [' '.join(s) for s in perm_seqs],
    'model_P_next_X': preds,
})

if drift_metric == 'std':
    drift_val = float(np.std(perm_df.model_P_next_X))
else:
    drift_val = float(perm_df.model_P_next_X.max() - perm_df.model_P_next_X.min())

posterior_P_A = bayes_update_posterior(prior_P_A, urn_A_pX, urn_B_pX, base_seq)
bayes_pred = bayes_predictive(posterior_P_A, urn_A_pX, urn_B_pX)

print('Drift metric:', drift_metric, 'value:', drift_val)
print('Bayes predictive P(next=X):', bayes_pred)

perm_df_sorted = perm_df.sort_values('model_P_next_X')
print('Lowest predictions:')
display(perm_df_sorted.head(3))
print('Highest predictions:')
display(perm_df_sorted.tail(3))


In [None]:
# Cell 12 - Head ablation utilities (TransformerLens hooks)

from collections import defaultdict


def make_head_ablation_hooks(head_indices, mode='zero', scale=0.0):
    # head_indices: list of (layer, head) pairs
    layer_to_heads = defaultdict(list)
    for layer, head in head_indices:
        layer_to_heads[int(layer)].append(int(head))

    hooks = []
    for layer, heads in layer_to_heads.items():
        hook_name = f'blocks.{layer}.attn.hook_result'
        heads = sorted(set(heads))

        def hook_fn(act, hook, heads=heads, mode=mode, scale=scale):
            # act: [batch, pos, head, d_head]
            if mode == 'zero':
                act[:, :, heads, :] = 0.0
            elif mode == 'scale':
                act[:, :, heads, :] = act[:, :, heads, :] * scale
            else:
                raise ValueError('Unknown mode')
            return act

        hooks.append((hook_name, hook_fn))
    return hooks


@torch.no_grad()
def get_pA_for_prompts_batch(prompts, hooks=None):
    tokens = model.to_tokens(prompts, prepend_bos=True)
    logits = model(tokens) if hooks is None else model.run_with_hooks(tokens, fwd_hooks=hooks)
    last = logits[:, -1, :]
    probs = torch.softmax(last.float(), dim=-1)

    pA = probs[:, A_id]
    pB = probs[:, B_id]
    denom = pA + pB
    pA_norm = pA / denom
    return pA_norm.detach().cpu().numpy()


def run_model_with_optional_ablation(prompts, heads_to_ablate):
    if heads_to_ablate is None or len(heads_to_ablate) == 0:
        return get_pA_for_prompts_batch(prompts, hooks=None)
    hooks = make_head_ablation_hooks(heads_to_ablate, mode=ablate_mode, scale=ablate_scale)
    return get_pA_for_prompts_batch(prompts, hooks=hooks)

print('Ablation utilities ready')


In [None]:
# Cell 13 - Identify drift-driving heads by ablation scan

baseline_preds = get_pA_for_prompts_batch(perm_prompts)
if drift_metric == 'std':
    drift_baseline = float(np.std(baseline_preds))
else:
    drift_baseline = float(baseline_preds.max() - baseline_preds.min())

print('Baseline drift:', drift_baseline)

if head_search_scope == 'all':
    candidate_layers = list(range(model.cfg.n_layers))
else:
    candidate_layers = list(layers_subset)

rows = []
for layer in tqdm(candidate_layers, desc='layer scan'):
    for head in range(model.cfg.n_heads):
        hooks = make_head_ablation_hooks([(layer, head)], mode='zero')
        preds = get_pA_for_prompts_batch(perm_prompts, hooks=hooks)
        if drift_metric == 'std':
            drift = float(np.std(preds))
        else:
            drift = float(preds.max() - preds.min())
        rows.append({
            'layer': layer,
            'head': head,
            'drift_baseline': drift_baseline,
            'drift_ablated': drift,
            'drift_reduction': drift_baseline - drift,
            'mean_P_next_X_baseline': float(baseline_preds.mean()),
            'mean_P_next_X_ablated': float(preds.mean()),
        })

head_rank = pd.DataFrame(rows).sort_values('drift_reduction', ascending=False)

display(head_rank.head(top_k_heads_to_report))

head_rank.to_csv(os.path.join(output_dir, 'head_drift_ranking.csv'), index=False)


In [None]:
# Cell 14 - Choose heads to ablate (edit this cell after viewing ranking)

# Option 1: manual selection
heads_to_ablate = []

# Option 2: auto-fill with top-K heads
# heads_to_ablate = [(int(r.layer), int(r.head)) for r in head_rank.head(top_k_heads_to_ablate_default).itertuples()]

print('heads_to_ablate:', heads_to_ablate)


In [None]:
# Cell 15 - Single-episode demo with ablation (baseline vs ablated vs Bayes)

rows = []
for t in range(1, seq_len_single + 1):
    obs_prefix = obs_seq[:t]
    prompt = build_prompt(obs_prefix, urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version)

    base = get_ab_logprobs(prompt, A_id, B_id)
    abl = None
    if heads_to_ablate:
        hooks = make_head_ablation_hooks(heads_to_ablate, mode=ablate_mode, scale=ablate_scale)
        logits = model.run_with_hooks(prompt, fwd_hooks=hooks)
        last = logits[0, -1, :]
        probs = torch.softmax(last.float(), dim=-1)
        pA = float(probs[A_id].item())
        pB = float(probs[B_id].item())
        abl = pA / (pA + pB) if (pA + pB) > 0 else 0.5

    truth = compute_truth_table(obs_prefix, prior_P_A, urn_A_pX, urn_B_pX)
    bayes_p = float(truth.iloc[-1]['bayes_P_next_X'])
    freq_p = float(truth.iloc[-1]['freq_Phat_X'])

    rows.append({
        'step': t,
        'obs_t': obs_prefix[-1],
        'bayes_P_next_X': bayes_p,
        'model_P_next_X_base': base['pA_norm'],
        'model_P_next_X_ablated': abl,
        'freq_Phat_X': freq_p,
    })

abl_df = pd.DataFrame(rows)
display(abl_df.head())

fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(abl_df.step, abl_df.bayes_P_next_X, label='Bayes', marker='o')
ax.plot(abl_df.step, abl_df.model_P_next_X_base, label='Baseline', marker='x')
if heads_to_ablate:
    ax.plot(abl_df.step, abl_df.model_P_next_X_ablated, label='Ablated', marker='^')
ax.plot(abl_df.step, abl_df.freq_Phat_X, label='Frequentist', linestyle='--')

for i, tok in enumerate(abl_df.obs_t):
    ax.annotate(tok, (abl_df.step.iloc[i], 0.45), fontsize=10, ha='center',
                color='black' if tok == 'X' else 'red')

ax.set_title('Single-episode with ablation')
ax.set_xlabel('step')
ax.set_ylabel('P(next is X)')
ax.set_ylim(0.0, 1.0)
ax.grid(True, alpha=0.3)
ax.legend()
plt.show()

mae_base = float((abl_df.model_P_next_X_base - abl_df.bayes_P_next_X).abs().mean())
bias_base = float((abl_df.model_P_next_X_base - abl_df.bayes_P_next_X).mean())

print('Baseline MAE:', mae_base, 'Bias:', bias_base)
if heads_to_ablate:
    mae_abl = float((abl_df.model_P_next_X_ablated - abl_df.bayes_P_next_X).abs().mean())
    bias_abl = float((abl_df.model_P_next_X_ablated - abl_df.bayes_P_next_X).mean())
    print('Ablated MAE:', mae_abl, 'Bias:', bias_abl)


In [None]:
# Cell 16 - Batch evaluation with ablation (baseline vs ablated)

rows = []
for ep in tqdm(range(n_batch_episodes), desc='episodes'):
    hidden_urn = 'A' if random.random() < prior_P_A else 'B'
    true_pX = urn_A_pX if hidden_urn == 'A' else urn_B_pX
    obs_seq_batch = ['X' if random.random() < true_pX else 'Y' for _ in range(seq_len_batch)]

    for t in range(1, seq_len_batch + 1):
        obs_prefix = obs_seq_batch[:t]
        prompt = build_prompt(obs_prefix, urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version)
        base = get_ab_logprobs(prompt, A_id, B_id)['pA_norm']

        abl = None
        if heads_to_ablate:
            hooks = make_head_ablation_hooks(heads_to_ablate, mode=ablate_mode, scale=ablate_scale)
            logits = model.run_with_hooks(prompt, fwd_hooks=hooks)
            last = logits[0, -1, :]
            probs = torch.softmax(last.float(), dim=-1)
            pA = float(probs[A_id].item())
            pB = float(probs[B_id].item())
            abl = pA / (pA + pB) if (pA + pB) > 0 else 0.5

        truth = compute_truth_table(obs_prefix, prior_P_A, urn_A_pX, urn_B_pX)
        bayes_p = float(truth.iloc[-1]['bayes_P_next_X'])

        rows.append({
            'bayes_P_next_X': bayes_p,
            'model_P_next_X_base': base,
            'model_P_next_X_ablated': abl,
        })

all_df2 = pd.DataFrame(rows)

summary_rows = []
mae_base = float((all_df2.model_P_next_X_base - all_df2.bayes_P_next_X).abs().mean())
bias_base = float((all_df2.model_P_next_X_base - all_df2.bayes_P_next_X).mean())
corr_base = float(np.corrcoef(all_df2.model_P_next_X_base, all_df2.bayes_P_next_X)[0, 1])
summary_rows.append({'model': 'baseline', 'MAE': mae_base, 'Bias': bias_base, 'Corr': corr_base})

if heads_to_ablate:
    mae_abl = float((all_df2.model_P_next_X_ablated - all_df2.bayes_P_next_X).abs().mean())
    bias_abl = float((all_df2.model_P_next_X_ablated - all_df2.bayes_P_next_X).mean())
    corr_abl = float(np.corrcoef(all_df2.model_P_next_X_ablated, all_df2.bayes_P_next_X)[0, 1])
    summary_rows.append({'model': 'ablated', 'MAE': mae_abl, 'Bias': bias_abl, 'Corr': corr_abl})

summary = pd.DataFrame(summary_rows)
display(summary)

base_preds = get_pA_for_prompts_batch(perm_prompts)
if drift_metric == 'std':
    drift_base = float(np.std(base_preds))
else:
    drift_base = float(base_preds.max() - base_preds.min())

if heads_to_ablate:
    ablated_preds = run_model_with_optional_ablation(perm_prompts, heads_to_ablate)
    if drift_metric == 'std':
        drift_abl = float(np.std(ablated_preds))
    else:
        drift_abl = float(ablated_preds.max() - ablated_preds.min())
    print('Drift baseline:', drift_base, 'Drift ablated:', drift_abl, 'Reduction:', drift_base - drift_abl)

summary.to_csv(os.path.join(output_dir, 'batch_summary_ablation.csv'), index=False)


In [None]:
# Cell 17 - CircuitVis attention visualization (baseline and ablated)

min_seq = perm_df_sorted.head(1).iloc[0]['seq'].split(' ')
max_seq = perm_df_sorted.tail(1).iloc[0]['seq'].split(' ')

for label, seq in [('min', min_seq), ('max', max_seq)]:
    prompt = build_prompt(seq, urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version)
    print('===', label, 'prompt ===')
    print(prompt)
    logits, cache = model.run_with_cache(prompt)
    tokens = model.to_str_tokens(prompt)
    for layer in range(model.cfg.n_layers):
        patt = cache['pattern', layer][0].detach().cpu().numpy()
        display(cv.attention.attention_patterns(tokens=tokens, attention=patt, title=f'{label}: layer {layer} (baseline)'))

    if heads_to_ablate:
        hooks = make_head_ablation_hooks(heads_to_ablate, mode=ablate_mode, scale=ablate_scale)
        logits2, cache2 = model.run_with_cache(prompt, fwd_hooks=hooks)
        for layer in range(model.cfg.n_layers):
            patt2 = cache2['pattern', layer][0].detach().cpu().numpy()
            display(cv.attention.attention_patterns(tokens=tokens, attention=patt2, title=f'{label}: layer {layer} (ablated)'))

print('Note: ablation changes outputs; attention patterns may be similar.')


In [None]:
# Cell 18 - Build probe dataset (activations + labels)

from collections import defaultdict


def generate_episode(pA, pB, prior, length):
    hidden = 'A' if random.random() < prior else 'B'
    pX = pA if hidden == 'A' else pB
    obs = ['X' if random.random() < pX else 'Y' for _ in range(length)]
    return hidden, obs


def get_target_logodds(obs_prefix):
    pA = bayes_update_posterior(prior_P_A, urn_A_pX, urn_B_pX, obs_prefix)
    if probe_target == 'posterior_logodds':
        return logodds(pA)
    p_next = bayes_predictive(pA, urn_A_pX, urn_B_pX)
    return logodds(p_next)


def batch_by_length(prompts):
    groups = defaultdict(list)
    for i, p in enumerate(prompts):
        length = len(model.tokenizer.encode(p, add_special_tokens=False)) + 1
        groups[length].append(i)
    return groups

train_prompts = []
train_labels = []

for _ in range(probe_train_size):
    _, obs = generate_episode(urn_A_pX, urn_B_pX, prior_P_A, seq_len_batch)
    for t in range(1, seq_len_batch + 1):
        prefix = obs[:t]
        train_prompts.append(build_prompt(prefix, urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version))
        train_labels.append(get_target_logodds(prefix))

train_labels = np.array(train_labels, dtype=np.float32)

X_by_layer = {layer: [] for layer in range(model.cfg.n_layers)}
idx_groups = batch_by_length(train_prompts)

for length, idxs in tqdm(idx_groups.items(), desc='probe extract (train)'):
    batch_prompts = [train_prompts[i] for i in idxs]
    tokens = model.to_tokens(batch_prompts, prepend_bos=True)
    logits, cache = model.run_with_cache(tokens)
    for layer in range(model.cfg.n_layers):
        act = cache[probe_act_name, layer][:, -1, :].detach().cpu().numpy()
        X_by_layer[layer].append(act)

for layer in X_by_layer:
    X_by_layer[layer] = np.concatenate(X_by_layer[layer], axis=0)

print('Probe dataset shape example:', X_by_layer[0].shape, 'labels:', train_labels.shape)


In [None]:
# Cell 19 - Train linear probes for all layers

from sklearn.linear_model import Ridge

results = []
probe_weights = {}

n_total = len(train_labels)
idx = np.arange(n_total)
np.random.shuffle(idx)
split = int(0.8 * n_total)
train_idx, test_idx = idx[:split], idx[split:]

for layer in tqdm(range(model.cfg.n_layers), desc='probe train'):
    X = X_by_layer[layer]
    X_tr, X_te = X[train_idx], X[test_idx]
    y_tr, y_te = train_labels[train_idx], train_labels[test_idx]

    if probe_type == 'ridge':
        reg = Ridge(alpha=1.0, fit_intercept=True)
        reg.fit(X_tr, y_tr)
        w = reg.coef_
        b = float(reg.intercept_)
    else:
        raise ValueError('Only ridge implemented in this notebook')

    pred = X_te @ w + b
    pred_p = 1 / (1 + np.exp(-pred))
    true_p = 1 / (1 + np.exp(-y_te))

    mae = float(np.mean(np.abs(pred_p - true_p)))
    mse = float(np.mean((pred - y_te) ** 2))
    corr = float(np.corrcoef(pred_p, true_p)[0, 1])

    results.append({'layer': layer, 'MAE': mae, 'MSE_logodds': mse, 'Corr': corr})
    probe_weights[layer] = (w, b)

    np.save(os.path.join(output_dir, f'probe_w_layer{layer}.npy'), w.astype(np.float32))
    np.save(os.path.join(output_dir, f'probe_b_layer{layer}.npy'), np.array([b], dtype=np.float32))

res_df = pd.DataFrame(results).sort_values('MAE')
display(res_df.head(10))
res_df.to_csv(os.path.join(output_dir, 'probe_results.csv'), index=False)


In [None]:
# Cell 20 - Choose probe layer and run single-episode with probe overlay

chosen_probe_layer = int(res_df.iloc[0]['layer'])

w = np.load(os.path.join(output_dir, f'probe_w_layer{chosen_probe_layer}.npy'))
b = float(np.load(os.path.join(output_dir, f'probe_b_layer{chosen_probe_layer}.npy'))[0])

rows = []
for t in range(1, seq_len_single + 1):
    obs_prefix = obs_seq[:t]
    prompt = build_prompt(obs_prefix, urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version)

    base = get_ab_logprobs(prompt, A_id, B_id)['pA_norm']
    abl = None
    if heads_to_ablate:
        hooks = make_head_ablation_hooks(heads_to_ablate, mode=ablate_mode, scale=ablate_scale)
        logits = model.run_with_hooks(prompt, fwd_hooks=hooks)
        last = logits[0, -1, :]
        probs = torch.softmax(last.float(), dim=-1)
        pA = float(probs[A_id].item())
        pB = float(probs[B_id].item())
        abl = pA / (pA + pB) if (pA + pB) > 0 else 0.5

    tokens = model.to_tokens(prompt, prepend_bos=True)
    logits, cache = model.run_with_cache(tokens)
    act = cache[probe_act_name, chosen_probe_layer][0, -1, :].detach().cpu().numpy()
    pred_logodds = act @ w + b
    probe_p = 1 / (1 + np.exp(-pred_logodds))

    truth = compute_truth_table(obs_prefix, prior_P_A, urn_A_pX, urn_B_pX)
    bayes_p = float(truth.iloc[-1]['bayes_P_next_X'])

    rows.append({
        'step': t,
        'obs_t': obs_prefix[-1],
        'bayes_P_next_X': bayes_p,
        'baseline_P_next_X': base,
        'ablated_P_next_X': abl,
        'probe_P_next_X': float(probe_p),
    })

probe_df = pd.DataFrame(rows)
display(probe_df.head())

fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(probe_df.step, probe_df.bayes_P_next_X, label='Bayes', marker='o')
ax.plot(probe_df.step, probe_df.baseline_P_next_X, label='Baseline', marker='x')
if heads_to_ablate:
    ax.plot(probe_df.step, probe_df.ablated_P_next_X, label='Ablated', marker='^')
ax.plot(probe_df.step, probe_df.probe_P_next_X, label='Probe', marker='s')

for i, tok in enumerate(probe_df.obs_t):
    ax.annotate(tok, (probe_df.step.iloc[i], 0.45), fontsize=10, ha='center',
                color='black' if tok == 'X' else 'red')

ax.set_title('Single-episode with probe overlay')
ax.set_xlabel('step')
ax.set_ylabel('P(next is X)')
ax.set_ylim(0.0, 1.0)
ax.grid(True, alpha=0.3)
ax.legend()
plt.show()

for name in ['baseline_P_next_X', 'probe_P_next_X']:
    mae = float((probe_df[name] - probe_df.bayes_P_next_X).abs().mean())
    bias = float((probe_df[name] - probe_df.bayes_P_next_X).mean())
    print(name, 'MAE', mae, 'Bias', bias)


In [None]:
# Cell 21 - Batch evaluation including probe

rows = []
for ep in tqdm(range(n_batch_episodes), desc='episodes'):
    hidden_urn = 'A' if random.random() < prior_P_A else 'B'
    true_pX = urn_A_pX if hidden_urn == 'A' else urn_B_pX
    obs_seq_batch = ['X' if random.random() < true_pX else 'Y' for _ in range(seq_len_batch)]

    for t in range(1, seq_len_batch + 1):
        obs_prefix = obs_seq_batch[:t]
        prompt = build_prompt(obs_prefix, urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version)

        base = get_ab_logprobs(prompt, A_id, B_id)['pA_norm']
        abl = None
        if heads_to_ablate:
            hooks = make_head_ablation_hooks(heads_to_ablate, mode=ablate_mode, scale=ablate_scale)
            logits = model.run_with_hooks(prompt, fwd_hooks=hooks)
            last = logits[0, -1, :]
            probs = torch.softmax(last.float(), dim=-1)
            pA = float(probs[A_id].item())
            pB = float(probs[B_id].item())
            abl = pA / (pA + pB) if (pA + pB) > 0 else 0.5

        tokens = model.to_tokens(prompt, prepend_bos=True)
        logits, cache = model.run_with_cache(tokens)
        act = cache[probe_act_name, chosen_probe_layer][0, -1, :].detach().cpu().numpy()
        pred_logodds = act @ w + b
        probe_p = 1 / (1 + np.exp(-pred_logodds))

        truth = compute_truth_table(obs_prefix, prior_P_A, urn_A_pX, urn_B_pX)
        bayes_p = float(truth.iloc[-1]['bayes_P_next_X'])

        rows.append({
            'bayes_P_next_X': bayes_p,
            'baseline_P_next_X': base,
            'ablated_P_next_X': abl,
            'probe_P_next_X': float(probe_p),
        })

all_df3 = pd.DataFrame(rows)

summary_rows = []
for col in ['baseline_P_next_X', 'probe_P_next_X']:
    mae = float((all_df3[col] - all_df3.bayes_P_next_X).abs().mean())
    bias = float((all_df3[col] - all_df3.bayes_P_next_X).mean())
    corr = float(np.corrcoef(all_df3[col], all_df3.bayes_P_next_X)[0, 1])
    summary_rows.append({'model': col, 'MAE': mae, 'Bias': bias, 'Corr': corr})

if heads_to_ablate:
    col = 'ablated_P_next_X'
    mae = float((all_df3[col] - all_df3.bayes_P_next_X).abs().mean())
    bias = float((all_df3[col] - all_df3.bayes_P_next_X).mean())
    corr = float(np.corrcoef(all_df3[col], all_df3.bayes_P_next_X)[0, 1])
    summary_rows.append({'model': col, 'MAE': mae, 'Bias': bias, 'Corr': corr})

summary = pd.DataFrame(summary_rows)
display(summary)
summary.to_csv(os.path.join(output_dir, 'batch_summary_with_probe.csv'), index=False)


In [None]:
# Cell 22 - Optional: belief editing and belief swapping (causal validation)

layer_to_edit = chosen_probe_layer
edit_lambda = 5.0

probe_dir = w / (np.linalg.norm(w) + 1e-8)
probe_dir_t = torch.tensor(probe_dir, device=DEVICE, dtype=torch.float32)


def belief_edit_hook(act, hook, direction, scale):
    act[:, -1, :] = act[:, -1, :] + scale * direction
    return act

example_seq = ['X','Y','X','Y','X','Y']
example_prompt = build_prompt(example_seq, urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version)

base = get_ab_logprobs(example_prompt, A_id, B_id)['pA_norm']

if probe_act_name == 'resid_post':
    hook_name = f'blocks.{layer_to_edit}.hook_resid_post'
elif probe_act_name == 'resid_pre':
    hook_name = f'blocks.{layer_to_edit}.hook_resid_pre'
else:
    raise ValueError('Unsupported probe_act_name for editing')

logits = model.run_with_hooks(example_prompt, fwd_hooks=[(hook_name, lambda act, hook: belief_edit_hook(act, hook, probe_dir_t, edit_lambda))])
last = logits[0, -1, :]
probs = torch.softmax(last.float(), dim=-1)
edit_pA = float(probs[A_id].item())
edit_pB = float(probs[B_id].item())
edit = edit_pA / (edit_pA + edit_pB) if (edit_pA + edit_pB) > 0 else 0.5

print('Baseline P(X):', base)
print('Edited   P(X):', edit, '(lambda=', edit_lambda, ')')
print('Belief swapping is left as an exercise: patch the probe component between two runs.')


In [None]:
# Cell 23 - Multi-model runner (lightweight)

MODEL_LIST = [
    'meta-llama/Llama-3.2-1B',
    'meta-llama/Llama-3.2-1B-Instruct',
    'meta-llama/Llama-3.2-3B',
    'meta-llama/Llama-3.2-3B-Instruct',
    'meta-llama/Llama-3.1-8B',
    'meta-llama/Llama-3.1-8B-Instruct',
]

RUN_MULTI_MODEL = False  # set True to execute

if not RUN_MULTI_MODEL:
    print('Set RUN_MULTI_MODEL=True to run across models.')
else:
    rows = []
    for mid in MODEL_LIST:
        print('Loading', mid)
        m = HookedTransformer.from_pretrained(mid, device=DEVICE, dtype=DTYPE)
        m.eval()

        # Token check per model
        def assert_single_token_local(text):
            ids = m.tokenizer.encode(text, add_special_tokens=False)
            if len(ids) != 1:
                raise RuntimeError(f'Model {mid}: token {text!r} is not single token (ids={ids})')
            return ids[0]

        A_id_local = assert_single_token_local(ANSWER_TOKEN_A)
        B_id_local = assert_single_token_local(ANSWER_TOKEN_B)

        tmp_preds = []
        for _ in range(20):
            hidden = 'A' if random.random() < prior_P_A else 'B'
            pX = urn_A_pX if hidden == 'A' else urn_B_pX
            obs = ['X' if random.random() < pX else 'Y' for _ in range(10)]
            prompt = build_prompt(obs, urn_A_pX, urn_B_pX, prior_P_A, prompt_template_version)
            logits = m(prompt)
            last = logits[0, -1, :]
            pA = float(torch.softmax(last.float(), dim=-1)[A_id_local].item())
            pB = float(torch.softmax(last.float(), dim=-1)[B_id_local].item())
            tmp_preds.append(pA / (pA + pB))
        drift = float(np.std(tmp_preds))
        rows.append({'model': mid, 'drift_proxy_std': drift})

    multi_df = pd.DataFrame(rows)
    display(multi_df)
