# Week 10: Skepticism and Interpretability Illusions - Exercises

This notebook contains hands-on exercises for validating interpretability methods and avoiding common pitfalls. We'll implement sanity checks, robustness tests, and multi-method validation strategies.

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, pearsonr
from typing import List, Tuple, Dict
import copy

# Load model
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print(f"Model loaded: {model_name}")
print(f"Device: {device}")

## Part 1: Sanity Checks for Saliency Maps (Adebayo et al., 2018)

Implement the model parameter randomization test: compare saliency maps from a trained model vs. a randomly initialized model.

In [None]:
def compute_saliency(model, input_ids, target_position=-1):
    """
    Compute simple gradient-based saliency for each input token.
    
    Args:
        model: The language model
        input_ids: Token IDs [batch_size, seq_len]
        target_position: Position to compute loss for (default: last token)
    
    Returns:
        saliency: Gradient magnitude for each input token [seq_len]
    """
    # TODO: Implement saliency computation
    # 1. Get embeddings with requires_grad=True
    # 2. Forward pass
    # 3. Compute loss on target position
    # 4. Backward to get gradients
    # 5. Return gradient magnitudes
    
    input_ids = input_ids.to(model.device)
    embeddings = model.transformer.wte(input_ids)
    embeddings.requires_grad_(True)
    
    # Forward pass using embeddings directly
    outputs = model(inputs_embeds=embeddings)
    logits = outputs.logits
    
    # Compute loss on target position
    target_logits = logits[0, target_position, :]
    target_token = input_ids[0, target_position + 1] if target_position < -1 else input_ids[0, 0]
    loss = F.cross_entropy(target_logits.unsqueeze(0), target_token.unsqueeze(0))
    
    # Backward
    loss.backward()
    
    # Get gradient magnitudes
    saliency = embeddings.grad.abs().sum(dim=-1).squeeze(0)
    
    return saliency.detach().cpu().numpy()


def randomize_model_weights(model):
    """
    Create a copy of the model with randomly initialized weights.
    
    Returns:
        random_model: Model with random weights (same architecture)
    """
    # TODO: Create a random model
    # Hint: Use copy.deepcopy and reinitialize parameters
    
    random_model = copy.deepcopy(model)
    
    # Reinitialize all parameters
    for param in random_model.parameters():
        if param.dim() > 1:
            nn.init.xavier_uniform_(param)
        else:
            nn.init.zeros_(param)
    
    return random_model


def sanity_check_model_randomization(trained_model, input_text):
    """
    Test: Do saliency maps change when we randomize the model?
    
    Returns:
        correlation: Spearman correlation between trained and random saliency
    """
    # TODO: Implement the sanity check
    # 1. Compute saliency on trained model
    # 2. Create random model and compute saliency
    # 3. Compute correlation
    # Expected: Low correlation (< 0.3) for good methods
    
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    
    # Trained model saliency
    with torch.enable_grad():
        trained_saliency = compute_saliency(trained_model, input_ids)
    
    # Random model saliency
    random_model = randomize_model_weights(trained_model)
    random_model.eval()
    with torch.enable_grad():
        random_saliency = compute_saliency(random_model, input_ids)
    
    # Compute correlation
    correlation, p_value = spearmanr(trained_saliency, random_saliency)
    
    # Visualize
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6))
    
    ax1.bar(range(len(tokens)), trained_saliency)
    ax1.set_xticks(range(len(tokens)))
    ax1.set_xticklabels(tokens, rotation=45, ha='right')
    ax1.set_title('Trained Model Saliency')
    ax1.set_ylabel('Gradient Magnitude')
    
    ax2.bar(range(len(tokens)), random_saliency)
    ax2.set_xticks(range(len(tokens)))
    ax2.set_xticklabels(tokens, rotation=45, ha='right')
    ax2.set_title('Random Model Saliency')
    ax2.set_ylabel('Gradient Magnitude')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Spearman correlation: {correlation:.3f} (p={p_value:.3f})")
    print(f"{'PASS' if abs(correlation) < 0.3 else 'FAIL'}: Saliency should differ for random model")
    
    return correlation


