# Neurons That Panic — Core Experiment Notebook

This notebook reproduces the main experiments from the paper:

- adversarial trigger generation
- activation extraction (clean vs adversarial)
- |Δnorm| computation and ranking
- positional and layerwise analyses
- causal patching

Running this notebook top-to-bottom will generate all figures and CSV outputs used in the paper. Sanity checks are handled in a separate notebook.


In [1]:
!pip install transformer_lens



In [2]:
!pip install -q torch datasets matplotlib seaborn

## 1. Setup & Imports

In [3]:
import torch
import numpy as np
import random
import math
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformer_lens import HookedTransformer
from datasets import load_dataset
from typing import List, Tuple, Dict, Optional

import warnings
warnings.filterwarnings('ignore')

In [4]:
SEED = 42

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

## 2. Load Model and Prompts

In [5]:
MODEL_NAME = "EleutherAI/pythia-160m-deduped"
FALLBACK_MODEL = "EleutherAI/pythia-70m-deduped"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

In [6]:
# download the weights of either the main model or the fallback model.

try:
    model = HookedTransformer.from_pretrained(
        MODEL_NAME,
        device=DEVICE,
        dtype=DTYPE,
        fold_ln=False,
        center_writing_weights=False,
        center_unembed=False,
    )

except Exception as e:
    print(f"Failed to load {MODEL_NAME}: {e}")
    try:
        model = HookedTransformer.from_pretrained(
            FALLBACK_MODEL,
            device=DEVICE,
            dtype=DTYPE,
            fold_ln=False,
            center_writing_weights=False,
            center_unembed=False,
        )
        MODEL_NAME = FALLBACK_MODEL
    except Exception as e2:
        print(f"Failed to load fallback model: {e2}")
        raise

`torch_dtype` is deprecated! Use `dtype` instead!


Loaded pretrained model EleutherAI/pythia-160m-deduped into HookedTransformer


In [7]:
tokenizer = model.tokenizer

In [9]:
dataset = load_dataset("glue", "sst2", split="validation")

N_PROMPTS = 200
clean_prompts = [
    f"Review: {dataset[i]['sentence']}\nSentiment:"
    for i in range(min(N_PROMPTS, len(dataset)))
]

## 3. Generate Adversarial Triggers

In [10]:
labels = [dataset[i]['label'] for i in range(len(clean_prompts))]

In [11]:
def build_candidate_token_list(model, n_tokens=100):
    """Build a diverse list of candidate tokens that decode to actual words."""
    special_token_ids = {0}  # Token 0 is <|endoftext|>
    if hasattr(model.tokenizer, 'all_special_ids'):
        special_token_ids.update(model.tokenizer.all_special_ids)

    # Sample from vocabulary ranges where actual words are more likely
    # Use smaller steps to get more diverse tokens
    ranges_to_check = [
        (1000, 5000, 20),
        (5000, 15000, 10),
        (15000, 30000, 5),
        (30000, 45000, 3),
    ]

    good_tokens = []
    for vocab_start, vocab_end, step in ranges_to_check:
        for token_id in range(vocab_start, min(vocab_end, model.cfg.d_vocab), step):
            if token_id in special_token_ids:
                continue

            try:
                token_str = model.to_string(torch.tensor([token_id]))
                # Filter: must contain alphanumeric characters, not special tokens
                if (token_str and
                    len(token_str.strip()) > 0 and
                    not token_str.startswith('<|') and
                    not token_str.startswith('[') and
                    len(token_str) < 15 and
                    any(c.isalnum() for c in token_str)):
                    good_tokens.append(token_id)
                    if len(good_tokens) >= n_tokens:
                        return good_tokens
            except:
                continue

    return good_tokens[:n_tokens]

In [12]:
def get_label_token_ids(model):
    """Get token IDs for 'negative' and 'positive' labels."""
    label_tokens = {}
    for label_val, label_name in [(0, "negative"), (1, "positive")]:
        try:
            token_id = model.to_single_token(f" {label_name}")
            label_tokens[label_val] = token_id
        except:
            try:
                token_id = model.to_single_token(label_name)
                label_tokens[label_val] = token_id
            except:
                tokens = model.tokenizer.encode(f" {label_name}", add_special_tokens=False)
                if tokens:
                    label_tokens[label_val] = tokens[0]
    return label_tokens

In [13]:
def insert_trigger(prompt, trigger_tokens):
    """Insert trigger tokens after 'Sentiment:' in the prompt."""
    parts = prompt.split()

    try:
        sentiment_idx = next(
            i for i, tok in enumerate(parts)
            if tok.startswith("Sentiment")
        )
    except StopIteration:
        sentiment_idx = 0  # fallback when no prefix is found

    insert_position = sentiment_idx + 1
    adv_parts = parts[:insert_position] + trigger_tokens + parts[insert_position:]
    return " ".join(adv_parts)


def evaluate_candidates_batched(model, base_tokens, position, candidate_tokens, correct_label_token, batch_size=32):
    """Evaluate multiple candidate tokens in a batched forward pass."""

    valid_candidates = [t for t in candidate_tokens if t != 0]
    if not valid_candidates:
        return []

    objectives = []
    device = base_tokens.device
    dtype = base_tokens.dtype

    # Process candidates in batches
    for i in range(0, len(valid_candidates), batch_size):
        batch_candidates = valid_candidates[i:i+batch_size]

        batch_sequences = []
        for token_id in batch_candidates:
            modified = torch.cat([
                base_tokens[:position],
                torch.tensor([token_id], device=device, dtype=dtype),
                base_tokens[position:]
            ])
            batch_sequences.append(modified)

        # Stack into batch tensor (all sequences have same length)
        batch_tensor = torch.stack(batch_sequences)

        with torch.no_grad():
            logits = model(batch_tensor)

            next_token_logits = logits[:, -1, :]  # [batch, vocab]
            next_token_probs = torch.softmax(next_token_logits, dim=-1)

            # Objective: negative log-prob of correct label
            adv_probs = next_token_probs[:, correct_label_token]
            batch_objectives = -torch.log(adv_probs + 1e-10)
            objectives.extend(batch_objectives.cpu().tolist())

    return objectives


