# Neurons That Panic — Sanity Checks Notebook

This notebook contains validation checks for the adversarial trigger generation pipeline:

- Check A: Objective monotonicity (R2, R3 improvements)
- Check B: Tokenization sanity (pathological token detection)
- Check C: Greedy vs beam/exhaustive search comparison
- Check D: Marginal ablation analysis
- Check E: Patch sanity pilot (1-token vs 3-token comparison)

All checks validate that the adversarial generation process is working correctly and producing meaningful results.


In [None]:
# Install dependencies (for Colab)
!pip install -q transformer_lens torch datasets matplotlib seaborn


In [None]:
import torch
import numpy as np
import random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
from transformer_lens import HookedTransformer
from typing import List, Tuple, Dict, Optional
from datetime import datetime

import warnings
warnings.filterwarnings('ignore')


In [None]:
# Setup: Seeds and constants
SEED = 42

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Constants
ARTIFACTS_DIR = "artifacts"
MODEL_NAME = "EleutherAI/pythia-160m-deduped"
FALLBACK_MODEL = "EleutherAI/pythia-70m-deduped"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

In [None]:
# Load model (same configuration as main notebook)
try:
    model = HookedTransformer.from_pretrained(
        MODEL_NAME,
        device=DEVICE,
        dtype=DTYPE,
        fold_ln=False,
        center_writing_weights=False,
        center_unembed=False,
    )
except Exception as e:
    print(f"Failed to load {MODEL_NAME}: {e}")
    try:
        model = HookedTransformer.from_pretrained(
            FALLBACK_MODEL,
            device=DEVICE,
            dtype=DTYPE,
            fold_ln=False,
            center_writing_weights=False,
            center_unembed=False,
        )
        MODEL_NAME = FALLBACK_MODEL
    except Exception as e2:
        print(f"Failed to load fallback model: {e2}")
        raise

In [None]:
# Load metadata from artifacts/adv_metadata.jsonl
metadata_file = os.path.join(ARTIFACTS_DIR, "adv_metadata.jsonl")

if not os.path.exists(metadata_file):
    raise FileNotFoundError(f"Metadata file not found: {metadata_file}")

metadata_list = []
with open(metadata_file, 'r') as f:
    for line in f:
        if line.strip():
            metadata_list.append(json.loads(line))

# validate schema and build DataFrame
required_fields = ['id', 'clean_prompt', 'clean_label', 'adv_tokens', 'positions', 
                   'objective_k', 'relative_improvements', 'ablation_marginals', 'chosen_k']

valid_entries = []
for entry in metadata_list:
    if all(field in entry for field in required_fields):
        valid_entries.append(entry)
    else:
        missing = [f for f in required_fields if f not in entry]
        print(f"Warning: Entry {entry.get('id', 'unknown')} missing fields: {missing}")

metadata_list = valid_entries
metadata_df = pd.DataFrame(metadata_list)


In [None]:
# Helper functions (copied from main notebook for self-contained execution)

def build_candidate_token_list(model, n_tokens=500):
    """Build a diverse list of candidate tokens that decode to actual words."""
    special_token_ids = {0}  # Token 0 is <|endoftext|>
    if hasattr(model.tokenizer, 'all_special_ids'):
        special_token_ids.update(model.tokenizer.all_special_ids)
    
    # sample from vocabulary ranges
    ranges_to_check = [
        (1000, 5000, 20),
        (5000, 15000, 10),
        (15000, 30000, 5),
        (30000, 45000, 3),
    ]
    
    good_tokens = []
    for vocab_start, vocab_end, step in ranges_to_check:
        for token_id in range(vocab_start, min(vocab_end, model.cfg.d_vocab), step):
            if token_id in special_token_ids:
                continue
            
            try:
                token_str = model.to_string(torch.tensor([token_id]))
                # filter valid tokens
                if (token_str and
                    len(token_str.strip()) > 0 and
                    not token_str.startswith('<|') and
                    not token_str.startswith('[') and
                    len(token_str) < 15 and
                    any(c.isalnum() for c in token_str)):
                    good_tokens.append(token_id)
                    if len(good_tokens) >= n_tokens:
                        return good_tokens
            except:
                continue
    
    return good_tokens[:n_tokens]


