# Toy Induction (attn-only-2l): Drift via Circuit Competition

Reproduce martingale drift (order sensitivity) in a minimal circuit where induction heads can exist, and show that ablating those heads collapses the drift toward an order-invariant counting baseline.

In [None]:
# Notebook path setup: make repo imports work regardless of where you run this from
from pathlib import Path
import sys

cwd = Path.cwd().resolve()
repo_candidates = [cwd, cwd.parent]
repo_root = next((p for p in repo_candidates if (p / 'bayesian_llm').exists()), None)
if repo_root is None:
    raise RuntimeError(f'Could not find repo root from cwd={cwd}.')

if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

print('Repo root:', repo_root)


In [None]:
# Dependencies (this notebook assumes you already installed requirements)
import pkgutil

REQUIRED = ['torch', 'numpy', 'pandas', 'transformer_lens']
missing = [p for p in REQUIRED if pkgutil.find_loader(p) is None]
print('Missing:', missing if missing else 'None')

HAS_CIRCUITSVIS = pkgutil.find_loader('circuitsvis') is not None
print('circuitsvis installed:', HAS_CIRCUITSVIS)

if missing:
    print('Install with: pip install -r ../requirements.txt')
if not HAS_CIRCUITSVIS:
    print('Install circuitsvis with: pip install circuitsvis')


In [None]:
# Load attn-only-2l (TransformerLens)
import torch
import numpy as np
import pandas as pd

from transformer_lens import HookedTransformer

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DTYPE = torch.float32 if DEVICE == 'cpu' else torch.float16

model = HookedTransformer.from_pretrained('attn-only-2l', device=DEVICE, dtype=DTYPE)
print('Loaded attn-only-2l')
print('device:', DEVICE, 'dtype:', DTYPE)
print('n_layers:', model.cfg.n_layers, 'n_heads:', model.cfg.n_heads, 'd_model:', model.cfg.d_model)


In [None]:
# Define the stimuli and verify tokenization (important!)

# We predict between exactly two single-token candidates: " X" and " Y"
X_STR = ' X'
Y_STR = ' Y'

X_ID = model.to_single_token(X_STR)
Y_ID = model.to_single_token(Y_STR)

print('X_STR token id:', X_ID)
print('Y_STR token id:', Y_ID)

# Sanity: decode
print('decode X:', repr(model.tokenizer.decode([X_ID])))
print('decode Y:', repr(model.tokenizer.decode([Y_ID])))


In [None]:
# Construct sequences with identical counts but different order

# All have 3 X and 3 Y. If the model were a perfect count-based Bayesian updater, predictions should match.
SEQUENCES = {
    'patterned': ['X','Y','X','Y','X','Y'],
    'clumped':   ['X','X','X','Y','Y','Y'],
    'mixed':     ['X','Y','Y','X','X','Y'],
}

def counts(seq):
    n_x = sum(t == 'X' for t in seq)
    n_y = sum(t == 'Y' for t in seq)
    return n_x, n_y

for name, seq in SEQUENCES.items():
    print(name, seq, 'counts=', counts(seq))


In [None]:
# Order-invariant baselines: Frequentist and Bayesian (Beta-Bernoulli)

# Frequentist: p_hat = n_X / (n_X + n_Y)
# Bayesian posterior predictive: Beta(alpha,beta) prior -> (alpha+n_X)/(alpha+beta+n)

ALPHA = 1.0
BETA = 1.0

def frequentist_p_x(n_x, n_y):
    n = n_x + n_y
    return n_x / n if n > 0 else 0.5

def bayes_p_x(n_x, n_y, *, alpha=ALPHA, beta=BETA):
    n = n_x + n_y
    return (alpha + n_x) / (alpha + beta + n)

# These should be identical across sequences (same counts).
for name, seq in SEQUENCES.items():
    n_x, n_y = counts(seq)
    print(name, 'freq=', round(frequentist_p_x(n_x,n_y),4), 'bayes=', round(bayes_p_x(n_x,n_y),4))


In [None]:
# Model prediction utilities: logits/probs for next token

import torch

def make_prompt(seq_tokens):
    # Force each token to be preceded by a space by using "Sequence: " prefix.
    return 'Sequence: ' + ' '.join(seq_tokens)

@torch.no_grad()
def model_p_x(prompt: str):
    # Returns normalized P(X | {X,Y}), plus raw logits and softmax probs for the two tokens.
    logits = model(prompt)  # [batch=1, pos, vocab]
    last = logits[0, -1, :]

    logit_x = float(last[X_ID].item())
    logit_y = float(last[Y_ID].item())

    probs = torch.softmax(last.float(), dim=-1)
    p_x = float(probs[X_ID].item())
    p_y = float(probs[Y_ID].item())

    norm = p_x + p_y
    p_x_norm = p_x / norm if norm > 0 else 0.5
    p_y_norm = p_y / norm if norm > 0 else 0.5

    return {
        'logit_X': logit_x,
        'logit_Y': logit_y,
        'logit_diff_XminusY': logit_x - logit_y,
        'p_X_raw': p_x,
        'p_Y_raw': p_y,
        'p_X_norm': p_x_norm,
        'p_Y_norm': p_y_norm,
    }

# Quick smoke test
print(model_p_x(make_prompt(SEQUENCES['patterned'])))


In [None]:
# Baseline: evaluate patterned/clumped/mixed and compute drift

rows = []
for name, seq in SEQUENCES.items():
    n_x, n_y = counts(seq)
    prompt = make_prompt(seq)
    out = model_p_x(prompt)

    rows.append({
        'name': name,
        'seq': ' '.join(seq),
        'n_X': n_x,
        'n_Y': n_y,
        'freq_p_X': frequentist_p_x(n_x,n_y),
        'bayes_p_X': bayes_p_x(n_x,n_y),
        **out,
    })