def generate_adv_prompt(model, prompt, label, candidate_tokens, positions, label_tokens, batch_size=32):
    """Generate adversarial prompt by inserting 1-3 tokens using iterative greedy approach."""
    tokens = model.to_tokens(prompt)[0]
    seq_len = tokens.shape[0]

    if label not in label_tokens:
        return prompt, 0, 0, [], []

    correct_label_token = label_tokens[label]

    # Get baseline probability of correct label
    with torch.no_grad():
        logits = model(tokens.unsqueeze(0))
        next_token_logits = logits[0, -1, :]
        next_token_probs = torch.softmax(next_token_logits, dim=-1)
        baseline_prob = next_token_probs[correct_label_token].item()

    # Filter positions to valid range
    valid_positions = [p for p in positions if 0 < p < seq_len]
    if not valid_positions:
        return prompt, 0, 0, [], []

    # find best 1-token insertion
    best_1_token = None
    best_1_position = None
    best_1_objective = float('-inf')

    for pos in valid_positions:
        objectives = evaluate_candidates_batched(
            model, tokens, pos, candidate_tokens, correct_label_token, batch_size
        )
        valid_candidates = [t for t in candidate_tokens if t != 0]

        for idx, objective in enumerate(objectives):
            if objective > best_1_objective:
                best_1_objective = objective
                best_1_token = valid_candidates[idx]
                best_1_position = pos

    # If no improvement found, return original
    if best_1_token is None or best_1_position is None:
        return prompt, 0, 0, [], [], {"k1": None, "k2": None, "k3": None}, {"r2": None, "r3": None}, []

    # Build 1-token result
    tokens_1 = torch.cat([
        tokens[:best_1_position],
        torch.tensor([best_1_token], device=tokens.device, dtype=tokens.dtype),
        tokens[best_1_position:]
    ])

    # re-evaluate objective
    with torch.no_grad():
        logits_1 = model(tokens_1.unsqueeze(0))
        next_token_logits_1 = logits_1[0, -1, :]
        next_token_probs_1 = torch.softmax(next_token_logits_1, dim=-1)
        objective_1_actual = -np.log(next_token_probs_1[correct_label_token].item() + 1e-10)

    result_1 = {
        'tokens': tokens_1,
        'objective': objective_1_actual,
        'num_tokens': 1,
        'inserted_tokens': [best_1_token],
        'positions': [best_1_position]
    }

    best_result = result_1

    # try adding 2nd token
    # test positions around 1st token
    pos_1 = best_1_position
    candidate_positions_2 = []
    if pos_1 > 0:
        candidate_positions_2.append(pos_1 - 1)  # Before
    candidate_positions_2.append(pos_1)  # At same position (creates sequence)
    if pos_1 < tokens_1.shape[0] - 1:
        candidate_positions_2.append(pos_1 + 1)  # After
    candidate_positions_2.append(pos_1 + 2)  # After the inserted token

    # Filter to valid positions
    candidate_positions_2 = [p for p in candidate_positions_2 if 0 < p < tokens_1.shape[0]]

    # Initialize to actual 1-token objective for fair comparison
    best_2_objective = objective_1_actual
    best_2_token = None
    best_2_position = None

    for pos_2 in candidate_positions_2:
        objectives = evaluate_candidates_batched(
            model, tokens_1, pos_2, candidate_tokens, correct_label_token, batch_size
        )
        valid_candidates = [t for t in candidate_tokens if t != 0]

        for idx, objective in enumerate(objectives):
            # update if better
            if objective > best_2_objective + 1e-6:
                best_2_objective = objective
                best_2_token = valid_candidates[idx]
                best_2_position = pos_2

    # compare with threshold
    if best_2_token is not None and best_2_position is not None and best_2_objective > objective_1_actual + 1e-6:
        tokens_2 = torch.cat([
            tokens_1[:best_2_position],
            torch.tensor([best_2_token], device=tokens.device, dtype=tokens.dtype),
            tokens_1[best_2_position:]
        ])

        # re-evaluate objective
        with torch.no_grad():
            logits_2 = model(tokens_2.unsqueeze(0))
            next_token_logits_2 = logits_2[0, -1, :]
            next_token_probs_2 = torch.softmax(next_token_logits_2, dim=-1)
            objective_2_actual = -np.log(next_token_probs_2[correct_label_token].item() + 1e-10)

        result_2 = {
            'tokens': tokens_2,
            'objective': objective_2_actual,
            'num_tokens': 2,
            'inserted_tokens': [best_1_token, best_2_token],
            'positions': [best_1_position, best_2_position]
        }
        best_result = result_2

        # try adding 3rd token
        pos_2 = best_2_position
        candidate_positions_3 = []
        if pos_2 > 0:
            candidate_positions_3.append(pos_2 - 1)
        candidate_positions_3.append(pos_2)
        if pos_2 < tokens_2.shape[0] - 1:
            candidate_positions_3.append(pos_2 + 1)
        candidate_positions_3.append(pos_2 + 2)

        candidate_positions_3 = [p for p in candidate_positions_3 if 0 < p < tokens_2.shape[0]]

        # Initialize to actual 2-token objective for fair comparison
        best_3_objective = objective_2_actual
        best_3_token = None
        best_3_position = None

        for pos_3 in candidate_positions_3:
            objectives = evaluate_candidates_batched(
                model, tokens_2, pos_3, candidate_tokens, correct_label_token, batch_size
            )
            valid_candidates = [t for t in candidate_tokens if t != 0]

            for idx, objective in enumerate(objectives):
                # update if better
                if objective > best_3_objective + 1e-6:
                    best_3_objective = objective
                    best_3_token = valid_candidates[idx]
                    best_3_position = pos_3

        # compare with threshold
        if best_3_token is not None and best_3_position is not None and best_3_objective > objective_2_actual + 1e-6:
            tokens_3 = torch.cat([
                tokens_2[:best_3_position],
                torch.tensor([best_3_token], device=tokens.device, dtype=tokens.dtype),
                tokens_2[best_3_position:]
            ])

            # re-evaluate objective
            with torch.no_grad():
                logits_3 = model(tokens_3.unsqueeze(0))
                next_token_logits_3 = logits_3[0, -1, :]
                next_token_probs_3 = torch.softmax(next_token_logits_3, dim=-1)
                objective_3_actual = -np.log(next_token_probs_3[correct_label_token].item() + 1e-10)

            result_3 = {
                'tokens': tokens_3,
                'objective': objective_3_actual,
                'num_tokens': 3,
                'inserted_tokens': [best_1_token, best_2_token, best_3_token],
                'positions': [best_1_position, best_2_position, best_3_position]
            }
            best_result = result_3

    # Convert inserted tokens to strings and insert after "Sentiment:"
    inserted_token_strings = [model.to_string(torch.tensor([tid])) for tid in best_result['inserted_tokens']]
    adv_prompt = insert_trigger(prompt, inserted_token_strings)

    trigger_pos = best_result['positions'][0] + 1 if best_result['positions'] else 0

    # Build objective_k dict
    objective_k = {
        "k1": objective_1_actual if 'objective_1_actual' in locals() else None,
        "k2": objective_2_actual if 'objective_2_actual' in locals() else None,
        "k3": objective_3_actual if 'objective_3_actual' in locals() else None
    }

    # Compute relative improvements
    relative_improvements = {"r2": None, "r3": None}
    if objective_k["k1"] is not None and objective_k["k2"] is not None and objective_k["k1"] > 0:
        relative_improvements["r2"] = (objective_k["k2"] - objective_k["k1"]) / objective_k["k1"]
    if objective_k["k2"] is not None and objective_k["k3"] is not None and objective_k["k2"] > 0:
        relative_improvements["r3"] = (objective_k["k3"] - objective_k["k2"]) / objective_k["k2"]

    # Compute ablation marginals
    ablation_marginals = []
    if best_result['num_tokens'] == 3 and 'tokens_3' in locals():
        # For 3-token: ablate each token individually
        inserted_tokens = best_result['inserted_tokens']
        insertion_positions = best_result['positions']
        baseline_obj = objective_k["k3"]

        # Sort positions to handle insertions in order
        sorted_indices = sorted(range(3), key=lambda i: insertion_positions[i])

        for i in range(3):
            # Build sequence without token i
            tokens_ablated = tokens.clone()
            for idx in sorted_indices:
                if idx != i:
                    tok = inserted_tokens[idx]
                    pos = insertion_positions[idx]
                    # adjust position for previous insertions
                    offset = sum(1 for j in sorted_indices if j < idx and j != i and insertion_positions[j] <= pos)
                    insert_pos = pos + offset
                    tokens_ablated = torch.cat([
                        tokens_ablated[:insert_pos],
                        torch.tensor([tok], device=tokens.device, dtype=tokens.dtype),
                        tokens_ablated[insert_pos:]
                    ])

            # Evaluate ablated sequence
            with torch.no_grad():
                logits_ablated = model(tokens_ablated.unsqueeze(0))
                next_token_logits_ablated = logits_ablated[0, -1, :]
                next_token_probs_ablated = torch.softmax(next_token_logits_ablated, dim=-1)
                objective_ablated = -np.log(next_token_probs_ablated[correct_label_token].item() + 1e-10)

            # compute marginal
            if baseline_obj is not None and baseline_obj > 1e-10:
                marginal = (objective_ablated - baseline_obj) / baseline_obj
            else:
                marginal = 0.0
            ablation_marginals.append(marginal)
    elif best_result['num_tokens'] == 2:
        # For 2-token: marginal of 2nd token relative to 1-token baseline
        if objective_k["k1"] is not None and objective_k["k2"] is not None and objective_k["k1"] > 1e-10:
            marginal_2 = (objective_k["k2"] - objective_k["k1"]) / objective_k["k1"]
            ablation_marginals = [0.0, marginal_2]
        else:
            ablation_marginals = [0.0, 0.0]
    elif best_result['num_tokens'] == 1:
        ablation_marginals = [0.0]

    return (
        adv_prompt,
        trigger_pos,
        best_result['num_tokens'],
        best_result['inserted_tokens'],
        best_result['positions'],
        objective_k,
        relative_improvements,
        ablation_marginals
    )