def get_label_token_ids(model):
    """Get token IDs for 'negative' and 'positive' labels."""
    label_tokens = {}
    for label_val, label_name in [(0, "negative"), (1, "positive")]:
        try:
            token_id = model.to_single_token(f" {label_name}")
            label_tokens[label_val] = token_id
        except:
            try:
                token_id = model.to_single_token(label_name)
                label_tokens[label_val] = token_id
            except:
                tokens = model.tokenizer.encode(f" {label_name}", add_special_tokens=False)
                if tokens:
                    label_tokens[label_val] = tokens[0]
    return label_tokens


def evaluate_candidates_batched(model, base_tokens, position, candidate_tokens, correct_label_token, batch_size=32):
    """Evaluate multiple candidate tokens in a batched forward pass."""
    # Filter out token 0 and invalid candidates
    valid_candidates = [t for t in candidate_tokens if t != 0]
    if not valid_candidates:
        return []
    
    objectives = []
    device = base_tokens.device
    dtype = base_tokens.dtype
    
    for i in range(0, len(valid_candidates), batch_size):
        batch_candidates = valid_candidates[i:i+batch_size]
        
        batch_sequences = []
        for token_id in batch_candidates:
            modified = torch.cat([
                base_tokens[:position],
                torch.tensor([token_id], device=device, dtype=dtype),
                base_tokens[position:]
            ])
            batch_sequences.append(modified)
        
        batch_tensor = torch.stack(batch_sequences)
        
        with torch.no_grad():
            logits = model(batch_tensor)
            next_token_logits = logits[:, -1, :]  # [batch, vocab]
            next_token_probs = torch.softmax(next_token_logits, dim=-1)
            
            # compute objective
            adv_probs = next_token_probs[:, correct_label_token]
            batch_objectives = -torch.log(adv_probs + 1e-10)
            objectives.extend(batch_objectives.cpu().tolist())
    
    return objectives


In [None]:
# Check A: Objective Monotonicity
# extract R2 and R3 values
r2_values = []
r3_values = []

for entry in metadata_list:
    rel_improvements = entry.get('relative_improvements', {})
    r2 = rel_improvements.get('r2')
    r3 = rel_improvements.get('r3')
    
    if r2 is not None:
        r2_values.append(r2)
    if r3 is not None:
        r3_values.append(r3)

# compute statistics
if r2_values:
    median_r2 = np.median(r2_values)
    pct_r2_above_threshold = 100 * sum(1 for r in r2_values if r >= 0.05) / len(r2_values)
else:
    median_r2 = None
    pct_r2_above_threshold = 0.0

if r3_values:
    median_r3 = np.median(r3_values)
    pct_r3_above_threshold = 100 * sum(1 for r in r3_values if r >= 0.05) / len(r3_values)
else:
    median_r3 = None
    pct_r3_above_threshold = 0.0

# check pass criteria: median R ≥ 0.05 OR ≥75% examples with R ≥ 0.05
r2_passes = (median_r2 is not None and median_r2 >= 0.05) or (pct_r2_above_threshold >= 75.0)
r3_passes = (median_r3 is not None and median_r3 >= 0.05) or (pct_r3_above_threshold >= 75.0)

check_a_passes = r2_passes and r3_passes

# generate histograms
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

if r2_values:
    axes[0].hist(r2_values, bins=30, edgecolor='black', alpha=0.7)
    axes[0].axvline(0.05, color='red', linestyle='--', label='Threshold (0.05)')
    if median_r2 is not None:
        axes[0].axvline(median_r2, color='green', linestyle='--', label=f'Median ({median_r2:.3f})')
    axes[0].set_xlabel('R2 (Relative Improvement k1→k2)')
    axes[0].set_ylabel('Frequency')
    axes[0].set_title('Distribution of R2 Values')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
else:
    axes[0].text(0.5, 0.5, 'No R2 values', ha='center', va='center')
    axes[0].set_title('Distribution of R2 Values')

if r3_values:
    axes[1].hist(r3_values, bins=30, edgecolor='black', alpha=0.7, color='orange')
    axes[1].axvline(0.05, color='red', linestyle='--', label='Threshold (0.05)')
    if median_r3 is not None:
        axes[1].axvline(median_r3, color='green', linestyle='--', label=f'Median ({median_r3:.3f})')
    axes[1].set_xlabel('R3 (Relative Improvement k2→k3)')
    axes[1].set_ylabel('Frequency')
    axes[1].set_title('Distribution of R3 Values')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
