# Causal Interventions

Tests whether Bayes-like behavior comes from a specific internal inference mechanism by applying causal interventions:

- Ablate candidate heads (identified in `02_localization.ipynb`) and measure changes in Bayes-alignment.
- Steer / inject belief using a probe direction and see if downstream predictions shift predictably.
- Stress-test robustness: distractors, symbol swaps, and order drift.

If the model only matches Bayes behaviorally, these interventions should not produce clean, systematic effects. If there is an internal belief update mechanism, we expect targeted, interpretable changes.


In [None]:
# Notebook path setup: make repo imports work regardless of where you run this from
from pathlib import Path
import sys

cwd = Path.cwd().resolve()
repo_candidates = [cwd, cwd.parent]
repo_root = next((p for p in repo_candidates if (p / 'bayesian_llm').exists()), None)
if repo_root is None:
    raise RuntimeError(f'Could not find repo root from cwd={cwd}.')

if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

print('Repo root:', repo_root)


In [None]:
# Dependencies
import pkgutil
missing = [p for p in ['torch','transformers','transformer_lens','numpy','pandas','matplotlib','tqdm'] if pkgutil.find_loader(p) is None]
print('Missing:', missing if missing else 'None')
if missing:
    print('Install with: pip install -r ../requirements.txt')


In [None]:
# Configuration
import numpy as np

MODEL_ID = 'meta-llama/Llama-3.1-8B'
DTYPE = 'float16'

CONTROL_TEMPLATE = 'order_irrelevant'
N_TOTAL = 10
N_PERMUTATIONS = 5

# Head selection from patching results
TOPK_BAYES = 20
TOPK_ANTI = 20

# Steering
STEER_LAYER = None  # if None, infer from saved probe filename
STEER_ALPHA = 5.0   # scaling for probe direction; tune

RNG_SEED = 0

print('MODEL_ID:', MODEL_ID)
print('CONTROL_TEMPLATE:', CONTROL_TEMPLATE)
print('TOPK_BAYES/ANTI:', TOPK_BAYES, TOPK_ANTI)


In [None]:
# Load TransformerLens model
import torch
from transformer_lens import HookedTransformer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16 if DTYPE == 'float16' else torch.bfloat16

tl_model = HookedTransformer.from_pretrained(MODEL_ID, device=device, dtype=dtype)
print('Loaded on', device)
print('n_layers:', tl_model.cfg.n_layers, 'n_heads:', tl_model.cfg.n_heads)


In [None]:
# Load localization artifacts (if present)
import os
import re

os.makedirs('results', exist_ok=True)

patch_path = 'results/head_patching_recovery.npy'
probe_w_paths = sorted([p for p in os.listdir('results') if re.match(r'probe_w_layer\d+\.npy', p)])

head_recovery = None
if os.path.exists(patch_path):
    head_recovery = np.load(patch_path)
    print('Loaded', patch_path, 'shape=', head_recovery.shape)
else:
    print('No head patching matrix found at', patch_path)

probe_w = None
probe_layer = None
if probe_w_paths:
    w_path = probe_w_paths[-1]  # pick the latest by name sort
    probe_layer = int(re.findall(r'probe_w_layer(\d+)\.npy', w_path)[0])
    probe_w = np.load('results/' + w_path)
    print('Loaded probe weights:', w_path, 'shape=', probe_w.shape)
else:
    print('No probe weights found in results/. Run 02_localization.ipynb first.')

if STEER_LAYER is None and probe_layer is not None:
    STEER_LAYER = probe_layer
print('STEER_LAYER:', STEER_LAYER)


In [None]:
# Task helpers (prompts, Bayes, model probability)

import math
import pandas as pd
from tqdm.auto import tqdm

from bayesian_llm.bayes import two_generator_posterior_predictive
from bayesian_llm.data import make_sequence, permute_sequence, set_seed

set_seed(RNG_SEED)
rng = np.random.default_rng(RNG_SEED)

X_tok = tl_model.to_single_token(' X')
Y_tok = tl_model.to_single_token(' Y')


def prompt_two_generators(sequence_tokens, *, control: str):
    seq_str = ' '.join(sequence_tokens)
    if control == 'base':
        prefix = 'Two random generators. Generator A: 50% X. Generator B: 75% X.'
    elif control == 'independent':
        prefix = 'Two random generators. Generator A: 50% X. Generator B: 75% X. Draws are independent.'
    elif control == 'order_irrelevant':
        prefix = 'Two random generators. Generator A: 50% X. Generator B: 75% X. Draws are independent. Order does not matter.'
    else:
        raise ValueError(control)

    return (
        f"{prefix} Sequence: {seq_str}. "
        'Predict the next output (X or Y):'
    )


