# Localization

Localizing where Bayesian evidence integration appears inside Llama-3.1-8B.

We do three complementary analyses:
1. Layerwise logit lens: decode the model's implied `P(next=X)` from intermediate hidden states and compare to Bayes.
2. Belief-state probing: train a linear probe to predict the true posterior from internal activations, test cross-template generalization.
3. Head patching: identify sparse heads whose activations causally restore Bayes-like behavior.

In [None]:
# Notebook path setup: make repo imports work regardless of where you run this from
from pathlib import Path
import sys

cwd = Path.cwd().resolve()
repo_candidates = [cwd, cwd.parent]
repo_root = next((p for p in repo_candidates if (p / 'bayesian_llm').exists()), None)
if repo_root is None:
    raise RuntimeError(f'Could not find repo root from cwd={cwd}.')

if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

print('Repo root:', repo_root)


In [None]:
# Dependency check
import pkgutil

REQUIRED = ['torch', 'transformers', 'accelerate', 'huggingface_hub', 'sklearn', 'numpy', 'pandas', 'matplotlib']
missing = [p for p in REQUIRED if pkgutil.find_loader(p) is None]
print('Missing core packages:', missing if missing else 'None')

HAS_TLENS = pkgutil.find_loader('transformer_lens') is not None
print('transformer_lens available:', HAS_TLENS)
if missing or not HAS_TLENS:
    print('Install with: pip install -r ../requirements.txt')


In [None]:
# Configuration
import numpy as np

MODEL_ID = 'meta-llama/Llama-3.1-8B'
DTYPE = 'float16'  # or 'bfloat16'
DEVICE_MAP = 'auto'

# Evidence sweep params
N_TOTAL = 10
N_PERMUTATIONS = 5
CONTROL_TEMPLATE = 'order_irrelevant'  # 'base' | 'independent' | 'order_irrelevant'

# Probe dataset params
N_TRAIN = 200
N_TEST = 100
LAYERS_STRIDE = 4  # probe every k layers to keep compute manageable

RNG_SEED = 0

print('MODEL_ID:', MODEL_ID)
print('N_TOTAL:', N_TOTAL)
print('CONTROL_TEMPLATE:', CONTROL_TEMPLATE)
print('Probe: N_TRAIN/N_TEST:', N_TRAIN, N_TEST)


In [None]:
# Load HuggingFace model (used for hidden-states logit lens + probing)
import torch
from bayesian_llm.llm import load_hf_causal_lm

dtype = {
    'float16': torch.float16,
    'bfloat16': torch.bfloat16,
    'float32': torch.float32,
}[DTYPE]

loaded = load_hf_causal_lm(MODEL_ID, torch_dtype=dtype, device_map=DEVICE_MAP)
model, tokenizer = loaded.model, loaded.tokenizer

print('Loaded HF model:', MODEL_ID)
print('dtype:', next(model.parameters()).dtype)
print('device:', next(model.parameters()).device)


In [None]:
# Prompt + Bayes helpers
import math
import pandas as pd

from bayesian_llm.bayes import discrete_posterior_predictive, DiscreteHypothesis, two_generator_posterior_predictive
from bayesian_llm.data import make_sequence, permute_sequence, set_seed

set_seed(RNG_SEED)
rng = np.random.default_rng(RNG_SEED)

HYPOTHESES = [
    DiscreteHypothesis(name='A', p_success=0.50),
    DiscreteHypothesis(name='B', p_success=0.75),
]

X_VARIANTS = [' X', 'X', '\nX']
Y_VARIANTS = [' Y', 'Y', '\nY']


def token_ids_for_variants(variants):
    ids = []
    for v in variants:
        enc = tokenizer.encode(v, add_special_tokens=False)
        if len(enc) == 1:
            ids.append(int(enc[0]))
    ids = sorted(set(ids))
    if not ids:
        raise ValueError(f'No single-token ids for variants={variants}')
    return ids

X_IDS = token_ids_for_variants(X_VARIANTS)
Y_IDS = token_ids_for_variants(Y_VARIANTS)
print('X_IDS:', X_IDS, 'Y_IDS:', Y_IDS)


def prompt_two_generators(sequence_tokens, *, control: str):
    seq_str = ' '.join(sequence_tokens)
    if control == 'base':
        prefix = 'Two random generators. Generator A: 50% X. Generator B: 75% X.'
    elif control == 'independent':
        prefix = 'Two random generators. Generator A: 50% X. Generator B: 75% X. Draws are independent.'
    elif control == 'order_irrelevant':
        prefix = 'Two random generators. Generator A: 50% X. Generator B: 75% X. Draws are independent. Order does not matter.'
    else:
        raise ValueError(control)

    return (
        f"{prefix} Sequence: {seq_str}. "
        'Predict the next output (X or Y):'
    )