else:
    axes[1].text(0.5, 0.5, 'No R3 values', ha='center', va='center')
    axes[1].set_title('Distribution of R3 Values')

plt.tight_layout()
plt.savefig(os.path.join(ARTIFACTS_DIR, 'sanity_check_a_histograms.png'), dpi=150, bbox_inches='tight')
plt.close()

# Store results
check_a_result = {
    "status": "PASS" if check_a_passes else "FAIL",
    "median_r2": float(median_r2) if median_r2 is not None else None,
    "median_r3": float(median_r3) if median_r3 is not None else None,
    "pct_r2_above_threshold": float(pct_r2_above_threshold),
    "pct_r3_above_threshold": float(pct_r3_above_threshold),
    "r2_passes": r2_passes,
    "r3_passes": r3_passes,
    "pass_criteria_met": check_a_passes
}


In [None]:
# Check B: Tokenization Sanity
pathological_cases = []
normal_cases = []

for entry in metadata_list:
    entry_id = entry.get('id', 'unknown')
    adv_tokens = entry.get('adv_tokens', [])
    adv_decoded = entry.get('adv_decoded', '')
    
    if not adv_tokens:
        continue
    
    # Decode tokens using model
    decoded_tokens = []
    issues = []
    
    for token_id in adv_tokens:
        try:
            decoded = model.to_string(torch.tensor([token_id]))
            decoded_tokens.append(decoded)
            
            # check for pathological patterns
            if len(decoded.strip()) > 0:
                if decoded.strip() and not any(c.isalnum() for c in decoded.strip()):
                    issues.append(f"Punctuation-only: '{decoded}'")
                elif decoded.startswith('<|') and decoded.endswith('|>'):
                    issues.append(f"Special token: '{decoded}'")
                elif len(decoded.strip()) < 2 and decoded.strip() and not decoded.strip().isalnum():
                    issues.append(f"Very short fragment: '{decoded}'")
        except Exception as e:
            issues.append(f"Decode error: {e}")
            decoded_tokens.append(f"<ERROR:{token_id}>")
    
    # Compare with adv_decoded field
    decoded_str = " ".join(decoded_tokens)
    if decoded_str.strip() != adv_decoded.strip():
        issues.append(f"Mismatch with adv_decoded: decoded='{decoded_str}' vs stored='{adv_decoded}'")
    
    if issues:
        pathological_cases.append({
            "id": entry_id,
            "tokens": adv_tokens,
            "decoded": decoded_str,
            "stored_decoded": adv_decoded,
            "issues": issues
        })
    else:
        normal_cases.append({
            "id": entry_id,
            "tokens": adv_tokens,
            "decoded": decoded_str
        })

total_cases = len(pathological_cases) + len(normal_cases)
pathological_percentage = 100 * len(pathological_cases) / total_cases if total_cases > 0 else 0.0

# check pass criteria: <10% pathological cases
check_b_passes = pathological_percentage < 10.0

# save examples to file
examples_file = os.path.join(ARTIFACTS_DIR, 'sanity_check_b_examples.txt')
with open(examples_file, 'w') as f:
    f.write("Tokenization Sanity Check - Pathological Cases\n")
    f.write("=" * 60 + "\n\n")
    f.write(f"Total pathological cases: {len(pathological_cases)} / {total_cases} ({pathological_percentage:.2f}%)\n\n")
    
    for case in pathological_cases:
        f.write(f"ID: {case['id']}\n")
        f.write(f"Tokens: {case['tokens']}\n")
        f.write(f"Decoded: '{case['decoded']}'\n")
        f.write(f"Stored: '{case['stored_decoded']}'\n")
        f.write(f"Issues: {case['issues']}\n")
        f.write("-" * 60 + "\n")

# store results
check_b_result = {
    "status": "PASS" if check_b_passes else "FAIL",
    "total_prompts": total_cases,
    "pathological_cases": len(pathological_cases),
    "pathological_percentage": float(pathological_percentage),
    "pass_criteria_met": check_b_passes,
    "examples_count": len(pathological_cases)
}


In [None]:
# Check C: Greedy vs Beam/Exhaustive Search
# Compare greedy 3-token results with beam search and mini-exhaustive for 1-2 tokens
# Pass if greedy 3-token is superior or equal in ≥90% samples

print("=" * 60)
print("Check C: Greedy vs Beam/Exhaustive Search")
print("=" * 60)