# Test
test_text = "The capital of France is Paris, which is known for"
correlation = sanity_check_model_randomization(model, test_text)

## Part 2: ROAR Benchmark (Hooker et al., 2019)

Implement RemOve And Retrain: test if removing "important" features hurts performance more than removing random features.

In [None]:
def identify_important_tokens(model, tokenizer, text, method='gradient', top_k=3):
    """
    Identify the most important tokens according to a method.
    
    Args:
        method: 'gradient', 'attention', or 'random'
        top_k: Number of tokens to identify
    
    Returns:
        important_positions: Indices of important tokens
    """
    input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
    
    if method == 'gradient':
        # TODO: Use gradient-based importance
        with torch.enable_grad():
            saliency = compute_saliency(model, input_ids)
        important_positions = np.argsort(saliency)[-top_k:]
        
    elif method == 'attention':
        # TODO: Use attention-based importance
        with torch.no_grad():
            outputs = model(input_ids, output_attentions=True)
            # Average attention across heads and layers
            attentions = torch.stack(outputs.attentions)  # [n_layers, 1, n_heads, seq_len, seq_len]
            avg_attention = attentions.mean(dim=(0, 2))  # [1, seq_len, seq_len]
            # Sum attention received by each token
            importance = avg_attention.sum(dim=1).squeeze(0)  # [seq_len]
        important_positions = torch.argsort(importance)[-top_k:].cpu().numpy()
        
    elif method == 'random':
        # Random baseline
        seq_len = input_ids.shape[1]
        important_positions = np.random.choice(seq_len, size=top_k, replace=False)
        
    else:
        raise ValueError(f"Unknown method: {method}")
    
    return important_positions


def mask_tokens(text, positions, tokenizer):
    """
    Mask tokens at specified positions.
    
    Returns:
        masked_text: Text with tokens masked
    """
    # TODO: Replace tokens at positions with [MASK] or remove them
    tokens = tokenizer.tokenize(text)
    
    for pos in sorted(positions, reverse=True):
        if pos < len(tokens):
            tokens[pos] = '[MASK]'
    
    # Reconstruct text (approximate)
    masked_text = ' '.join(tokens).replace(' ##', '')
    
    return masked_text


def compute_perplexity(model, tokenizer, texts):
    """
    Compute average perplexity on a list of texts.
    
    Returns:
        perplexity: Average perplexity
    """
    # TODO: Compute perplexity
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for text in texts:
            input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
            
            if input_ids.shape[1] < 2:
                continue
            
            outputs = model(input_ids, labels=input_ids)
            total_loss += outputs.loss.item() * (input_ids.shape[1] - 1)
            total_tokens += input_ids.shape[1] - 1
    
    avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf')
    perplexity = np.exp(avg_loss)
    
    return perplexity


def roar_test(model, tokenizer, test_texts, methods=['gradient', 'attention', 'random'], top_k=3):
    """
    ROAR test: Compare performance degradation when removing important tokens.
    
    Returns:
        results: Dict mapping method to perplexity increase
    """
    # TODO: Implement ROAR test
    # 1. Compute baseline perplexity on original texts
    # 2. For each method:
    #    a. Identify important tokens
    #    b. Mask those tokens
    #    c. Compute perplexity on masked texts
    # 3. Compare degradation
    # Expected: Better methods should cause larger degradation
    
    baseline_ppl = compute_perplexity(model, tokenizer, test_texts)
    print(f"Baseline perplexity: {baseline_ppl:.2f}")
    
    results = {}
    
    for method in methods:
        masked_texts = []
        
        for text in test_texts:
            positions = identify_important_tokens(model, tokenizer, text, method=method, top_k=top_k)
            masked_text = mask_tokens(text, positions, tokenizer)
            masked_texts.append(masked_text)
        
        method_ppl = compute_perplexity(model, tokenizer, masked_texts)
        degradation = method_ppl - baseline_ppl
        
        results[method] = {
            'perplexity': method_ppl,
            'degradation': degradation
        }
        
        print(f"{method}: ppl={method_ppl:.2f}, degradation={degradation:.2f}")
    
    # Visualize
    plt.figure(figsize=(8, 5))
    methods_list = list(results.keys())
    degradations = [results[m]['degradation'] for m in methods_list]
    
    plt.bar(methods_list, degradations)
    plt.ylabel('Perplexity Increase')
    plt.title('ROAR Test: Performance Degradation by Method')
    plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
    plt.show()
    
    print(f"\n{'PASS' if results['gradient']['degradation'] > results['random']['degradation'] else 'FAIL'}: "
          f"Gradient method should outperform random")
    
    return results