def bayes_posterior_B(n_x, n_total):
    _, post = discrete_posterior_predictive(
        n_success=int(n_x),
        n_total=int(n_total),
        hypotheses=HYPOTHESES,
        priors=[0.5, 0.5],
    )
    return float(post['B'])


def logit(p):
    p = min(1 - 1e-9, max(1e-9, float(p)))
    return math.log(p / (1 - p))


In [None]:
# Layerwise logit lens on HF hidden states
# We decode each intermediate hidden state using the *final* norm + lm_head.

import torch

@torch.no_grad()
def layerwise_p_x(prompt: str):
    inputs = tokenizer(prompt, return_tensors='pt')
    inputs = inputs.to(next(model.parameters()).device)

    out = model(**inputs, output_hidden_states=True, use_cache=False, return_dict=True)
    h_states = out.hidden_states  # tuple: (embeds, layer1, ..., layerN)

    # Llama uses a final RMSNorm before lm_head; apply it to intermediate states too.
    norm = model.model.norm
    lm_head = model.lm_head

    probs = []
    for h in h_states[1:]:  # skip embeddings
        last = h[0, -1, :]
        last = norm(last)
        logits = lm_head(last)
        p = torch.softmax(logits.float(), dim=-1)
        p_x = float(p[X_IDS].sum().item())
        p_y = float(p[Y_IDS].sum().item())
        probs.append(p_x / (p_x + p_y) if (p_x + p_y) > 0 else 0.5)

    return np.array(probs)

# Smoke test
prompt = prompt_two_generators(['X','X','Y','X','X'], control=CONTROL_TEMPLATE)
px_by_layer = layerwise_p_x(prompt)
print('n_layers decoded:', len(px_by_layer))
print('first/last:', px_by_layer[0], px_by_layer[-1])


In [None]:
# Evidence sweep: where (which layers) are most Bayes-aligned?

from tqdm.auto import tqdm

rows = []

for n_x in tqdm(range(N_TOTAL + 1), desc='layerwise_sweep'):
    base_seq = make_sequence(n_x=n_x, n_total=N_TOTAL, x='X', y='Y')
    true_p_next = two_generator_posterior_predictive(n_x=n_x, n_total=N_TOTAL)

    # Average over permutations to reduce order artifacts
    px_layers = []
    for _ in range(N_PERMUTATIONS):
        seq = permute_sequence(base_seq, rng=rng)
        prompt = prompt_two_generators(seq, control=CONTROL_TEMPLATE)
        px_layers.append(layerwise_p_x(prompt))

    px_layers = np.stack(px_layers, axis=0)  # [perm, layer]
    mean_px = px_layers.mean(axis=0)

    for layer_idx, px in enumerate(mean_px, start=1):
        rows.append({
            'n_X': n_x,
            'layer': layer_idx,
            'p_x_logit_lens': float(px),
            'true_bayes_p_x': float(true_p_next),
            'error': float(px - true_p_next),
            'abs_error': float(abs(px - true_p_next)),
        })

df_layer = pd.DataFrame(rows)
display(df_layer.head())

# Aggregate per layer
mae_by_layer = df_layer.groupby('layer')['abs_error'].mean().reset_index().sort_values('abs_error')
display(mae_by_layer.head(10))

best_layer = int(mae_by_layer.iloc[0]['layer'])
print('Best layer by MAE:', best_layer)


In [None]:
# Visualization: MAE-by-layer and error heatmap

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(mae_by_layer.layer, mae_by_layer.abs_error, marker='o')
ax.set_title('Logit lens: MAE to Bayes vs layer')
ax.set_xlabel('Layer')
ax.set_ylabel('Mean |P_x(layer)-P_x(Bayes)|')
ax.grid(True, alpha=0.3)
plt.show()

# Heatmap over (layer, n_X)
pivot = df_layer.pivot_table(index='layer', columns='n_X', values='error', aggfunc='mean')
fig, ax = plt.subplots(figsize=(10, 6))
im = ax.imshow(pivot.values, aspect='auto', origin='lower', cmap='RdBu_r', vmin=-0.25, vmax=0.25)
ax.set_title('Logit lens error heatmap (blue = under, red = over vs Bayes)')
ax.set_xlabel('n_X')
ax.set_ylabel('Layer')
ax.set_xticks(range(pivot.shape[1]))
ax.set_xticklabels(pivot.columns.tolist())
fig.colorbar(im, ax=ax, fraction=0.03, pad=0.04)
plt.show()