# Helper function for beam search
def beam_search_adv(model, prompt, label, candidate_tokens, positions, label_tokens, 
                    beam_width=5, max_tokens=2, batch_size=32):
    """Beam search for adversarial token insertion (up to max_tokens)."""
    tokens = model.to_tokens(prompt)[0]
    correct_label_token = label_tokens[label]
    
    # Initialize beam with empty sequence
    beam = [{
        'tokens': tokens,
        'inserted': [],
        'positions': [],
        'objective': None
    }]
    
    # Evaluate initial objective
    with torch.no_grad():
        logits = model(tokens.unsqueeze(0))
        next_token_logits = logits[0, -1, :]
        next_token_probs = torch.softmax(next_token_logits, dim=-1)
        baseline_obj = -np.log(next_token_probs[correct_label_token].item() + 1e-10)
        beam[0]['objective'] = baseline_obj
    
    # Search for up to max_tokens
    for token_num in range(max_tokens):
        candidates = []
        
        for beam_item in beam:
            base_tokens = beam_item['tokens']
            inserted = beam_item['inserted']
            inserted_positions = beam_item['positions']
            
            # Try inserting at each position
            for pos in positions:
                if pos >= base_tokens.shape[0]:
                    continue
                
                # Evaluate all candidate tokens at this position
                objectives = evaluate_candidates_batched(
                    model, base_tokens, pos, candidate_tokens, correct_label_token, batch_size
                )
                valid_candidates = [t for t in candidate_tokens if t != 0]
                
                for idx, objective in enumerate(objectives):
                    new_token = valid_candidates[idx]
                    new_tokens = torch.cat([
                        base_tokens[:pos],
                        torch.tensor([new_token], device=base_tokens.device, dtype=base_tokens.dtype),
                        base_tokens[pos:]
                    ])
                    
                    candidates.append({
                        'tokens': new_tokens,
                        'inserted': inserted + [new_token],
                        'positions': inserted_positions + [pos],
                        'objective': objective
                    })
        
        # Keep top beam_width candidates
        candidates.sort(key=lambda x: x['objective'], reverse=True)
        beam = candidates[:beam_width]
        
        # Re-evaluate to ensure accuracy
        for beam_item in beam:
            with torch.no_grad():
                logits = model(beam_item['tokens'].unsqueeze(0))
                next_token_logits = logits[0, -1, :]
                next_token_probs = torch.softmax(next_token_logits, dim=-1)
                beam_item['objective'] = -np.log(next_token_probs[correct_label_token].item() + 1e-10)
    
    # Return best from beam
    if beam:
        best = max(beam, key=lambda x: x['objective'])
        return best['objective'], best['inserted'], best['positions']
    return baseline_obj, [], []


# Helper function for mini-exhaustive search
def mini_exhaustive_search(model, prompt, label, candidate_tokens, positions, label_tokens,
                          max_candidates=30, max_tokens=2, batch_size=32):
    """Mini-exhaustive search: try all combinations of top M candidates for 1-2 tokens."""
    tokens = model.to_tokens(prompt)[0]
    correct_label_token = label_tokens[label]
    
    # Limit candidate tokens to top M
    limited_candidates = candidate_tokens[:max_candidates]
    
    best_objective = None
    best_tokens = []
    best_positions = []
    
    # 1-token search
    for pos in positions:
        if pos >= tokens.shape[0]:
            continue
        objectives = evaluate_candidates_batched(
            model, tokens, pos, limited_candidates, correct_label_token, batch_size
        )
        valid_candidates = [t for t in limited_candidates if t != 0]
        
        for idx, objective in enumerate(objectives):
            if best_objective is None or objective > best_objective:
                best_objective = objective
                best_tokens = [valid_candidates[idx]]
                best_positions = [pos]
    
    # 2-token search (if max_tokens >= 2)
    if max_tokens >= 2 and best_tokens:
        tokens_1 = torch.cat([
            tokens[:best_positions[0]],
            torch.tensor([best_tokens[0]], device=tokens.device, dtype=tokens.dtype),
            tokens[best_positions[0]:]
        ])
        
        for pos2 in positions:
            if pos2 >= tokens_1.shape[0]:
                continue
            objectives = evaluate_candidates_batched(
                model, tokens_1, pos2, limited_candidates, correct_label_token, batch_size
            )
            valid_candidates = [t for t in limited_candidates if t != 0]
            
            for idx, objective in enumerate(objectives):
                if objective > best_objective:
                    best_objective = objective
                    best_tokens = [best_tokens[0], valid_candidates[idx]]
                    best_positions = [best_positions[0], pos2]
    
    # Re-evaluate final best
    if best_tokens:
        final_tokens = tokens.clone()
        for tok, pos in zip(best_tokens, sorted(best_positions)):
            offset = sum(1 for p in best_positions if p < pos)
            insert_pos = pos + offset
            final_tokens = torch.cat([
                final_tokens[:insert_pos],
                torch.tensor([tok], device=final_tokens.device, dtype=final_tokens.dtype),
                final_tokens[insert_pos:]
            ])
        
        with torch.no_grad():
            logits = model(final_tokens.unsqueeze(0))
            next_token_logits = logits[0, -1, :]
            next_token_probs = torch.softmax(next_token_logits, dim=-1)
            best_objective = -np.log(next_token_probs[correct_label_token].item() + 1e-10)
    
    return best_objective, best_tokens, best_positions