# Test
test_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Machine learning models can be difficult to interpret.",
    "Paris is the capital of France and a major European city."
]

results = roar_test(model, tokenizer, test_texts)

## Part 3: Attention vs. Gradient Importance (Jain & Wallace, 2019)

Test whether attention weights correlate with gradient-based feature importance.

In [None]:
def compare_attention_vs_gradient(model, tokenizer, text):
    """
    Compare attention weights with gradient-based importance.
    
    Returns:
        correlation: Spearman correlation between the two methods
    """
    # TODO: Implement comparison
    # 1. Compute gradient importance
    # 2. Compute attention importance (average over heads/layers)
    # 3. Correlate the two
    
    input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
    
    # Gradient importance
    with torch.enable_grad():
        gradient_importance = compute_saliency(model, input_ids)
    
    # Attention importance
    with torch.no_grad():
        outputs = model(input_ids, output_attentions=True)
        attentions = torch.stack(outputs.attentions)  # [n_layers, 1, n_heads, seq_len, seq_len]
        # Average attention to each token across all queries, heads, and layers
        attention_importance = attentions.mean(dim=(0, 2, 3)).squeeze(0).cpu().numpy()  # [seq_len]
    
    # Normalize both
    gradient_importance = gradient_importance / gradient_importance.sum()
    attention_importance = attention_importance / attention_importance.sum()
    
    # Compute correlation
    correlation, p_value = spearmanr(gradient_importance, attention_importance)
    
    # Visualize
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    fig, axes = plt.subplots(3, 1, figsize=(12, 10))
    
    axes[0].bar(range(len(tokens)), gradient_importance)
    axes[0].set_title('Gradient-Based Importance')
    axes[0].set_xticks(range(len(tokens)))
    axes[0].set_xticklabels(tokens, rotation=45, ha='right')
    
    axes[1].bar(range(len(tokens)), attention_importance)
    axes[1].set_title('Attention-Based Importance')
    axes[1].set_xticks(range(len(tokens)))
    axes[1].set_xticklabels(tokens, rotation=45, ha='right')
    
    axes[2].scatter(gradient_importance, attention_importance, alpha=0.6)
    axes[2].set_xlabel('Gradient Importance')
    axes[2].set_ylabel('Attention Importance')
    axes[2].set_title(f'Correlation: {correlation:.3f} (p={p_value:.3f})')
    axes[2].plot([0, max(gradient_importance.max(), attention_importance.max())], 
                 [0, max(gradient_importance.max(), attention_importance.max())], 
                 'r--', alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Spearman correlation: {correlation:.3f} (p={p_value:.3f})")
    print(f"Interpretation: {'Weak' if abs(correlation) < 0.3 else 'Moderate' if abs(correlation) < 0.6 else 'Strong'} correlation")
    print("Low correlation suggests attention may not explain feature importance.")
    
    return correlation


# Test on multiple examples
test_texts = [
    "The cat sat on the mat.",
    "Machine learning is a subset of artificial intelligence.",
    "She didn't want to go to the party because she was tired."
]

correlations = []
for text in test_texts:
    print(f"\nText: {text}")
    corr = compare_attention_vs_gradient(model, tokenizer, text)
    correlations.append(corr)

print(f"\nAverage correlation: {np.mean(correlations):.3f}")

## Part 4: Adversarial Robustness for SAE Features

Test whether SAE feature activations are robust to small input perturbations.

In [None]:
def simulate_sae_feature(model, tokenizer, concept_text, layer=6):
    """
    Simulate an SAE feature by computing a direction that activates on concept examples.
    (Simplified version for demonstration)
    
    Returns:
        feature_direction: Direction in activation space [hidden_dim]
    """
    # TODO: Extract activations and compute concept direction
    input_ids = tokenizer.encode(concept_text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)
        activations = outputs.hidden_states[layer]  # [1, seq_len, hidden_dim]
        # Use mean activation as feature direction
        feature_direction = activations.mean(dim=1).squeeze(0)  # [hidden_dim]
    
    # Normalize
    feature_direction = feature_direction / feature_direction.norm()
    
    return feature_direction


def compute_feature_activation(model, tokenizer, text, feature_direction, layer=6):
    """
    Compute how much a text activates a feature.
    
    Returns:
        activation: Scalar activation value
    """
    # TODO: Project activations onto feature direction
    input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)
        activations = outputs.hidden_states[layer].mean(dim=1).squeeze(0)  # [hidden_dim]
    
    # Project onto feature direction
    activation = torch.dot(activations, feature_direction).item()
    
    return activation