In [None]:
# Probe: is the Bayes posterior linearly represented? (predict P(B|evidence))

from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score

# Choose layers to probe (stride to keep compute manageable)
n_layers = model.config.num_hidden_layers
layers_to_probe = list(range(1, n_layers + 1, LAYERS_STRIDE))
print('Layers to probe:', layers_to_probe)

# Build dataset of prompts and targets

def make_dataset(n_samples, control):
    prompts, targets = [], []
    for _ in range(n_samples):
        n_x = int(rng.integers(0, N_TOTAL + 1))
        seq = make_sequence(n_x=n_x, n_total=N_TOTAL, x='X', y='Y')
        seq = permute_sequence(seq, rng=rng)
        prompts.append(prompt_two_generators(seq, control=control))
        targets.append(bayes_posterior_B(n_x=n_x, n_total=N_TOTAL))
    return prompts, np.array(targets, dtype=np.float64)

train_prompts, y_train = make_dataset(N_TRAIN, control='base')
test_prompts, y_test = make_dataset(N_TEST, control=CONTROL_TEMPLATE)

print('Train control:', 'base')
print('Test control:', CONTROL_TEMPLATE)
print('y_train range:', (y_train.min(), y_train.max()))


In [None]:
# Extract representations for selected layers (last-position hidden state)

import torch
from tqdm.auto import tqdm

@torch.no_grad()
def extract_layer_reprs(prompts, layers):
    # returns dict layer -> [n, d_model]
    feats = {layer: [] for layer in layers}
    device = next(model.parameters()).device
    norm = model.model.norm

    for prompt in tqdm(prompts, desc='extract_reprs'):
        inputs = tokenizer(prompt, return_tensors='pt').to(device)
        out = model(**inputs, output_hidden_states=True, use_cache=False, return_dict=True)
        h_states = out.hidden_states

        for layer in layers:
            h = h_states[layer]  # 1..n_layers inclusive (after that layer)
            last = h[0, -1, :]
            last = norm(last)
            feats[layer].append(last.detach().float().cpu().numpy())

    for layer in layers:
        feats[layer] = np.stack(feats[layer], axis=0)
    return feats

X_train = extract_layer_reprs(train_prompts, layers_to_probe)
X_test = extract_layer_reprs(test_prompts, layers_to_probe)

print('Example layer feature shape:', layers_to_probe[0], X_train[layers_to_probe[0]].shape)


In [None]:
# Fit per-layer probes and evaluate cross-template generalization

rows = []
probes = {}

for layer in layers_to_probe:
    reg = Ridge(alpha=1.0, fit_intercept=True, random_state=0)
    reg.fit(X_train[layer], y_train)

    pred_train = reg.predict(X_train[layer])
    pred_test = reg.predict(X_test[layer])

    r2_tr = float(r2_score(y_train, pred_train))
    r2_te = float(r2_score(y_test, pred_test))

    rows.append({'layer': layer, 'r2_train': r2_tr, 'r2_test': r2_te})
    probes[layer] = reg

df_probe = pd.DataFrame(rows).sort_values('r2_test', ascending=False)
display(df_probe)

best_probe_layer = int(df_probe.iloc[0]['layer'])
print('Best probe layer (by test R^2):', best_probe_layer)


In [None]:
# Plot probe quality by layer

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(df_probe.sort_values('layer').layer, df_probe.sort_values('layer').r2_test, marker='o', label='Test R^2')
ax.plot(df_probe.sort_values('layer').layer, df_probe.sort_values('layer').r2_train, marker='x', label='Train R^2', alpha=0.7)
ax.set_title('Linear probe for Bayes posterior P(B|evidence)')
ax.set_xlabel('Layer')
ax.set_ylabel('R^2')
ax.grid(True, alpha=0.3)
ax.legend()
plt.show()


In [None]:
# Scatter: probe prediction vs true posterior (on test prompts)

from sklearn.metrics import mean_absolute_error

layer = best_probe_layer
reg = probes[layer]
y_hat = reg.predict(X_test[layer])

# Clip for visualization (regression can overshoot [0,1])
y_hat_clip = np.clip(y_hat, 0, 1)

mae_val = float(mean_absolute_error(y_test, y_hat_clip))
print('Layer', layer, 'Test MAE:', mae_val)

import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(5, 5))
ax.scatter(y_test, y_hat_clip, alpha=0.5)
ax.plot([0,1],[0,1], color='black', linewidth=1)
ax.set_title(f'Probe generalization (layer={layer})')
ax.set_xlabel('True P(B|evidence)')
ax.set_ylabel('Probe prediction')
ax.grid(True, alpha=0.3)
plt.show()