def p_x_from_logits(logits):
    p = torch.softmax(logits[0, -1, :].float(), dim=-1)
    px = float(p[X_tok].item())
    py = float(p[Y_tok].item())
    return px / (px + py) if (px + py) > 0 else 0.5


def p_x(prompt, *, fwd_hooks=None):
    if fwd_hooks is None:
        logits = tl_model(prompt)
    else:
        logits = tl_model.run_with_hooks(prompt, fwd_hooks=fwd_hooks)
    return p_x_from_logits(logits)


def evidence_sweep(*, control, fwd_hooks=None, n_total=N_TOTAL, n_perms=N_PERMUTATIONS):
    rows = []
    for n_x in range(n_total + 1):
        base = make_sequence(n_x=n_x, n_total=n_total, x='X', y='Y')
        true_p = two_generator_posterior_predictive(n_x=n_x, n_total=n_total)

        preds = []
        for _ in range(n_perms):
            seq = permute_sequence(base, rng=rng)
            prompt = prompt_two_generators(seq, control=control)
            preds.append(p_x(prompt, fwd_hooks=fwd_hooks))

        preds = np.asarray(preds)
        rows.append({
            'n_X': n_x,
            'true_bayes': true_p,
            'llm_mean': float(preds.mean()),
            'llm_std': float(preds.std(ddof=1) if len(preds) > 1 else 0.0),
            'order_drift': float(preds.max() - preds.min()),
        })

    df = pd.DataFrame(rows)
    df['abs_error'] = (df.llm_mean - df.true_bayes).abs()
    mae = float(df.abs_error.mean())
    return df, mae


def order_drift_fixed(*, seq_tokens, control, fwd_hooks=None, n_samples=50):
    # Sample random permutations of a fixed multiset and measure drift.
    preds = []
    for _ in range(n_samples):
        seq = permute_sequence(seq_tokens, rng=rng)
        prompt = prompt_two_generators(seq, control=control)
        preds.append(p_x(prompt, fwd_hooks=fwd_hooks))
    preds = np.asarray(preds)
    return float(preds.max() - preds.min()), preds


In [None]:
# Baseline: evidence sweep + order drift

df_base, mae_base = evidence_sweep(control=CONTROL_TEMPLATE)
print('Baseline MAE to Bayes:', mae_base)
display(df_base)

# Plot
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(df_base.n_X, df_base.true_bayes, color='black', linewidth=2, label='True Bayes')
ax.errorbar(df_base.n_X, df_base.llm_mean, yerr=df_base.llm_std, marker='o', capsize=3, label='LLM')
ax.set_title(f'Baseline sweep ({CONTROL_TEMPLATE})')
ax.set_xlabel('n_X')
ax.set_ylabel('P(next is X | {X,Y})')
ax.set_ylim(0, 1)
ax.grid(True, alpha=0.3)
ax.legend()
plt.show()

# Order drift diagnostic (fixed counts)
seq_fixed = ['X','X','X','X','Y']
drift, preds = order_drift_fixed(seq_tokens=seq_fixed, control=CONTROL_TEMPLATE, n_samples=100)
print('Order drift for', seq_fixed, ':', drift)


In [None]:
# Choose candidate heads from head_patching_recovery.npy (if available)

from collections import defaultdict

BAYES_HEADS = []
ANTI_HEADS = []

if head_recovery is not None:
    flat = []
    for layer in range(head_recovery.shape[0]):
        for head in range(head_recovery.shape[1]):
            flat.append((float(head_recovery[layer, head]), layer, head))
    flat.sort(key=lambda x: x[0], reverse=True)

    BAYES_HEADS = [(layer, head) for _, layer, head in flat[:TOPK_BAYES]]
    ANTI_HEADS = [(layer, head) for _, layer, head in flat[-TOPK_ANTI:]]

print('BAYES_HEADS (top recovery):', BAYES_HEADS[:10], '... total', len(BAYES_HEADS))
print('ANTI_HEADS (bottom recovery):', ANTI_HEADS[:10], '... total', len(ANTI_HEADS))

# You can also override manually, e.g. BAYES_HEADS=[(20,3),(21,7)]


In [None]:
# Build ablation hooks for a set of heads

from transformer_lens.utils import get_act_name

def ablation_hooks(heads_to_ablate):
    # heads_to_ablate: list[(layer, head)]
    layer_to_heads = defaultdict(list)
    for layer, head in heads_to_ablate:
        layer_to_heads[int(layer)].append(int(head))

    hooks = []
    for layer, heads in layer_to_heads.items():
        hook_name = get_act_name('z', layer)
        heads = sorted(set(heads))

        def hook_fn(z, hook, heads=heads):
            z[:, :, heads, :] = 0.0
            return z

        hooks.append((hook_name, hook_fn))
    return hooks