def adversarial_perturbation(model, tokenizer, text, feature_direction, layer=6, epsilon=0.01, steps=10):
    """
    Find small perturbation to input embeddings that maximally changes feature activation.
    
    Returns:
        perturbed_activation: Feature activation after perturbation
        original_activation: Feature activation before perturbation
        output_change: Change in model output (perplexity)
    """
    # TODO: Implement adversarial perturbation
    # 1. Compute original feature activation
    # 2. Use gradient ascent to perturb input embeddings to minimize activation
    # 3. Measure how much model output changes
    
    input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
    
    # Original activation
    original_activation = compute_feature_activation(model, tokenizer, text, feature_direction, layer)
    
    # Get embeddings
    embeddings = model.transformer.wte(input_ids).detach()
    embeddings.requires_grad_(True)
    
    # Gradient descent to minimize feature activation
    for step in range(steps):
        if embeddings.grad is not None:
            embeddings.grad.zero_()
        
        outputs = model(inputs_embeds=embeddings, output_hidden_states=True)
        activations = outputs.hidden_states[layer].mean(dim=1).squeeze(0)
        
        # Objective: minimize feature activation
        feature_activation = torch.dot(activations, feature_direction)
        
        feature_activation.backward()
        
        # Update embeddings
        with torch.no_grad():
            embeddings -= epsilon * embeddings.grad.sign()
            # Project back to epsilon ball
            embeddings.requires_grad_(True)
    
    # Final perturbed activation
    with torch.no_grad():
        outputs_perturbed = model(inputs_embeds=embeddings, output_hidden_states=True)
        activations_perturbed = outputs_perturbed.hidden_states[layer].mean(dim=1).squeeze(0)
        perturbed_activation = torch.dot(activations_perturbed, feature_direction).item()
        
        # Measure output change (simplified: just look at logit change)
        outputs_original = model(input_ids)
        logit_diff = (outputs_perturbed.logits - outputs_original.logits).abs().mean().item()
    
    return perturbed_activation, original_activation, logit_diff