In [None]:
# Run Check C: Compare greedy 3-token with beam/exhaustive
N_SAMPLES = 20
np.random.seed(SEED)
sample_indices = np.random.choice(len(metadata_list), size=min(N_SAMPLES, len(metadata_list)), replace=False)

label_tokens = get_label_token_ids(model)
candidate_tokens = build_candidate_token_list(model, n_tokens=500)

comparison_results = []
greedy_superior_count = 0

for i, idx in enumerate(sample_indices):
    entry = metadata_list[idx]
    prompt = entry['clean_prompt']
    label = entry['clean_label']
    
    # Get greedy 3-token objective from metadata
    greedy_obj = entry['objective_k'].get('k3')
    if greedy_obj is None:
        continue
    
    # determine positions
    tokens = model.to_tokens(prompt)[0]
    review_text = prompt.split("\nSentiment:")[0]
    review_tokens = model.to_tokens(review_text)[0]
    max_insertion_pos = review_tokens.shape[0] - 1
    
    positions = []
    if max_insertion_pos > 1:
        positions.append(1)
    if max_insertion_pos > 3:
        positions.append(min(max_insertion_pos // 2, max_insertion_pos - 1))
    if max_insertion_pos > 2:
        positions.append(max_insertion_pos)
    positions = sorted(set([p for p in positions if 0 < p <= max_insertion_pos]))[:3]
    
    # Run beam search (1-2 tokens)
    beam_obj, beam_tokens, beam_positions = beam_search_adv(
        model, prompt, label, candidate_tokens, positions, label_tokens,
        beam_width=5, max_tokens=2, batch_size=32
    )
    
    # Run mini-exhaustive (1-2 tokens)
    exhaustive_obj, exhaustive_tokens, exhaustive_positions = mini_exhaustive_search(
        model, prompt, label, candidate_tokens, positions, label_tokens,
        max_candidates=30, max_tokens=2, batch_size=32
    )
    
    # Compare: greedy 3-token should be ≥ best alternative
    best_alternative = max(beam_obj, exhaustive_obj) if (beam_obj is not None and exhaustive_obj is not None) else None
    
    if best_alternative is not None:
        greedy_superior = greedy_obj >= best_alternative
        if greedy_superior:
            greedy_superior_count += 1
        
        comparison_results.append({
            'prompt_id': entry['id'],
            'greedy_obj': greedy_obj,
            'beam_obj': beam_obj,
            'exhaustive_obj': exhaustive_obj,
            'best_alternative': best_alternative,
            'greedy_superior': greedy_superior
        })

greedy_superior_percentage = 100 * greedy_superior_count / len(comparison_results) if comparison_results else 0.0

# check pass criteria: ≥90% samples where greedy 3-token ≥ alternatives
check_c_passes = greedy_superior_percentage >= 90.0

# save comparison results
comparison_df = pd.DataFrame(comparison_results)
comparison_file = os.path.join(ARTIFACTS_DIR, 'sanity_check_c_comparison.csv')
comparison_df.to_csv(comparison_file, index=False)

# generate plot
if comparison_results:
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    prompt_ids = [r['prompt_id'] for r in comparison_results]
    greedy_objs = [r['greedy_obj'] for r in comparison_results]
    beam_objs = [r['beam_obj'] for r in comparison_results]
    exhaustive_objs = [r['exhaustive_obj'] for r in comparison_results]
    
    x = np.arange(len(prompt_ids))
    width = 0.25
    
    ax.bar(x - width, greedy_objs, width, label='Greedy 3-token', color='green', alpha=0.7)
    ax.bar(x, beam_objs, width, label='Beam Search (1-2 tokens)', color='blue', alpha=0.7)
    ax.bar(x + width, exhaustive_objs, width, label='Mini-Exhaustive (1-2 tokens)', color='orange', alpha=0.7)
    
    ax.set_xlabel('Prompt ID')
    ax.set_ylabel('Objective (negative log-prob)')
    ax.set_title('Greedy 3-token vs Beam/Exhaustive Search Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels(prompt_ids, rotation=45, ha='right')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(os.path.join(ARTIFACTS_DIR, 'sanity_check_c_comparison.png'), dpi=150, bbox_inches='tight')
    plt.close()

# store results
check_c_result = {
    "status": "PASS" if check_c_passes else "FAIL",
    "n_samples": len(comparison_results),
    "greedy_superior_count": greedy_superior_count,
    "greedy_superior_percentage": float(greedy_superior_percentage),
    "pass_criteria_met": check_c_passes
}


In [None]:
# Check D: Marginal Ablation
# filter to 3-token sequences only
three_token_entries = [entry for entry in metadata_list if entry.get('chosen_k') == 3]

# extract ablation marginals
marginals_token1 = []
marginals_token2 = []
marginals_token3 = []

for entry in three_token_entries:
    ablation_marginals = entry.get('ablation_marginals', [])
    
    if len(ablation_marginals) >= 3:
        marginals_token1.append(ablation_marginals[0])
        marginals_token2.append(ablation_marginals[1])
        marginals_token3.append(ablation_marginals[2])
    elif len(ablation_marginals) == 2:
        # If only 2 values, assume they're for token2 and token3
        marginals_token2.append(ablation_marginals[0])
        marginals_token3.append(ablation_marginals[1])
    elif len(ablation_marginals) == 1:
        marginals_token3.append(ablation_marginals[0])

# compute averages
avg_marginal_token1 = np.mean(marginals_token1) if marginals_token1 else None
avg_marginal_token2 = np.mean(marginals_token2) if marginals_token2 else None
avg_marginal_token3 = np.mean(marginals_token3) if marginals_token3 else None

# check pass criteria: |avg marginal fraction of token3| ≥ 0.10
avg_marginal_token3_abs = abs(avg_marginal_token3) if avg_marginal_token3 is not None else None
check_d_passes = avg_marginal_token3_abs is not None and avg_marginal_token3_abs >= 0.10

# generate plot
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

if marginals_token1 or marginals_token2 or marginals_token3:
    data_to_plot = []
    labels = []
    
    if marginals_token1:
        data_to_plot.append(marginals_token1)
        labels.append('Token 1')
    if marginals_token2:
        data_to_plot.append(marginals_token2)
        labels.append('Token 2')
    if marginals_token3:
        data_to_plot.append(marginals_token3)
        labels.append('Token 3')
    
    if data_to_plot:
        bp = ax.boxplot(data_to_plot, labels=labels, patch_artist=True)
        
        # Color the boxes
        colors = ['lightblue', 'lightgreen', 'lightcoral']
        for patch, color in zip(bp['boxes'], colors[:len(bp['boxes'])]):
            patch.set_facecolor(color)
        
        ax.axhline(0.10, color='red', linestyle='--', label='Threshold (0.10)')
        ax.axhline(-0.10, color='red', linestyle='--', alpha=0.5, label='Threshold (-0.10)')
        if avg_marginal_token3 is not None:
            ax.axhline(avg_marginal_token3, color='green', linestyle='--', 
                      label=f'Token 3 Mean ({avg_marginal_token3:.3f})')
            ax.axhline(-avg_marginal_token3, color='green', linestyle=':', alpha=0.5,
                      label=f'|Token 3 Mean| ({abs(avg_marginal_token3):.3f})')
        
        ax.set_ylabel('Marginal Contribution')
        ax.set_title('Marginal Ablation Contributions by Token Position')
        ax.legend()
        ax.grid(True, alpha=0.3)
else:
    ax.text(0.5, 0.5, 'No marginal data available', ha='center', va='center')
    ax.set_title('Marginal Ablation Contributions')

plt.tight_layout()
plt.savefig(os.path.join(ARTIFACTS_DIR, 'sanity_check_d_marginals.png'), dpi=150, bbox_inches='tight')
plt.close()

# store results
check_d_result = {
    "status": "PASS" if check_d_passes else "FAIL",
    "n_3token_sequences": len(three_token_entries),
    "avg_marginal_token1": float(avg_marginal_token1) if avg_marginal_token1 is not None else None,
    "avg_marginal_token2": float(avg_marginal_token2) if avg_marginal_token2 is not None else None,
    "avg_marginal_token3": float(avg_marginal_token3) if avg_marginal_token3 is not None else None,
    "avg_marginal_token3_abs": float(avg_marginal_token3_abs) if avg_marginal_token3_abs is not None else None,
    "pass_criteria_met": check_d_passes,
    "note": "Using absolute value: negative marginals indicate tokens are helpful"
}


In [None]:
# Check E: Patch Sanity Pilot (N=30)
N_SAMPLES_E = 30
np.random.seed(SEED)
sample_indices_e = np.random.choice(len(metadata_list), size=min(N_SAMPLES_E, len(metadata_list)), replace=False)

# generate 1-token adversarial prompts
label_tokens_e = get_label_token_ids(model)
candidate_tokens_e = build_candidate_token_list(model, n_tokens=500)

one_token_prompts = []
one_token_objectives = []
three_token_objectives = []

for i, idx in enumerate(sample_indices_e):
    entry = metadata_list[idx]
    prompt = entry['clean_prompt']
    label = entry['clean_label']
    
    # Get 3-token objective from metadata
    three_token_obj = entry['objective_k'].get('k3')
    if three_token_obj is None:
        continue
    
    # Generate 1-token adversarial prompt
    tokens = model.to_tokens(prompt)[0]
    review_text = prompt.split("\nSentiment:")[0]
    review_tokens = model.to_tokens(review_text)[0]
    max_insertion_pos = review_tokens.shape[0] - 1
    
    positions = []
    if max_insertion_pos > 1:
        positions.append(1)
    if max_insertion_pos > 3:
        positions.append(min(max_insertion_pos // 2, max_insertion_pos - 1))
    if max_insertion_pos > 2:
        positions.append(max_insertion_pos)
    positions = sorted(set([p for p in positions if 0 < p <= max_insertion_pos]))[:3]
    
    # Find best 1-token insertion
    best_obj = None
    best_token = None
    best_pos = None
    
    correct_label_token = label_tokens_e[label]
    
    for pos in positions:
        if pos >= tokens.shape[0]:
            continue
        objectives = evaluate_candidates_batched(
            model, tokens, pos, candidate_tokens_e, correct_label_token, batch_size=32
        )
        valid_candidates = [t for t in candidate_tokens_e if t != 0]
        
        for idx_obj, objective in enumerate(objectives):
            if best_obj is None or objective > best_obj:
                best_obj = objective
                best_token = valid_candidates[idx_obj]
                best_pos = pos
    
    # Re-evaluate to ensure accuracy
    if best_token is not None and best_pos is not None:
        tokens_1 = torch.cat([
            tokens[:best_pos],
            torch.tensor([best_token], device=tokens.device, dtype=tokens.dtype),
            tokens[best_pos:]
        ])
        
        with torch.no_grad():
            logits_1 = model(tokens_1.unsqueeze(0))
            next_token_logits_1 = logits_1[0, -1, :]
            next_token_probs_1 = torch.softmax(next_token_logits_1, dim=-1)
            best_obj = -np.log(next_token_probs_1[correct_label_token].item() + 1e-10)
        
        one_token_prompts.append(model.to_string(tokens_1))
        one_token_objectives.append(best_obj)
        three_token_objectives.append(three_token_obj)

if len(one_token_objectives) > 0 and len(three_token_objectives) > 0:
    avg_obj_1token = np.mean(one_token_objectives)
    avg_obj_3token = np.mean(three_token_objectives)
    
    # check if 3-token shows equal-or-stronger performance
    check_e_passes = avg_obj_3token >= avg_obj_1token
    
    check_e_result = {
        "status": "PASS" if check_e_passes else "FAIL",
        "n_samples": len(one_token_prompts),
        "avg_objective_1token": float(avg_obj_1token),
        "avg_objective_3token": float(avg_obj_3token),
        "pass_criteria_met": bool(check_e_passes)
    }
else:
    check_e_result = {
        "status": "ERROR",
        "reason": "Insufficient 1-token prompts generated",
        "pass_criteria_met": None
    }


In [None]:
# Generate Final Sanity Report
# helper function to convert numpy types to native Python types for JSON serialization
def convert_to_json_serializable(obj):
    """Recursively convert numpy types and other non-serializable types to Python native types."""
    if isinstance(obj, dict):
        return {key: convert_to_json_serializable(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_json_serializable(item) for item in obj]
    elif isinstance(obj, (np.integer, np.int64, np.int32)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float64, np.float32)):
        return float(obj)
    elif isinstance(obj, (np.bool_, bool)):
        return bool(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif pd.isna(obj):
        return None
    else:
        return obj

# aggregate all check results
all_checks = {
    "check_a": check_a_result,
    "check_b": check_b_result,
    "check_c": check_c_result if 'check_c_result' in globals() else {"status": "NOT_RUN"},
    "check_d": check_d_result,
    "check_e": check_e_result if 'check_e_result' in globals() else {"status": "NOT_RUN"}
}

# Determine overall status
check_statuses = []
for check_name, check_data in all_checks.items():
    status = check_data.get('status', 'UNKNOWN')
    if status in ['PASS', 'FAIL']:
        check_statuses.append(status)

overall_passes = sum(1 for s in check_statuses if s == 'PASS')
overall_fails = sum(1 for s in check_statuses if s == 'FAIL')
overall_status = "PASS" if overall_fails == 0 and overall_passes > 0 else "FAIL"

# generate recommendation
if overall_status == "PASS":
    recommendation = "Proceed"
else:
    recommendation = "Adjust search / Apply stopping rule"

# create report
sanity_report = {
    "check_a": all_checks["check_a"],
    "check_b": all_checks["check_b"],
    "check_c": all_checks["check_c"],
    "check_d": all_checks["check_d"],
    "check_e": all_checks["check_e"],
    "overall_status": overall_status,
    "recommendation": recommendation,
    "timestamp": datetime.now().isoformat(),
    "summary": {
        "total_checks": len([c for c in all_checks.values() if c.get('status') in ['PASS', 'FAIL']]),
        "passed": overall_passes,
        "failed": overall_fails,
        "skipped": len([c for c in all_checks.values() if c.get('status') not in ['PASS', 'FAIL']])
    }
}

# convert to JSON-serializable format
sanity_report = convert_to_json_serializable(sanity_report)

# save report
report_file = os.path.join(ARTIFACTS_DIR, 'sanity_report.json')
with open(report_file, 'w') as f:
    json.dump(sanity_report, f, indent=2)

# generate summary visualization
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

check_names = ['Check A\n(Monotonicity)', 'Check B\n(Tokenization)', 
               'Check C\n(Greedy vs Beam)', 'Check D\n(Marginal)', 
               'Check E\n(Patch Sanity)']
check_statuses_viz = [
    all_checks['check_a']['status'],
    all_checks['check_b']['status'],
    all_checks['check_c'].get('status', 'NOT_RUN'),
    all_checks['check_d']['status'],
    all_checks['check_e'].get('status', 'NOT_RUN')
]

colors = []
for status in check_statuses_viz:
    if status == 'PASS':
        colors.append('green')
    elif status == 'FAIL':
        colors.append('red')
    else:
        colors.append('gray')

bars = ax.barh(check_names, [1 if s in ['PASS', 'FAIL'] else 0.5 for s in check_statuses_viz], color=colors, alpha=0.7)
ax.set_xlim(0, 1.2)
ax.set_xlabel('Status')
ax.set_title('Sanity Checks Summary Dashboard')
ax.set_xticks([0, 0.5, 1])
ax.set_xticklabels(['', '', ''])

# Add status labels
for i, (bar, status) in enumerate(zip(bars, check_statuses_viz)):
    width = bar.get_width()
    ax.text(width + 0.05, bar.get_y() + bar.get_height()/2, status, 
            ha='left', va='center', fontweight='bold')

# Add overall status
ax.text(0.5, -0.3, f'Overall: {overall_status} | Recommendation: {recommendation}', 
        ha='center', va='top', fontsize=12, fontweight='bold',
        bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5),
        transform=ax.transAxes)

plt.tight_layout()
plt.savefig(os.path.join(ARTIFACTS_DIR, 'sanity_check_summary.png'), dpi=150, bbox_inches='tight')
plt.close()