# Quick smoke test (no-op ablation)
print('Hooks for 0 heads:', len(ablation_hooks([])))


In [None]:
# Intervention A: ablate ANTI_HEADS (hypothesis: improves Bayes calibration)

if not ANTI_HEADS:
    print('ANTI_HEADS empty; skipping')
else:
    hooks_anti = ablation_hooks(ANTI_HEADS)
    df_anti, mae_anti = evidence_sweep(control=CONTROL_TEMPLATE, fwd_hooks=hooks_anti)

    print('MAE baseline:', mae_base)
    print('MAE ablate ANTI:', mae_anti)

    import matplotlib.pyplot as plt
    fig, ax = plt.subplots(figsize=(9, 5))
    ax.plot(df_base.n_X, df_base.true_bayes, color='black', linewidth=2, label='True Bayes')
    ax.plot(df_base.n_X, df_base.llm_mean, marker='o', label='Baseline')
    ax.plot(df_anti.n_X, df_anti.llm_mean, marker='x', label='Ablate ANTI_HEADS')
    ax.set_title('Ablating anti-recovery heads')
    ax.set_xlabel('n_X')
    ax.set_ylabel('P(next is X | {X,Y})')
    ax.set_ylim(0, 1)
    ax.grid(True, alpha=0.3)
    ax.legend()
    plt.show()

    drift_base, _ = order_drift_fixed(seq_tokens=['X','X','X','X','Y'], control=CONTROL_TEMPLATE, n_samples=100)
    drift_anti, _ = order_drift_fixed(seq_tokens=['X','X','X','X','Y'], control=CONTROL_TEMPLATE, fwd_hooks=hooks_anti, n_samples=100)
    print('Order drift baseline:', drift_base)
    print('Order drift ablate ANTI:', drift_anti)


In [None]:
# Intervention B: ablate BAYES_HEADS (hypothesis: harms Bayes tracking)

if not BAYES_HEADS:
    print('BAYES_HEADS empty; skipping')
else:
    hooks_bayes = ablation_hooks(BAYES_HEADS)
    df_bayes, mae_bayes = evidence_sweep(control=CONTROL_TEMPLATE, fwd_hooks=hooks_bayes)

    print('MAE baseline:', mae_base)
    print('MAE ablate BAYES:', mae_bayes)

    import matplotlib.pyplot as plt
    fig, ax = plt.subplots(figsize=(9, 5))
    ax.plot(df_base.n_X, df_base.true_bayes, color='black', linewidth=2, label='True Bayes')
    ax.plot(df_base.n_X, df_base.llm_mean, marker='o', label='Baseline')
    ax.plot(df_bayes.n_X, df_bayes.llm_mean, marker='x', label='Ablate BAYES_HEADS')
    ax.set_title('Ablating high-recovery (candidate Bayes) heads')
    ax.set_xlabel('n_X')
    ax.set_ylabel('P(next is X | {X,Y})')
    ax.set_ylim(0, 1)
    ax.grid(True, alpha=0.3)
    ax.legend()
    plt.show()


## Robustness: Distractors + Symbol Swap

Here we test whether the inferred belief update survives nuisance changes.


In [None]:
# Distractor insertion

DISTRACTOR_TEXT = ' The quick brown fox jumps over the lazy dog.'

def prompt_with_distractor(sequence_tokens, *, control: str, distractor_repeats: int):
    seq_str = ' '.join(sequence_tokens)
    if control == 'base':
        prefix = 'Two random generators. Generator A: 50% X. Generator B: 75% X.'
    elif control == 'independent':
        prefix = 'Two random generators. Generator A: 50% X. Generator B: 75% X. Draws are independent.'
    elif control == 'order_irrelevant':
        prefix = 'Two random generators. Generator A: 50% X. Generator B: 75% X. Draws are independent. Order does not matter.'
    else:
        raise ValueError(control)

    distractor = DISTRACTOR_TEXT * int(distractor_repeats)
    return (
        f"{prefix} Sequence: {seq_str}."
        + distractor
        + ' Predict the next output (X or Y):'
    )


def sweep_distractors(n_x, *, distractor_repeats_list, control, fwd_hooks=None):
    base = make_sequence(n_x=n_x, n_total=N_TOTAL, x='X', y='Y')
    base = permute_sequence(base, rng=rng)

    rows = []
    for r in distractor_repeats_list:
        prompt = prompt_with_distractor(base, control=control, distractor_repeats=r)
        rows.append({'repeats': r, 'p_x': p_x(prompt, fwd_hooks=fwd_hooks)})
    return pd.DataFrame(rows)

repeats_list = [0, 1, 2, 4, 8]

df_dist = sweep_distractors(n_x=8, distractor_repeats_list=repeats_list, control=CONTROL_TEMPLATE)
display(df_dist)