def test_sae_robustness(model, tokenizer, concept_texts, test_texts, layer=6):
    """
    Test SAE feature robustness to adversarial perturbations.
    
    Returns:
        results: Dict with activation changes and output changes
    """
    # TODO: Implement robustness test
    # Expected: Large activation change with small output change = fragile feature
    
    # Learn feature from concept examples
    print("Learning feature direction from concept examples...")
    feature_direction = simulate_sae_feature(model, tokenizer, " ".join(concept_texts), layer)
    
    results = []
    
    for text in test_texts:
        perturbed_act, original_act, output_change = adversarial_perturbation(
            model, tokenizer, text, feature_direction, layer
        )
        
        activation_change = abs(perturbed_act - original_act)
        
        results.append({
            'text': text,
            'original_activation': original_act,
            'perturbed_activation': perturbed_act,
            'activation_change': activation_change,
            'output_change': output_change
        })
        
        print(f"\nText: {text[:50]}...")
        print(f"  Original activation: {original_act:.3f}")
        print(f"  Perturbed activation: {perturbed_act:.3f}")
        print(f"  Activation change: {activation_change:.3f}")
        print(f"  Output change: {output_change:.3f}")
        print(f"  Robustness: {'PASS' if activation_change < 0.5 * abs(original_act) else 'FAIL'}")
    
    # Visualize
    activation_changes = [r['activation_change'] for r in results]
    output_changes = [r['output_change'] for r in results]
    
    plt.figure(figsize=(8, 6))
    plt.scatter(output_changes, activation_changes, alpha=0.6, s=100)
    plt.xlabel('Output Change (Logit Difference)')
    plt.ylabel('Feature Activation Change')
    plt.title('SAE Feature Robustness Test')
    plt.axhline(y=np.median(activation_changes), color='r', linestyle='--', 
                label=f'Median activation change: {np.median(activation_changes):.3f}')
    plt.legend()
    plt.show()
    
    print(f"\nSummary: Large activation changes with small output changes suggest fragile features.")
    
    return results


# Test
concept_texts = [
    "The Golden Gate Bridge is in San Francisco.",
    "I visited the Golden Gate Bridge last summer."
]

test_texts = [
    "San Francisco is a beautiful city on the west coast.",
    "The bridge connects the city to Marin County.",
    "California has many famous landmarks and attractions."
]

results = test_sae_robustness(model, tokenizer, concept_texts, test_texts)

## Part 5: Multi-Method Validation

Validate a concept using three independent methods and check for agreement.

In [None]:
def validate_concept_multimethod(model, tokenizer, concept_name, positive_examples, negative_examples, layer=6):
    """
    Validate a concept using multiple methods:
    1. Linear probe
    2. Causal intervention (steering)
    3. Feature attribution
    
    Returns:
        agreement_score: How much methods agree (0-1)
    """
    # TODO: Implement multi-method validation
    
    print(f"Validating concept: {concept_name}")
    print(f"Positive examples: {len(positive_examples)}")
    print(f"Negative examples: {len(negative_examples)}")
    
    # Method 1: Linear Probe
    print("\nMethod 1: Linear Probe")
    # TODO: Train probe and compute accuracy
    probe_accuracy = 0.75  # Placeholder
    print(f"  Probe accuracy: {probe_accuracy:.2%}")
    
    # Method 2: Causal Steering
    print("\nMethod 2: Causal Steering")
    # TODO: Extract concept direction and test steering effectiveness
    steering_effectiveness = 0.70  # Placeholder
    print(f"  Steering effectiveness: {steering_effectiveness:.2%}")
    
    # Method 3: Feature Attribution
    print("\nMethod 3: Feature Attribution")
    # TODO: Identify features important for concept
    attribution_consistency = 0.65  # Placeholder
    print(f"  Attribution consistency: {attribution_consistency:.2%}")
    
    # Compute agreement
    scores = [probe_accuracy, steering_effectiveness, attribution_consistency]
    agreement_score = 1.0 - np.std(scores)  # Lower std = higher agreement
    
    print(f"\nAgreement score: {agreement_score:.3f}")
    print(f"{'PASS' if agreement_score > 0.8 else 'WARNING'}: Methods should agree for robust concept")
    
    # Visualize
    methods = ['Probe', 'Steering', 'Attribution']
    plt.figure(figsize=(8, 5))
    plt.bar(methods, scores)
    plt.ylabel('Validation Score')
    plt.title(f'Multi-Method Validation: {concept_name}')
    plt.ylim([0, 1])
    plt.axhline(y=np.mean(scores), color='r', linestyle='--', label=f'Mean: {np.mean(scores):.2f}')
    plt.legend()
    plt.show()
    
    return agreement_score