# Save probe for later causal steering (optional)
import os
os.makedirs('results', exist_ok=True)
np.save(f'results/probe_w_layer{layer}.npy', reg.coef_.astype(np.float32))
np.save(f'results/probe_b_layer{layer}.npy', np.array([reg.intercept_], dtype=np.float32))
print('Saved probe weights to results/.')


## Head Patching (TransformerLens)

Finds attention heads whose activations move the model from a corrupted condition toward a clean condition.

In [None]:
# Load TransformerLens model (optional but recommended for head-level causal tracing)

import pkgutil
HAS_TLENS = pkgutil.find_loader('transformer_lens') is not None
if not HAS_TLENS:
    raise RuntimeError('transformer_lens not installed. Run: pip install -r ../requirements.txt')

import torch
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

# Use CUDA if available; otherwise this will be extremely slow.
device = 'cuda' if torch.cuda.is_available() else 'cpu'

dtype = torch.float16 if DTYPE == 'float16' else torch.bfloat16

tl_model = HookedTransformer.from_pretrained(MODEL_ID, device=device, dtype=dtype)
print('Loaded TransformerLens model on', device)
print('n_layers:', tl_model.cfg.n_layers, 'n_heads:', tl_model.cfg.n_heads)


In [None]:
# Define clean vs corrupt prompts (you can change these)

clean_seq = ['X'] * 9 + ['Y'] * 1
corrupt_seq = ['X'] * 5 + ['Y'] * 5

clean_prompt = prompt_two_generators(clean_seq, control=CONTROL_TEMPLATE)
corrupt_prompt = prompt_two_generators(corrupt_seq, control=CONTROL_TEMPLATE)

print('CLEAN:', clean_prompt)
print('CORRUPT:', corrupt_prompt)


In [None]:
# Baseline metrics (logit diff between X and Y)

import torch

X_tok = tl_model.to_single_token(' X')
Y_tok = tl_model.to_single_token(' Y')

clean_logits, clean_cache = tl_model.run_with_cache(clean_prompt)
corrupt_logits, corrupt_cache = tl_model.run_with_cache(corrupt_prompt)

def logit_diff(logits):
    return float((logits[0, -1, X_tok] - logits[0, -1, Y_tok]).item())

clean_ld = logit_diff(clean_logits)
corrupt_ld = logit_diff(corrupt_logits)
print('Clean logit_diff(X-Y):', clean_ld)
print('Corrupt logit_diff(X-Y):', corrupt_ld)


In [None]:
# Head patching sweep (partial by default)

import numpy as np
from tqdm.auto import tqdm

# Start small to ensure runtime is reasonable, then widen.
LAYER_RANGE = range(max(0, tl_model.cfg.n_layers - 8), tl_model.cfg.n_layers)  # last 8 layers
HEAD_RANGE = range(tl_model.cfg.n_heads)

patching = np.zeros((tl_model.cfg.n_layers, tl_model.cfg.n_heads), dtype=np.float32)

for layer in tqdm(LAYER_RANGE, desc='layers'):
    hook_name = get_act_name('z', layer)

    for head in HEAD_RANGE:
        def patch_head(z, hook, head=head, hook_name=hook_name):
            # z: [batch, pos, head, d_head]
            z[:, :, head, :] = clean_cache[hook_name][:, :, head, :]
            return z

        patched_logits = tl_model.run_with_hooks(
            corrupt_prompt,
            fwd_hooks=[(hook_name, patch_head)],
        )

        patched_ld = logit_diff(patched_logits)
        denom = (clean_ld - corrupt_ld)
        recovery = (patched_ld - corrupt_ld) / denom if abs(denom) > 1e-8 else 0.0
        patching[layer, head] = float(recovery)

# Top heads
flat = []
for layer in LAYER_RANGE:
    for head in HEAD_RANGE:
        flat.append((patching[layer, head], layer, head))
flat.sort(reverse=True, key=lambda x: x[0])
print('Top 10 recovery heads:')
for r,layer,head in flat[:10]:
    print(f'  layer {layer:02d} head {head:02d}: recovery={r:.3f}')


In [None]:
# Visualize patching heatmap

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 6))
im = ax.imshow(patching, aspect='auto', origin='lower', cmap='RdBu_r', vmin=-1, vmax=1)
ax.set_title('Head patching recovery (clean -> corrupt)')
ax.set_xlabel('Head')
ax.set_ylabel('Layer')
fig.colorbar(im, ax=ax, fraction=0.03, pad=0.04)
plt.show()

# Save for causal-intervention notebook
import os
os.makedirs('results', exist_ok=True)
np.save('results/head_patching_recovery.npy', patching)
print('Saved results/head_patching_recovery.npy')