import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(df_dist.repeats, df_dist.p_x, marker='o')
ax.set_title('Effect of distractors on P(next=X)')
ax.set_xlabel('Distractor repeats')
ax.set_ylabel('P(next=X | {X,Y})')
ax.grid(True, alpha=0.3)
plt.show()


In [None]:
# Symbol swap robustness (X/Y -> A/B)

# This keeps the probabilistic structure but changes surface form.

def prompt_two_generators_ab(sequence_tokens, *, control: str):
    seq_str = ' '.join(sequence_tokens)
    if control == 'base':
        prefix = 'Two random generators. Generator A: 50% A. Generator B: 75% A.'
    elif control == 'independent':
        prefix = 'Two random generators. Generator A: 50% A. Generator B: 75% A. Draws are independent.'
    elif control == 'order_irrelevant':
        prefix = 'Two random generators. Generator A: 50% A. Generator B: 75% A. Draws are independent. Order does not matter.'
    else:
        raise ValueError(control)

    return (
        f"{prefix} Sequence: {seq_str}. "
        'Predict the next output (A or B):'
    )

# Token ids in TLens for A/B
A_tok = tl_model.to_single_token(' A')
B_tok = tl_model.to_single_token(' B')

def p_a(prompt, *, fwd_hooks=None):
    logits = tl_model(prompt) if fwd_hooks is None else tl_model.run_with_hooks(prompt, fwd_hooks=fwd_hooks)
    p = torch.softmax(logits[0, -1, :].float(), dim=-1)
    pa = float(p[A_tok].item())
    pb = float(p[B_tok].item())
    return pa / (pa + pb) if (pa + pb) > 0 else 0.5

seq = ['A','A','B','A','A']
prompt_ab = prompt_two_generators_ab(seq, control=CONTROL_TEMPLATE)
print(prompt_ab)
print('P(next is A | {A,B}):', p_a(prompt_ab))


## Steering / Belief Injection

We use a probe direction `w` (saved from `02_localization.ipynb`) to *add* a vector to the residual stream at a chosen layer.

If `w` truly corresponds to belief in hypothesis B, increasing it should systematically increase `P(next=X)` on ambiguous evidence.


In [None]:
# Build a steering hook from probe direction (if available)

from transformer_lens.utils import get_act_name

def steering_hooks(*, layer: int, alpha: float, pos: int = -1):
    if probe_w is None:
        raise RuntimeError('No probe weights loaded. Run 02_localization.ipynb first.')

    w = torch.tensor(probe_w, device=device, dtype=torch.float32)
    w = w / (w.norm() + 1e-8)

    hook_name = get_act_name('resid_post', int(layer))

    def hook_fn(resid, hook):
        resid[:, pos, :] = resid[:, pos, :] + alpha * w
        return resid

    return [(hook_name, hook_fn)]

# Test on a moderately ambiguous sequence
seq = ['X','X','X','Y','Y','Y','X','Y','X','Y']
prompt = prompt_two_generators(seq, control=CONTROL_TEMPLATE)

px0 = p_x(prompt)
hooks = steering_hooks(layer=STEER_LAYER, alpha=STEER_ALPHA)
px1 = p_x(prompt, fwd_hooks=hooks)
print('Baseline P(X):', px0)
print('Steered  P(X):', px1)


In [None]:
# Steering dose-response curve

alphas = [-10, -5, -2, -1, 0, 1, 2, 5, 10]
rows = []
for a in alphas:
    hooks = steering_hooks(layer=STEER_LAYER, alpha=float(a))
    rows.append({'alpha': a, 'p_x': p_x(prompt, fwd_hooks=hooks)})

df_alpha = pd.DataFrame(rows)
display(df_alpha)

import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(df_alpha.alpha, df_alpha.p_x, marker='o')
ax.axhline(px0, color='gray', linestyle='--', alpha=0.6, label='baseline')
ax.set_title('Belief steering dose-response')
ax.set_xlabel('alpha (probe direction)')
ax.set_ylabel('P(next=X | {X,Y})')
ax.grid(True, alpha=0.3)
ax.legend()
plt.show()


In [None]:
# Save intervention results (so you can compare across runs)

import os
os.makedirs('results', exist_ok=True)

df_base.to_csv('results/causal_baseline_sweep.csv', index=False)
print('Wrote results/causal_baseline_sweep.csv')

if head_recovery is not None:
    with open('results/selected_heads.txt', 'w') as f:
        f.write('BAYES_HEADS
')
        for l,h in BAYES_HEADS:
            f.write(f'{l},{h}
')
        f.write('
ANTI_HEADS
')
        for l,h in ANTI_HEADS:
            f.write(f'{l},{h}
')
    print('Wrote results/selected_heads.txt')