# Test
positive_examples = [
    "The cat sat on the mat.",
    "A dog chased the ball."
]
negative_examples = [
    "The theory of relativity explains gravity.",
    "Machine learning uses neural networks."
]

score = validate_concept_multimethod(
    model, tokenizer, "animals", positive_examples, negative_examples
)

## Part 6: Circuit Faithfulness with Multiple Ablation Methods

Test how circuit faithfulness scores change with different ablation methods.

In [None]:
def ablate_attention_head(model, layer, head, method='zero'):
    """
    Ablate a specific attention head using different methods.
    
    Args:
        method: 'zero', 'mean', 'random'
    
    Returns:
        hook: Function to register as forward hook
    """
    # TODO: Implement different ablation methods
    
    def zero_ablation_hook(module, input, output):
        # output is (batch, num_heads, seq_len, head_dim)
        output[0][:, head, :, :] = 0
        return output
    
    def mean_ablation_hook(module, input, output):
        # Replace with mean across sequence
        output[0][:, head, :, :] = output[0][:, head, :, :].mean(dim=1, keepdim=True)
        return output
    
    def random_ablation_hook(module, input, output):
        # Replace with random values from standard normal
        output[0][:, head, :, :] = torch.randn_like(output[0][:, head, :, :])
        return output
    
    if method == 'zero':
        return zero_ablation_hook
    elif method == 'mean':
        return mean_ablation_hook
    elif method == 'random':
        return random_ablation_hook
    else:
        raise ValueError(f"Unknown ablation method: {method}")


def test_circuit_faithfulness(model, tokenizer, text, layer, head, ablation_methods=['zero', 'mean', 'random']):
    """
    Test how faithfulness scores depend on ablation method.
    
    Returns:
        results: Dict mapping ablation method to faithfulness score
    """
    # TODO: Compute faithfulness with different ablation methods
    
    input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
    
    # Baseline (no ablation)
    with torch.no_grad():
        baseline_outputs = model(input_ids)
        baseline_logits = baseline_outputs.logits[0, -1, :]
        baseline_probs = F.softmax(baseline_logits, dim=-1)
    
    results = {}
    
    for method in ablation_methods:
        # Register hook
        hook = ablate_attention_head(model, layer, head, method)
        handle = model.transformer.h[layer].attn.register_forward_hook(hook)
        
        # Forward pass with ablation
        with torch.no_grad():
            ablated_outputs = model(input_ids)
            ablated_logits = ablated_outputs.logits[0, -1, :]
            ablated_probs = F.softmax(ablated_logits, dim=-1)
        
        # Remove hook
        handle.remove()
        
        # Compute faithfulness (KL divergence)
        kl_div = F.kl_div(
            ablated_probs.log(), baseline_probs, reduction='sum'
        ).item()
        
        results[method] = kl_div
        
        print(f"{method} ablation: KL divergence = {kl_div:.4f}")
    
    # Visualize
    plt.figure(figsize=(8, 5))
    methods_list = list(results.keys())
    kl_divs = list(results.values())
    
    plt.bar(methods_list, kl_divs)
    plt.ylabel('KL Divergence (Faithfulness)')
    plt.title(f'Circuit Faithfulness: Layer {layer}, Head {head}')
    plt.show()
    
    # Check sensitivity
    max_diff = max(kl_divs) - min(kl_divs)
    print(f"\nMax difference: {max_diff:.4f}")
    print(f"{'WARNING' if max_diff > 0.1 else 'OK'}: Large differences suggest conclusion depends on ablation choice")
    
    return results


# Test
test_text = "The capital of France is Paris, which is known for"
results = test_circuit_faithfulness(model, tokenizer, test_text, layer=6, head=0)

## Part 7: Red Team Your Project Concept

Apply skeptical validation to your own project concept.