In [14]:
def generate_all_adv_prompts(model, clean_prompts, labels, n_candidates=500, batch_size=32, artifacts_dir="artifacts"):
    """Generate adversarial prompts for all clean prompts using iterative greedy 1-3 token insertion."""
    import json
    import os

    # Create artifacts directory if it doesn't exist
    os.makedirs(artifacts_dir, exist_ok=True)

    label_tokens = get_label_token_ids(model)

    # build candidate token list once and reuse for all prompts
    candidate_tokens = build_candidate_token_list(model, n_candidates)

    if not candidate_tokens:
        raise ValueError("Failed to build candidate token list")

    adv_prompts = []
    trigger_positions = []
    metadata_list = []

    metadata_file = os.path.join(artifacts_dir, "adv_metadata.jsonl")
    # Clear existing file
    if os.path.exists(metadata_file):
        os.remove(metadata_file)

    for i, (prompt, label) in enumerate(zip(clean_prompts, labels)):
        tokens = model.to_tokens(prompt)[0]
        seq_len = tokens.shape[0]

        # calculate max insertion position
        review_text = prompt.split("\nSentiment:")[0]
        review_tokens = model.to_tokens(review_text)[0]
        max_insertion_pos = review_tokens.shape[0] - 1

        # check minimum tokens
        if max_insertion_pos < 2:
            adv_prompts.append(prompt)
            trigger_positions.append(0)
            metadata_entry = {
                "id": i,
                "clean_prompt": prompt,
                "clean_label": int(label),
                "adv_tokens": [],
                "adv_decoded": "",
                "positions": [],
                "objective_k": {"k1": None, "k2": None, "k3": None},
                "relative_improvements": {"r2": None, "r3": None},
                "ablation_marginals": [],
                "chosen_k": 0
            }
            metadata_list.append(metadata_entry)
            continue

        # determine insertion positions
        positions = []
        if max_insertion_pos > 1:
            positions.append(1)
        if max_insertion_pos > 3:
            positions.append(min(max_insertion_pos // 2, max_insertion_pos - 1))
        if max_insertion_pos > 2:
            positions.append(max_insertion_pos)
        positions = sorted(set([p for p in positions if 0 < p <= max_insertion_pos]))[:3]

        if not positions:
            adv_prompts.append(prompt)
            trigger_positions.append(0)
            metadata_entry = {
                "id": i,
                "clean_prompt": prompt,
                "clean_label": int(label),
                "adv_tokens": [],
                "adv_decoded": "",
                "positions": [],
                "objective_k": {"k1": None, "k2": None, "k3": None},
                "relative_improvements": {"r2": None, "r3": None},
                "ablation_marginals": [],
                "chosen_k": 0
            }
            metadata_list.append(metadata_entry)
            continue

        # generate adversarial prompt
        result = generate_adv_prompt(
            model, prompt, label, candidate_tokens, positions, label_tokens, batch_size
        )
        adv_prompt, trigger_pos, num_tokens, inserted_tokens, insertion_positions, objective_k, relative_improvements, ablation_marginals = result

        # Decode inserted tokens
        adv_decoded = " ".join([model.to_string(torch.tensor([tid])) for tid in inserted_tokens])

        # Build metadata entry
        metadata_entry = {
            "id": i,
            "clean_prompt": prompt,
            "clean_label": int(label),
            "adv_tokens": [int(tid) for tid in inserted_tokens],
            "adv_decoded": adv_decoded,
            "positions": [int(pos) for pos in insertion_positions],
            "objective_k": {
                "k1": float(objective_k["k1"]) if objective_k["k1"] is not None else None,
                "k2": float(objective_k["k2"]) if objective_k["k2"] is not None else None,
                "k3": float(objective_k["k3"]) if objective_k["k3"] is not None else None
            },
            "relative_improvements": {
                "r2": float(relative_improvements["r2"]) if relative_improvements["r2"] is not None else None,
                "r3": float(relative_improvements["r3"]) if relative_improvements["r3"] is not None else None
            },
            "ablation_marginals": [float(m) for m in ablation_marginals],
            "chosen_k": int(num_tokens)
        }

        # Save to JSONL immediately (append mode)
        with open(metadata_file, 'a') as f:
            f.write(json.dumps(metadata_entry) + '\n')

        adv_prompts.append(adv_prompt)
        trigger_positions.append(trigger_pos)
        metadata_list.append(metadata_entry)

        if (i + 1) % 50 == 0:
            print(f"  Processed {i + 1}/{len(clean_prompts)} prompts")

    # build metadata dict
    metadata = {
        'num_tokens': [m['chosen_k'] for m in metadata_list],
        'inserted_tokens': [m['adv_tokens'] for m in metadata_list],
        'insertion_positions': [m['positions'] for m in metadata_list],
        'full_metadata': metadata_list
    }

    return adv_prompts, trigger_positions, metadata


In [15]:
adv_prompts, trigger_positions, metadata = generate_all_adv_prompts(
    model, clean_prompts, labels, n_candidates=500, batch_size=32
)

  Processed 50/200 prompts
  Processed 100/200 prompts
  Processed 150/200 prompts
  Processed 200/200 prompts


In [16]:
# Save clean and adversarial prompts to artifacts/
import os
os.makedirs("artifacts", exist_ok=True)

with open("artifacts/clean_prompts.txt", "w") as f:
    for prompt in clean_prompts:
        f.write(prompt + "\n")

with open("artifacts/adv_prompts.txt", "w") as f:
    for prompt in adv_prompts:
        f.write(prompt + "\n")

## 4. Extract Clean and Adversarial Activations

In [17]:
def extract_activations(model, prompts, hook_points=None, batch_size=16):
    """Extract activations for a list of prompts."""
    if hook_points is None:
        # default hook points
        hook_points = []
        for l in range(model.cfg.n_layers):
            hook_points.append(f"blocks.{l}.hook_mlp_out")
            hook_points.append(f"blocks.{l}.attn.hook_z")

    activations = {hook: [] for hook in hook_points}

    # find max sequence length
    max_seq_len = 0
    for prompt in prompts:
        tokens = model.to_tokens(prompt)
        max_seq_len = max(max_seq_len, tokens.shape[1])

    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i+batch_size]
        batch_tokens = model.to_tokens(batch_prompts)

        with torch.no_grad():
            _, cache = model.run_with_cache(batch_tokens)

        # extract and pad activations
        for hook in hook_points:
            if hook in cache:
                # cache[hook] shape: [batch, seq_len, ...]
                hook_acts = cache[hook].cpu()
                batch_size_actual, seq_len = hook_acts.shape[0], hook_acts.shape[1]

                # pad if needed
                if seq_len < max_seq_len:
                    pad_shape = list(hook_acts.shape)
                    pad_shape[1] = max_seq_len - seq_len
                    padding = torch.zeros(pad_shape, dtype=hook_acts.dtype, device=hook_acts.device)
                    hook_acts = torch.cat([hook_acts, padding], dim=1)

                activations[hook].append(hook_acts)

    # concatenate batches
    for hook in hook_points:
        if activations[hook]:
            activations[hook] = torch.cat(activations[hook], dim=0)
        else:
            activations[hook] = None

    return activations


## 5. Compute |Δnorm| and Rank Components

In [18]:
# Extract activations for clean and adversarial prompts
activations_clean = extract_activations(model, clean_prompts, batch_size=16)
activations_adv = extract_activations(model, adv_prompts, batch_size=16)

# Save activations
torch.save(activations_clean, "artifacts/activations_clean.pt")
torch.save(activations_adv, "artifacts/activations_adv.pt")


## 6. Causal Patching Experiments

In [19]:
def compute_delta_metrics(activations_clean, activations_adv, trigger_positions, metadata, model):
    """Compute Δactivation and Δnorm for all components."""
    components = []
    eps = 1e-10

    # Get hook points
    hook_points = [h for h in activations_clean.keys() if h is not None]

    for hook in hook_points:
        if hook not in activations_adv or activations_clean[hook] is None:
            continue

        clean_acts = activations_clean[hook]  # [n_prompts, seq_len, ...]
        adv_acts = activations_adv[hook]  # [n_prompts, seq_len, ...]

        # parse layer and type
        parts = hook.split(".")
        layer = int(parts[1])

        if "mlp" in hook:
            comp_type = "mlp"
            # MLP: [batch, seq_len, d_model]
            n_components = clean_acts.shape[-1]
        elif "attn" in hook and "hook_z" in hook:
            comp_type = "attn"
            # Attention: [batch, seq_len, n_heads, d_head]
            n_components = clean_acts.shape[2]  # n_heads
            # average over d_head
            clean_acts = clean_acts.mean(dim=-1)  # [batch, seq_len, n_heads]
            adv_acts = adv_acts.mean(dim=-1)  # [batch, seq_len, n_heads]
        else:
            continue

        # compute delta at trigger positions
        for comp_idx in range(n_components):
            clean_comp = clean_acts[:, :, comp_idx]  # [n_prompts, seq_len]
            adv_comp = adv_acts[:, :, comp_idx]  # [n_prompts, seq_len]

            # get trigger positions
            trigger_pos_list = []
            for i, (trigger_pos, insertion_positions) in enumerate(zip(trigger_positions, metadata['insertion_positions'])):
                if trigger_pos > 0:
                    if insertion_positions:
                        trigger_pos_list.append(insertion_positions[0])
                    else:
                        trigger_pos_list.append(trigger_pos - 1)
                else:
                    trigger_pos_list.append(None)

            # compute mean at trigger positions
            clean_trigger_vals = []
            adv_trigger_vals = []
            clean_all_vals = []

            for i, (clean_seq, adv_seq, trigger_pos) in enumerate(zip(clean_comp, adv_comp, trigger_pos_list)):
                if trigger_pos is not None and trigger_pos < clean_seq.shape[0]:
                    clean_trigger_vals.append(clean_seq[trigger_pos].item())
                    adv_trigger_vals.append(adv_seq[trigger_pos].item())
                # collect all positions for std
                clean_all_vals.extend(clean_seq.cpu().tolist())

            if not clean_trigger_vals:
                continue
            mean_clean_trigger = np.mean(clean_trigger_vals)
            mean_adv_trigger = np.mean(adv_trigger_vals)
            std_clean_all = np.std(clean_all_vals) if clean_all_vals else eps

            delta = mean_adv_trigger - mean_clean_trigger
            delta_norm = delta / (std_clean_all + eps)

            components.append({
                'layer': layer,
                'type': comp_type,
                'index': comp_idx,
                'delta': delta,
                'delta_norm': delta_norm,
                'position': 'trigger'
            })

    df = pd.DataFrame(components)
    return df


In [20]:
# Compute delta metrics and rank components
panic_components_df = compute_delta_metrics(
    activations_clean, activations_adv, trigger_positions, metadata, model
)

# Rank by absolute delta_norm
panic_components_df['abs_delta_norm'] = panic_components_df['delta_norm'].abs()
panic_components_df = panic_components_df.sort_values('abs_delta_norm', ascending=False)

# Save to CSV
panic_components_df.to_csv("artifacts/panic_components.csv", index=False)


## 7. Save All Figures and CSV Outputs


In [21]:
def causal_patching(model, clean_prompts, adv_prompts, labels, panic_components_df,
                    metadata, top_k=20, n_random_trials=5, batch_size=16):
    """Perform causal patching on top-K panic components."""
    label_tokens = get_label_token_ids(model)
    eps = 1e-10

    # Get top-K components
    top_components = panic_components_df.head(top_k)

    # Build hook dict for patching: {hook_name: [(comp_idx, ...), ...]}
    hooks_to_patch = {}
    for _, row in top_components.iterrows():
        layer = row['layer']
        comp_type = row['type']
        comp_idx = row['index']

        if comp_type == "mlp":
            hook = f"blocks.{layer}.hook_mlp_out"
        elif comp_type == "attn":
            hook = f"blocks.{layer}.attn.hook_z"
        else:
            continue

        if hook not in hooks_to_patch:
            hooks_to_patch[hook] = []
        hooks_to_patch[hook].append((comp_idx, comp_type))



    # Get baseline scores (clean and adversarial)
    clean_scores = []
    adv_scores = []

    for i in range(0, len(clean_prompts), batch_size):
        batch_clean = clean_prompts[i:i+batch_size]
        batch_adv = adv_prompts[i:i+batch_size]
        batch_labels = labels[i:i+batch_size]

        clean_tokens = model.to_tokens(batch_clean)
        adv_tokens = model.to_tokens(batch_adv)

        with torch.no_grad():
            clean_logits = model(clean_tokens)
            adv_logits = model(adv_tokens)

            for j, label in enumerate(batch_labels):
                if label in label_tokens:
                    clean_prob = torch.softmax(clean_logits[j, -1, :], dim=-1)[label_tokens[label]].item()
                    adv_prob = torch.softmax(adv_logits[j, -1, :], dim=-1)[label_tokens[label]].item()
                    clean_scores.append(clean_prob)
                    adv_scores.append(adv_prob)
                else:
                    clean_scores.append(0.0)
                    adv_scores.append(0.0)

    # Run patching - process prompts individually to handle sequence length mismatches
    patched_scores = []

    for i in range(len(adv_prompts)):
        adv_prompt = adv_prompts[i]
        clean_prompt = clean_prompts[i]
        label = labels[i]
        insertion_positions = metadata['insertion_positions'][i]

        adv_tokens = model.to_tokens(adv_prompt)
        clean_tokens = model.to_tokens(clean_prompt)

        # Get clean cache
        with torch.no_grad():
            _, clean_cache = model.run_with_cache(clean_tokens)

            # Build position mapping: clean positions to adversarial positions
            clean_to_adv_map = {}
            clean_len = clean_tokens.shape[1]
            adv_len = adv_tokens.shape[1]

            if insertion_positions:
                sorted_insertions = sorted(insertion_positions)
                for clean_pos in range(clean_len):
                    insertions_before = sum(1 for ins_pos in sorted_insertions if ins_pos <= clean_pos)
                    adv_pos = clean_pos + insertions_before
                    if adv_pos < adv_len:
                        clean_to_adv_map[clean_pos] = adv_pos
            else:
                for pos in range(min(clean_len, adv_len)):
                    clean_to_adv_map[pos] = pos

            # Define patching hook function
            def make_patch_hook(hook_name, clean_acts_single):
                def patch_hook(activations, hook):
                    if hook.name != hook_name or hook.name not in hooks_to_patch:
                        return activations

                    # activations shape: [1, adv_seq_len, ...]
                    # clean_acts_single shape: [1, clean_seq_len, ...]
                    components_to_patch = hooks_to_patch[hook.name]
                    patched = activations.clone()

                    if "mlp" in hook.name:
                        for comp_idx, _ in components_to_patch:
                            for clean_pos, adv_pos in clean_to_adv_map.items():
                                if clean_pos < clean_acts_single.shape[1] and adv_pos < activations.shape[1]:
                                    patched[0, adv_pos, comp_idx] = clean_acts_single[0, clean_pos, comp_idx]
                    elif "attn" in hook.name and "hook_z" in hook.name:
                        for comp_idx, _ in components_to_patch:
                            for clean_pos, adv_pos in clean_to_adv_map.items():
                                if clean_pos < clean_acts_single.shape[1] and adv_pos < activations.shape[1]:
                                    patched[0, adv_pos, comp_idx, :] = clean_acts_single[0, clean_pos, comp_idx, :]

                    return patched
                return patch_hook

            # Create hooks for each hook point
            fwd_hooks = []
            for hook_name in hooks_to_patch.keys():
                if hook_name in clean_cache:
                    clean_acts_single = clean_cache[hook_name]  # [1, clean_seq_len, ...]
                    fwd_hooks.append((hook_name, make_patch_hook(hook_name, clean_acts_single)))

            # Run adversarial with patching
            patched_logits = model.run_with_hooks(adv_tokens, fwd_hooks=fwd_hooks)

            if label in label_tokens:
                patched_prob = torch.softmax(patched_logits[0, -1, :], dim=-1)[label_tokens[label]].item()
                patched_scores.append(patched_prob)
            else:
                patched_scores.append(0.0)

    # Compute recovery metric
    recoveries = []
    for clean, adv, patched in zip(clean_scores, adv_scores, patched_scores):
        denominator = clean - adv + eps
        if denominator > eps:
            recovery = (patched - adv) / denominator
        else:
            recovery = 0.0
        recoveries.append(recovery)

    mean_recovery = np.mean(recoveries)

    # Random baseline
    random_recoveries = []

    for trial in range(n_random_trials):
        # Sample random components from same layers
        random_components = panic_components_df.sample(n=top_k)
        random_hooks = {}

        for _, row in random_components.iterrows():
            layer = row['layer']
            comp_type = row['type']
            comp_idx = row['index']

            if comp_type == "mlp":
                hook = f"blocks.{layer}.hook_mlp_out"
            elif comp_type == "attn":
                hook = f"blocks.{layer}.attn.hook_z"
            else:
                continue

            if hook not in random_hooks:
                random_hooks[hook] = []
            random_hooks[hook].append((comp_idx, comp_type))

        random_patched_scores = []
        for i in range(len(adv_prompts)):
            adv_prompt = adv_prompts[i]
            clean_prompt = clean_prompts[i]
            label = labels[i]
            insertion_positions = metadata['insertion_positions'][i]

            adv_tokens = model.to_tokens(adv_prompt)
            clean_tokens = model.to_tokens(clean_prompt)

            with torch.no_grad():
                _, clean_cache = model.run_with_cache(clean_tokens)

                # Build position mapping (same as main patching)
                clean_to_adv_map = {}
                clean_len = clean_tokens.shape[1]
                adv_len = adv_tokens.shape[1]

                if insertion_positions:
                    sorted_insertions = sorted(insertion_positions)
                    for clean_pos in range(clean_len):
                        insertions_before = sum(1 for ins_pos in sorted_insertions if ins_pos <= clean_pos)
                        adv_pos = clean_pos + insertions_before
                        if adv_pos < adv_len:
                            clean_to_adv_map[clean_pos] = adv_pos
                else:
                    for pos in range(min(clean_len, adv_len)):
                        clean_to_adv_map[pos] = pos

                def make_random_patch_hook(hook_name, clean_acts_single):
                    def random_patch_hook(activations, hook):
                        if hook.name != hook_name or hook.name not in random_hooks:
                            return activations
                        components_to_patch = random_hooks[hook.name]
                        patched = activations.clone()

                        if "mlp" in hook.name:
                            for comp_idx, _ in components_to_patch:
                                for clean_pos, adv_pos in clean_to_adv_map.items():
                                    if clean_pos < clean_acts_single.shape[1] and adv_pos < activations.shape[1]:
                                        patched[0, adv_pos, comp_idx] = clean_acts_single[0, clean_pos, comp_idx]
                        elif "attn" in hook.name and "hook_z" in hook.name:
                            for comp_idx, _ in components_to_patch:
                                for clean_pos, adv_pos in clean_to_adv_map.items():
                                    if clean_pos < clean_acts_single.shape[1] and adv_pos < activations.shape[1]:
                                        patched[0, adv_pos, comp_idx, :] = clean_acts_single[0, clean_pos, comp_idx, :]
                        return patched
                    return random_patch_hook

                fwd_hooks = []
                for hook_name in random_hooks.keys():
                    if hook_name in clean_cache:
                        clean_acts_single = clean_cache[hook_name]
                        fwd_hooks.append((hook_name, make_random_patch_hook(hook_name, clean_acts_single)))

                patched_logits = model.run_with_hooks(adv_tokens, fwd_hooks=fwd_hooks)

                if label in label_tokens:
                    patched_prob = torch.softmax(patched_logits[0, -1, :], dim=-1)[label_tokens[label]].item()
                    random_patched_scores.append(patched_prob)
                else:
                    random_patched_scores.append(0.0)

        random_recoveries_trial = []
        for clean, adv, patched in zip(clean_scores, adv_scores, random_patched_scores):
            denominator = clean - adv + eps
            if denominator > eps:
                recovery = (patched - adv) / denominator
            else:
                recovery = 0.0
            random_recoveries_trial.append(recovery)
        random_recoveries.append(np.mean(random_recoveries_trial))

    mean_random_recovery = np.mean(random_recoveries)
    std_random_recovery = np.std(random_recoveries)

    results = {
        'top_k': top_k,
        'mean_recovery': mean_recovery,
        'mean_random_recovery': mean_random_recovery,
        'std_random_recovery': std_random_recovery,
        'recoveries': recoveries,
        'random_recoveries': random_recoveries
    }

    return results


In [22]:
# Run causal patching
patch_results = causal_patching(
    model, clean_prompts, adv_prompts, labels, panic_components_df, metadata,
    top_k=20, n_random_trials=5, batch_size=16
)

# Save results to CSV
patch_df = pd.DataFrame({
    'top_k': [patch_results['top_k']],
    'mean_recovery': [patch_results['mean_recovery']],
    'mean_random_recovery': [patch_results['mean_random_recovery']],
    'std_random_recovery': [patch_results['std_random_recovery']]
})
patch_df.to_csv("artifacts/patch_results.csv", index=False)


In [23]:
# Save run manifest
import json
import platform
from datetime import datetime

run_manifest = {
    "model_id": MODEL_NAME,
    "model_config": {
        "n_layers": model.cfg.n_layers,
        "d_model": model.cfg.d_model,
        "d_mlp": model.cfg.d_mlp,
        "n_heads": model.cfg.n_heads,
        "d_vocab": model.cfg.d_vocab,
        "n_ctx": model.cfg.n_ctx
    },
    "versions": {
        "pytorch": torch.__version__,
        "transformer_lens": "2.16.1",  # Update if needed
        "python": platform.python_version()
    },
    "hardware": {
        "device": DEVICE,
        "dtype": str(DTYPE),
        "cuda_available": torch.cuda.is_available(),
        "cuda_device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None
    },
    "experiment_config": {
        "n_prompts": len(clean_prompts),
        "n_candidates": 500,
        "batch_size": 32,
        "seed": SEED,
        "top_k_patching": 20,
        "n_random_trials": 5
    },
    "timestamp": datetime.now().isoformat()
}

with open("artifacts/run_manifest.json", "w") as f:
    json.dump(run_manifest, f, indent=2)

In [24]:
import json
import os

# Create output directories
os.makedirs("artifacts/plots", exist_ok=True)
os.makedirs("artifacts/tables", exist_ok=True)

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.dpi'] = 300


In [25]:
# Load all artifacts
panic_components_df = pd.read_csv("artifacts/panic_components.csv")
patch_results_df = pd.read_csv("artifacts/patch_results.csv")

adv_metadata = []
with open("artifacts/adv_metadata.jsonl", "r") as f:
    for line in f:
        adv_metadata.append(json.loads(line.strip()))

with open("artifacts/clean_prompts.txt", "r") as f:
    clean_prompts = [line.strip() for line in f if line.strip()]
with open("artifacts/adv_prompts.txt", "r") as f:
    adv_prompts = [line.strip() for line in f if line.strip()]

with open("artifacts/run_manifest.json", "r") as f:
    run_manifest = json.load(f)


## 8a. Layerwise Sensitivity Analysis

In [26]:
# Layerwise Mean |Δnorm| Analysis
layerwise_stats = panic_components_df.groupby('layer')['abs_delta_norm'].mean().sort_index()

fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(layerwise_stats.index, layerwise_stats.values, color='steelblue', alpha=0.7, edgecolor='black', linewidth=1.2)
ax.set_xlabel('Layer Index', fontsize=12, fontweight='bold')
ax.set_ylabel('Mean |Δnorm|', fontsize=12, fontweight='bold')
ax.set_title('Layerwise Mean |Δnorm| Across All Components', fontsize=14, fontweight='bold')
ax.set_xticks(layerwise_stats.index)
ax.grid(axis='y', alpha=0.3, linestyle='--')

# Add value labels on bars
for i, (layer, value) in enumerate(layerwise_stats.items()):
    ax.text(layer, value, f'{value:.2f}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig("artifacts/plots/layerwise_delta_norm.png", dpi=300, bbox_inches='tight')
plt.close()


## 8b. Positional Localization Analysis

In [32]:
# Position-Localization Analysis
top_k = run_manifest['experiment_config']['top_k_patching']
top_components = panic_components_df.head(top_k)

# create histogram
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(top_components['abs_delta_norm'], bins=20, color='coral', alpha=0.7, edgecolor='black', linewidth=1.2)
ax.set_xlabel('|Δnorm| at Trigger Position', fontsize=12, fontweight='bold')
ax.set_ylabel('Number of Components', fontsize=12, fontweight='bold')
ax.set_title(f'Distribution of |Δnorm| for Top-{top_k} Panic Components', fontsize=14, fontweight='bold')
ax.grid(axis='y', alpha=0.3, linestyle='--')
ax.axvline(top_components['abs_delta_norm'].mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {top_components["abs_delta_norm"].mean():.2f}')
ax.legend()

plt.tight_layout()
plt.savefig("artifacts/plots/position_localization.png", dpi=300, bbox_inches='tight')
plt.close()


In [28]:
# Δnorm distribution: MLP vs Attention
mlp_delta_norm = panic_components_df[panic_components_df['type'] == 'mlp']['abs_delta_norm']
attn_delta_norm = panic_components_df[panic_components_df['type'] == 'attn']['abs_delta_norm']

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Histogram comparison
ax1.hist(mlp_delta_norm, bins=50, alpha=0.6, label='MLP', color='steelblue', edgecolor='black', linewidth=0.5)
ax1.hist(attn_delta_norm, bins=50, alpha=0.6, label='Attention', color='coral', edgecolor='black', linewidth=0.5)
ax1.set_xlabel('|Δnorm|', fontsize=12, fontweight='bold')
ax1.set_ylabel('Number of Components', fontsize=12, fontweight='bold')
ax1.set_title('Distribution of |Δnorm|: MLP vs Attention', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(axis='y', alpha=0.3, linestyle='--')

# Box plot comparison
box_data = [mlp_delta_norm.values, attn_delta_norm.values]
bp = ax2.boxplot(box_data, labels=['MLP', 'Attention'], patch_artist=True,
                 boxprops=dict(facecolor='lightblue', alpha=0.7),
                 medianprops=dict(color='red', linewidth=2))
ax2.set_ylabel('|Δnorm|', fontsize=12, fontweight='bold')
ax2.set_title('Box Plot: MLP vs Attention |Δnorm|', fontsize=14, fontweight='bold')
ax2.grid(axis='y', alpha=0.3, linestyle='--')

plt.tight_layout()
plt.savefig("artifacts/plots/delta_norm_distributions.png", dpi=300, bbox_inches='tight')
plt.close()


In [29]:
# Causal Patching Visualization
top_k_recovery = patch_results_df['mean_recovery'].iloc[0]
random_k_mean = patch_results_df['mean_random_recovery'].iloc[0]
random_k_std = patch_results_df['std_random_recovery'].iloc[0]

fig, ax = plt.subplots(figsize=(8, 6))
categories = ['Top-K\nComponents', 'Random-K\nBaseline']
values = [top_k_recovery, random_k_mean]
errors = [0, random_k_std]
colors = ['steelblue', 'coral']

bars = ax.bar(categories, values, yerr=errors, capsize=10, color=colors, alpha=0.7,
              edgecolor='black', linewidth=1.5, error_kw={'elinewidth': 2, 'capthick': 2})

# Add value labels on bars
for i, (val, err) in enumerate(zip(values, errors)):
    if i == 0:
        ax.text(i, val, f'{val:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')
    else:
        ax.text(i, val + err, f'{val:.3f}±{err:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

ax.set_ylabel('Recovery Metric', fontsize=12, fontweight='bold')
ax.set_title('Causal Patching Results: Top-K vs Random-K Baseline', fontsize=14, fontweight='bold')
ax.grid(axis='y', alpha=0.3, linestyle='--')
ax.set_ylim(bottom=0)

# Add horizontal line at y=0 for reference
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)

plt.tight_layout()
plt.savefig("artifacts/plots/causal_patching_results.png", dpi=300, bbox_inches='tight')
plt.close()


In [30]:
# Top-20 Panic Components Table
top_k = run_manifest['experiment_config']['top_k_patching']
top_20_components = panic_components_df.head(top_k).copy()

# Create formatted table with selected columns
table_columns = ['layer', 'type', 'index', 'delta_norm', 'abs_delta_norm']
top_20_table = top_20_components[table_columns].copy()
top_20_table.columns = ['Layer', 'Type', 'Index', 'Δnorm', '|Δnorm|']

# Format for display
top_20_table['Type'] = top_20_table['Type'].str.upper()
top_20_table['Δnorm'] = top_20_table['Δnorm'].round(3)
top_20_table['|Δnorm|'] = top_20_table['|Δnorm|'].round(3)

# Save to CSV
top_20_table.to_csv("artifacts/tables/top_20_components.csv", index=False)


In [31]:
# Patch Recovery Summary Table
top_k_recovery = patch_results_df['mean_recovery'].iloc[0]
random_k_mean = patch_results_df['mean_random_recovery'].iloc[0]
random_k_std = patch_results_df['std_random_recovery'].iloc[0]
difference = top_k_recovery - random_k_mean

summary_data = {
    'Metric': ['Top-K Recovery', 'Random-K Recovery (mean)', 'Random-K Recovery (std)', 'Difference (Top-K - Random-K)'],
    'Value': [f'{top_k_recovery:.4f}', f'{random_k_mean:.4f}', f'{random_k_std:.4f}', f'{difference:.4f}']
}

patch_summary_df = pd.DataFrame(summary_data)

# Save to CSV
patch_summary_df.to_csv("artifacts/tables/patch_recovery_summary.csv", index=False)