df = pd.DataFrame(rows).sort_values('name')
display(df)

drift = float(df.p_X_norm.max() - df.p_X_norm.min())
print('Drift (max-min) over permutations with same counts:', drift)


## Visualize Attention with CircuitVis

We visualize attention patterns for both layers. Induction heads typically show a distinctive pattern where they attend to positions that enable copying of the next token after a previous occurrence of the same context token.


In [None]:
# Run with cache and visualize attention patterns for each sequence

import pkgutil
from IPython.display import display

if pkgutil.find_loader('circuitsvis') is None:
    raise RuntimeError('circuitsvis not installed. Run: pip install circuitsvis')

import circuitsvis as cv

for name, seq in SEQUENCES.items():
    prompt = make_prompt(seq)
    logits, cache = model.run_with_cache(prompt)
    tokens = model.to_str_tokens(prompt)

    print('===', name, '===')
    print('Tokens:', tokens)

    for layer in range(model.cfg.n_layers):
        # cache['pattern', layer] has shape [batch, head, query, key]
        patt = cache['pattern', layer][0].detach().cpu().numpy()
        display(cv.attention.attention_patterns(tokens=tokens, attention=patt, title=f'{name}: layer {layer} attention patterns'))


## Identify Induction Heads via Ablation Sweep

We use a purely behavioral criterion: *which head ablation reduces drift the most?*

Metric: drift = max(P(X|{X,Y})) - min(P(X|{X,Y})) across the three sequences.


In [None]:
# Ablation utilities

from transformer_lens.utils import get_act_name

def ablate_heads(heads):
    # heads: list of (layer, head)
    # Hook on z so the head contributes ~0 to residual stream.
    layer_to_heads = {}
    for layer, head in heads:
        layer_to_heads.setdefault(int(layer), []).append(int(head))

    hooks = []
    for layer, hs in layer_to_heads.items():
        hs = sorted(set(hs))
        hook_name = get_act_name('z', layer)

        def hook_fn(z, hook, hs=hs):
            z[:, :, hs, :] = 0.0
            return z

        hooks.append((hook_name, hook_fn))
    return hooks

@torch.no_grad()
def eval_sequences(fwd_hooks=None):
    rows = []
    for name, seq in SEQUENCES.items():
        n_x, n_y = counts(seq)
        prompt = make_prompt(seq)
        logits = model(prompt) if fwd_hooks is None else model.run_with_hooks(prompt, fwd_hooks=fwd_hooks)
        last = logits[0, -1, :]
        probs = torch.softmax(last.float(), dim=-1)
        p_x = float(probs[X_ID].item())
        p_y = float(probs[Y_ID].item())
        p_x_norm = p_x / (p_x + p_y) if (p_x + p_y) > 0 else 0.5

        rows.append({
            'name': name,
            'p_X_norm': p_x_norm,
            'logit_diff_XminusY': float(last[X_ID].item() - last[Y_ID].item()),
            'bayes_p_X': bayes_p_x(n_x,n_y),
            'freq_p_X': frequentist_p_x(n_x,n_y),
        })
    df = pd.DataFrame(rows)
    drift = float(df.p_X_norm.max() - df.p_X_norm.min())
    mae_to_bayes = float((df.p_X_norm - df.bayes_p_X).abs().mean())
    return df.sort_values('name'), drift, mae_to_bayes

base_df, base_drift, base_mae = eval_sequences()
print('Baseline drift:', base_drift)
print('Baseline MAE to Bayes:', base_mae)
display(base_df)


In [None]:
# Single-head ablation sweep: which head kills drift?

results = []
for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        hooks = ablate_heads([(layer, head)])
        _, drift, mae = eval_sequences(fwd_hooks=hooks)
        results.append({
            'layer': layer,
            'head': head,
            'drift': drift,
            'mae_to_bayes': mae,
            'drift_reduction': base_drift - drift,
        })

df_sweep = pd.DataFrame(results).sort_values('drift_reduction', ascending=False)
display(df_sweep.head(10))

best = df_sweep.iloc[0].to_dict()
print('Top drift-reducing head:', best)


In [None]:
# Choose induction-head candidates and re-test

TOPK = 2  # increase if drift remains
candidates = [(int(r.layer), int(r.head)) for r in df_sweep.head(TOPK).itertuples()]
print('Ablating candidates:', candidates)

hooks = ablate_heads(candidates)
abl_df, abl_drift, abl_mae = eval_sequences(fwd_hooks=hooks)
print('After ablation drift:', abl_drift, '(baseline:', base_drift, ')')
print('After ablation MAE to Bayes:', abl_mae, '(baseline:', base_mae, ')')
display(abl_df)


In [None]:
# Optional: visualize attention again (attention *patterns* won't change under z-ablation)

import pkgutil
from IPython.display import display

if pkgutil.find_loader('circuitsvis') is None:
    raise RuntimeError('circuitsvis not installed. Run: pip install circuitsvis')

import circuitsvis as cv

name = 'patterned'
seq = SEQUENCES[name]
prompt = make_prompt(seq)

logits, cache = model.run_with_cache(prompt)
tokens = model.to_str_tokens(prompt)

print('Candidates:', candidates)
print('Note: z-ablation changes contributions, not the attention weights themselves.')

for layer in range(model.cfg.n_layers):
    patt = cache['pattern', layer][0].detach().cpu().numpy()
    display(cv.attention.attention_patterns(tokens=tokens, attention=patt, title=f'{name}: layer {layer} attention patterns (baseline)'))