In [None]:
class ConceptValidator:
    """
    Comprehensive validation suite for interpretability claims.
    """
    
    def __init__(self, model, tokenizer, concept_name):
        self.model = model
        self.tokenizer = tokenizer
        self.concept_name = concept_name
        self.validation_results = {}
    
    def run_all_checks(self, positive_examples, negative_examples, layer=6):
        """
        Run complete validation suite.
        
        Returns:
            report: Dict with all validation results
        """
        print(f"=" * 60)
        print(f"Red Team Validation: {self.concept_name}")
        print(f"=" * 60)
        
        # 1. Sanity checks
        print("\n[1/6] Sanity Checks...")
        sanity_pass = self.sanity_checks(positive_examples[0])
        self.validation_results['sanity_checks'] = sanity_pass
        
        # 2. Multi-method agreement
        print("\n[2/6] Multi-Method Validation...")
        agreement = self.multimethod_validation(positive_examples, negative_examples, layer)
        self.validation_results['method_agreement'] = agreement
        
        # 3. Robustness tests
        print("\n[3/6] Robustness Tests...")
        robustness = self.robustness_tests(positive_examples, layer)
        self.validation_results['robustness'] = robustness
        
        # 4. Causal validation
        print("\n[4/6] Causal Validation...")
        causal_pass = self.causal_validation(positive_examples)
        self.validation_results['causal_validation'] = causal_pass
        
        # 5. Baseline comparisons
        print("\n[5/6] Baseline Comparisons...")
        baseline_pass = self.baseline_comparisons(positive_examples, negative_examples)
        self.validation_results['baseline_comparison'] = baseline_pass
        
        # 6. Alternative explanations
        print("\n[6/6] Testing Alternative Explanations...")
        alternatives = self.test_alternative_explanations(positive_examples)
        self.validation_results['alternative_explanations'] = alternatives
        
        # Summary
        self.print_summary()
        
        return self.validation_results
    
    def sanity_checks(self, example_text):
        """Check if interpretation changes for random model."""
        # TODO: Implement
        print("  Testing on random model...")
        correlation = sanity_check_model_randomization(self.model, example_text)
        passed = abs(correlation) < 0.3
        print(f"  Result: {'PASS' if passed else 'FAIL'}")
        return passed
    
    def multimethod_validation(self, positive_examples, negative_examples, layer):
        """Test agreement across methods."""
        # TODO: Implement
        print("  Comparing probe, steering, and attribution...")
        agreement = validate_concept_multimethod(
            self.model, self.tokenizer, self.concept_name,
            positive_examples, negative_examples, layer
        )
        return agreement
    
    def robustness_tests(self, examples, layer):
        """Test adversarial robustness."""
        # TODO: Implement
        print("  Testing adversarial robustness...")
        # Simplified: just return a score
        return 0.75
    
    def causal_validation(self, examples):
        """Test if interventions have predicted effects."""
        # TODO: Implement
        print("  Testing causal interventions...")
        return True
    
    def baseline_comparisons(self, positive_examples, negative_examples):
        """Compare against random and simple baselines."""
        # TODO: Implement ROAR-style test
        print("  Comparing to random baseline...")
        return True
    
    def test_alternative_explanations(self, examples):
        """Check if simpler explanations fit the data."""
        # TODO: Test word frequency, recency, etc.
        print("  Testing word frequency baseline...")
        print("  Testing positional bias...")
        return ['word_frequency', 'positional_bias']
    
    def print_summary(self):
        """Print validation summary."""
        print("\n" + "=" * 60)
        print("VALIDATION SUMMARY")
        print("=" * 60)
        
        checks = [
            ('Sanity Checks', self.validation_results['sanity_checks']),
            ('Method Agreement', self.validation_results['method_agreement'] > 0.8),
            ('Robustness', self.validation_results['robustness'] > 0.7),
            ('Causal Validation', self.validation_results['causal_validation']),
            ('Baseline Comparison', self.validation_results['baseline_comparison']),
        ]
        
        passed = sum(1 for _, result in checks if result)
        total = len(checks)
        
        for check_name, result in checks:
            status = '‚úì PASS' if result else '‚úó FAIL'
            print(f"  {check_name:.<40} {status}")
        
        print(f"\nOverall: {passed}/{total} checks passed")
        
        if passed == total:
            print("\nüéâ Concept validation looks strong!")
        elif passed >= total * 0.7:
            print("\n‚ö†Ô∏è  Concept validation is promising but needs improvement.")
        else:
            print("\n‚ùå Concept validation has significant issues. Rethink your approach.")
        
        print(f"\nAlternative explanations to rule out:")
        for alt in self.validation_results['alternative_explanations']:
            print(f"  - {alt}")


# Example usage
validator = ConceptValidator(model, tokenizer, "YOUR_CONCEPT_NAME")

positive_examples = [
    # TODO: Add your concept's positive examples
    "Example text with your concept",
    "Another example with your concept",
]

negative_examples = [
    # TODO: Add negative examples
    "Example without your concept",
    "Another example without your concept",
]

results = validator.run_all_checks(positive_examples, negative_examples)

## Part 8: Publication-Ready Validation Checklist

Create a checklist for your paper's interpretability claims.

### Validation Checklist for Your Paper

Before submitting, ensure you can answer YES to these questions:

#### 1. Sanity Checks
- [ ] Tested interpretation method on randomly initialized model
- [ ] Tested interpretation method on model trained with random labels
- [ ] Results differ significantly from random model (correlation < 0.3)

#### 2. Method Validation
- [ ] Used at least 3 independent interpretability methods
- [ ] Methods show consistent results (agreement > 70%)
- [ ] Reported where methods disagree and investigated why

#### 3. Causal Validation
- [ ] Performed causal interventions (steering, ablation, or editing)
- [ ] Interventions had predicted effects on model behavior
- [ ] Tested on held-out examples not used during interpretation discovery

#### 4. Robustness
- [ ] Tested robustness to input perturbations
- [ ] Tested robustness to prompt variations
- [ ] Tested across multiple model checkpoints or architectures

#### 5. Baseline Comparisons
- [ ] Compared against random baselines
- [ ] Compared against simple heuristics (word frequency, position, etc.)
- [ ] Results significantly outperform baselines (p < 0.05)

#### 6. Alternative Explanations
- [ ] Identified at least 3 alternative explanations for observations
- [ ] Designed tests to distinguish between alternatives
- [ ] Provided evidence ruling out simpler explanations

#### 7. Quantitative Evaluation
- [ ] Used quantitative metrics (not just visualization)
- [ ] Reported confidence intervals or statistical significance
- [ ] Tested on sufficient sample size (n > 30 for statistical tests)

#### 8. Methodological Transparency
- [ ] Specified all hyperparameters (layer, alpha, threshold, etc.)
- [ ] Specified ablation method if using circuit analysis
- [ ] Reported how many experiments were run (avoid p-hacking)
- [ ] Included negative results or failure cases

#### 9. Human Validation (if applicable)
- [ ] Conducted human study with appropriate controls
- [ ] Measured whether explanations help humans (not just make sense)
- [ ] Tested for confirmation bias in human evaluations

#### 10. Limitations
- [ ] Included explicit limitations section
- [ ] Discussed what interpretation cannot tell us
- [ ] Acknowledged methodological dependencies
- [ ] Suggested future work to address limitations

**Score:** ___/10 categories completed

**Recommendation:**
- 9-10: Ready to submit
- 7-8: Address gaps before submission
- <7: Significant validation work needed

## Summary

In this notebook, you've learned to:

1. **Implement sanity checks** to detect when interpretations are model-independent
2. **Use ROAR benchmarks** to quantitatively validate feature importance
3. **Compare attention vs. gradient importance** and understand when they disagree
4. **Test adversarial robustness** of learned features and concepts
5. **Validate concepts with multiple methods** and check for agreement
6. **Measure circuit faithfulness** with different ablation methods
7. **Red team your own work** with comprehensive validation suite
8. **Follow best practices** for publication-ready validation

Remember: **Every interpretability claim is a hypothesis that needs rigorous testing.**

The bar is high, but meeting it is what separates true mechanistic understanding from interpretability theater.