# Week 4 Exercise: Causal Mediation Analysis

In this exercise, you'll gain hands-on experience with:
- Activation patching (noise and clean)
- Computing causal effects
- ROME-style causal tracing
- Gradient-based attribution
- Average Indirect Effect (AIE)
- Function vector extraction
- Counterfactual dataset design

## Setup

In [None]:
!pip install transformers torch numpy matplotlib seaborn -q

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModelForCausalLM, AutoTokenizer
from copy import deepcopy
import warnings
warnings.filterwarnings('ignore')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Load GPT-2
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = model.to(device)
model.eval()

print(f"Model: {model_name}")
print(f"Layers: {model.config.n_layer}")
print(f"Hidden size: {model.config.n_embd}")

## Part 1: Basic Activation Patching

Let's implement the core patching functionality.

In [None]:
def run_with_cache(model, prompt, layer_idx=None):
    """
    Run model and cache activations.
    
    Args:
        model: The language model
        prompt: Input text
        layer_idx: If specified, only cache this layer
    
    Returns:
        logits: Model output logits
        cache: Dict of cached activations
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    cache = {}
    
    # Hook to save activations
    def save_activation(name):
        def hook(module, input, output):
            if isinstance(output, tuple):
                cache[name] = output[0].detach()
            else:
                cache[name] = output.detach()
        return hook
    
    # Register hooks
    handles = []
    if layer_idx is None:
        # Cache all layers
        for i, layer in enumerate(model.transformer.h):
            handle = layer.register_forward_hook(save_activation(f"layer_{i}"))
            handles.append(handle)
    else:
        # Cache specific layer
        handle = model.transformer.h[layer_idx].register_forward_hook(
            save_activation(f"layer_{layer_idx}")
        )
        handles.append(handle)
    
    # Run model
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Clean up hooks
    for handle in handles:
        handle.remove()
    
    return outputs.logits, cache

# Test
logits, cache = run_with_cache(model, "The capital of France is", layer_idx=6)
print(f"Output shape: {logits.shape}")
print(f"Cached activations: {list(cache.keys())}")
print(f"Layer 6 activation shape: {cache['layer_6'].shape}")

### Implement Activation Patching

In [None]:
def run_with_patch(model, prompt, patch_layer, patch_activations, patch_position=None):
    """
    Run model with patched activations.
    
    Args:
        model: The language model
        prompt: Input text (corrupted)
        patch_layer: Which layer to patch
        patch_activations: Activations to patch in (from clean run)
        patch_position: If specified, only patch this token position
    
    Returns:
        logits: Model output with patching applied
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Hook to patch activations
    def patch_hook(module, input, output):
        if isinstance(output, tuple):
            hidden_states = output[0]
        else:
            hidden_states = output
        
        # Patch activations
        if patch_position is None:
            # Patch all positions
            hidden_states[:] = patch_activations
        else:
            # Patch specific position
            hidden_states[:, patch_position, :] = patch_activations[:, patch_position, :]
        
        if isinstance(output, tuple):
            return (hidden_states,) + output[1:]
        return hidden_states
    
    # Register hook on target layer
    handle = model.transformer.h[patch_layer].register_forward_hook(patch_hook)
    
    # Run model
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Clean up
    handle.remove()
    
    return outputs.logits

# Test patching
clean_prompt = "The capital of France is"
corrupted_prompt = "The capital of Germany is"

# Get clean activations
_, clean_cache = run_with_cache(model, clean_prompt, layer_idx=6)

# Patch them into corrupted run
patched_logits = run_with_patch(
    model, 
    corrupted_prompt, 
    patch_layer=6, 
    patch_activations=clean_cache['layer_6']
)

print("Patching test successful!")
print(f"Patched output shape: {patched_logits.shape}")

### Test Causal Effect of Patching

In [None]:
def get_top_prediction(logits, top_k=5):
    """Get top k predictions from logits."""
    probs = torch.softmax(logits[0, -1, :], dim=0)
    top_probs, top_indices = torch.topk(probs, top_k)
    
    results = []
    for prob, idx in zip(top_probs, top_indices):
        token = tokenizer.decode([idx])
        results.append((token, prob.item()))
    return results

# Compare outputs
clean_logits, _ = run_with_cache(model, clean_prompt)
corrupted_logits, _ = run_with_cache(model, corrupted_prompt)
patched_logits = run_with_patch(model, corrupted_prompt, 6, clean_cache['layer_6'])

print("Clean prompt: 'The capital of France is'")
print("Top predictions:")
for token, prob in get_top_prediction(clean_logits):
    print(f"  {token:15s}: {prob:.4f}")

print("\nCorrupted prompt: 'The capital of Germany is'")
print("Top predictions:")
for token, prob in get_top_prediction(corrupted_logits):
    print(f"  {token:15s}: {prob:.4f}")

print("\nPatched (Germany prompt + France layer 6):")
print("Top predictions:")
for token, prob in get_top_prediction(patched_logits):
    print(f"  {token:15s}: {prob:.4f}")

**Question:** Does patching layer 6 transfer the clean answer? What does this tell you about layer 6's role?

## Part 2: Noise vs Clean Patching

In [None]:
def noise_patch(model, prompt, patch_layer, noise_scale=1.0):
    """
    Patch with random noise (ablation).
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    def noise_hook(module, input, output):
        if isinstance(output, tuple):
            hidden_states = output[0]
        else:
            hidden_states = output
        
        # Add noise
        noise = torch.randn_like(hidden_states) * noise_scale
        hidden_states = hidden_states + noise
        
        if isinstance(output, tuple):
            return (hidden_states,) + output[1:]
        return hidden_states
    
    handle = model.transformer.h[patch_layer].register_forward_hook(noise_hook)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    handle.remove()
    return outputs.logits

# Compare noise vs clean patching
original_logits, _ = run_with_cache(model, clean_prompt)
noise_logits = noise_patch(model, clean_prompt, patch_layer=6, noise_scale=0.1)

print("Original (no patching):")
for token, prob in get_top_prediction(original_logits)[:3]:
    print(f"  {token:15s}: {prob:.4f}")

print("\nNoise patched (layer 6):")
for token, prob in get_top_prediction(noise_logits)[:3]:
    print(f"  {token:15s}: {prob:.4f}")

### Exercise 2.1: Compare Strategies

Test both noise and clean patching across layers.

In [None]:
def compare_patching_strategies(clean_prompt, corrupted_prompt, target_token):
    """
    Compare noise vs clean patching across layers.
    """
    target_id = tokenizer.encode(target_token, add_special_tokens=False)[0]
    
    results = {'layer': [], 'noise': [], 'clean': []}
    
    for layer in range(model.config.n_layer):
        # Clean activations
        _, clean_cache = run_with_cache(model, clean_prompt, layer_idx=layer)
        
        # Noise patching
        noise_logits = noise_patch(model, corrupted_prompt, layer, noise_scale=0.1)
        noise_prob = torch.softmax(noise_logits[0, -1, :], dim=0)[target_id].item()
        
        # Clean patching
        clean_patched_logits = run_with_patch(
            model, corrupted_prompt, layer, clean_cache[f'layer_{layer}']
        )
        clean_prob = torch.softmax(clean_patched_logits[0, -1, :], dim=0)[target_id].item()
        
        results['layer'].append(layer)
        results['noise'].append(noise_prob)
        results['clean'].append(clean_prob)
    
    return results

# Test
results = compare_patching_strategies(
    "The capital of France is",
    "The capital of Germany is",
    " Paris"
)

# Plot
plt.figure(figsize=(12, 5))
plt.plot(results['layer'], results['noise'], marker='o', label='Noise Patching')
plt.plot(results['layer'], results['clean'], marker='s', label='Clean Patching')
plt.xlabel('Layer')
plt.ylabel('P("Paris")')
plt.title('Patching Effect Across Layers')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

**Question:** Which layers show the strongest clean patching effect? What does this suggest about where information is processed?

## Part 3: Computing Causal Effects

Let's formalize our measurements as causal effects.

In [None]:
def compute_ace(model, prompt, patch_layer, clean_activations, target_token):
    """
    Compute Average Causal Effect of patching a layer.
    
    ACE = E[Y | do(X=clean)] - E[Y | do(X=corrupted)]
    """
    target_id = tokenizer.encode(target_token, add_special_tokens=False)[0]
    
    # Corrupted (no patch)
    corrupted_logits, _ = run_with_cache(model, prompt)
    p_corrupted = torch.softmax(corrupted_logits[0, -1, :], dim=0)[target_id].item()
    
    # Clean (with patch)
    patched_logits = run_with_patch(model, prompt, patch_layer, clean_activations)
    p_patched = torch.softmax(patched_logits[0, -1, :], dim=0)[target_id].item()
    
    ace = p_patched - p_corrupted
    
    return ace, p_corrupted, p_patched

# Example
clean_prompt = "The capital of France is"
corrupted_prompt = "The capital of Germany is"
target = " Paris"

_, clean_cache = run_with_cache(model, clean_prompt, layer_idx=6)
ace, p_before, p_after = compute_ace(
    model, corrupted_prompt, 6, clean_cache['layer_6'], target
)

print(f"P(Paris | corrupted, no patch) = {p_before:.4f}")
print(f"P(Paris | corrupted, patch layer 6) = {p_after:.4f}")
print(f"\nAverage Causal Effect = {ace:.4f}")
print(f"\nInterpretation: Patching layer 6 {'increases' if ace > 0 else 'decreases'} ")
print(f"the probability of 'Paris' by {abs(ace):.4f}")

### Total Effect and Indirect Effect

In [None]:
def compute_total_effect(model, clean_prompt, corrupted_prompt, target_token):
    """
    Total effect of changing the prompt.
    """
    target_id = tokenizer.encode(target_token, add_special_tokens=False)[0]
    
    clean_logits, _ = run_with_cache(model, clean_prompt)
    corrupted_logits, _ = run_with_cache(model, corrupted_prompt)
    
    p_clean = torch.softmax(clean_logits[0, -1, :], dim=0)[target_id].item()
    p_corrupted = torch.softmax(corrupted_logits[0, -1, :], dim=0)[target_id].item()
    
    total_effect = p_clean - p_corrupted
    
    return total_effect, p_clean, p_corrupted

# Compute
te, p_clean, p_corrupt = compute_total_effect(
    model, clean_prompt, corrupted_prompt, target
)

print(f"P(Paris | France) = {p_clean:.4f}")
print(f"P(Paris | Germany) = {p_corrupt:.4f}")
print(f"\nTotal Effect = {te:.4f}")
print(f"\nThis is the total causal effect of changing France→Germany")

## Part 4: ROME-Style Causal Tracing

Replicate ROME's methodology at a small scale.

In [None]:
def causal_trace(model, clean_prompt, target_token):
    """
    ROME-style causal tracing: add noise everywhere, 
    then restore clean activations one component at a time.
    """
    target_id = tokenizer.encode(target_token, add_special_tokens=False)[0]
    
    # Get clean activations for all layers
    clean_logits, clean_cache = run_with_cache(model, clean_prompt)
    
    # Get baseline probability
    p_clean = torch.softmax(clean_logits[0, -1, :], dim=0)[target_id].item()
    
    # Add noise to embeddings to create corrupted run
    # (In full ROME, they corrupt all embeddings; here we'll use a simpler approach)
    
    results = []
    for layer in range(model.config.n_layer):
        # Run with clean activations restored at this layer
        _, cache_at_layer = run_with_cache(model, clean_prompt, layer_idx=layer)
        
        # For simplicity, we'll measure the effect of having clean vs noisy at each layer
        # In real ROME, they patch clean into a fully corrupted run
        
        # Compute indirect effect through this layer
        ace, _, p_restored = compute_ace(
            model, clean_prompt, layer, cache_at_layer[f'layer_{layer}'], target_token
        )
        
        results.append({
            'layer': layer,
            'restoration': p_restored,
            'ace': ace
        })
    
    return results

# Run causal trace
trace_results = causal_trace(model, "The Eiffel Tower is located in", " Paris")

# Visualize
layers = [r['layer'] for r in trace_results]
restoration = [r['restoration'] for r in trace_results]

plt.figure(figsize=(12, 5))
plt.plot(layers, restoration, marker='o', linewidth=2)
plt.xlabel('Layer')
plt.ylabel('P(correct answer)')
plt.title('ROME-style Causal Trace: Which layers are critical?')
plt.grid(True, alpha=0.3)
plt.axhline(y=max(restoration), color='r', linestyle='--', alpha=0.5, label='Peak performance')
plt.legend()
plt.show()

# Find critical layers
top_layers = sorted(trace_results, key=lambda x: x['restoration'], reverse=True)[:3]
print("\nTop 3 most important layers:")
for r in top_layers:
    print(f"  Layer {r['layer']}: P(Paris) = {r['restoration']:.4f}")

**Question:** Which layers show the highest restoration? How does this compare to ROME's findings about middle layers?

## Part 5: Gradient-Based Attribution

Use gradients to efficiently estimate importance.

In [None]:
def gradient_attribution(model, clean_prompt, corrupted_prompt, target_token):
    """
    Compute gradient-based attribution for each layer.
    
    Attribution ≈ gradient × (clean - corrupted)
    """
    target_id = tokenizer.encode(target_token, add_special_tokens=False)[0]
    
    # Get clean and corrupted activations
    _, clean_cache = run_with_cache(model, clean_prompt)
    _, corrupted_cache = run_with_cache(model, corrupted_prompt)
    
    # Run corrupted with gradients enabled
    model.zero_grad()
    inputs = tokenizer(corrupted_prompt, return_tensors="pt").to(device)
    
    activations = {}
    
    def save_activation_with_grad(name):
        def hook(module, input, output):
            if isinstance(output, tuple):
                act = output[0]
            else:
                act = output
            act.requires_grad_(True)
            act.retain_grad()
            activations[name] = act
        return hook
    
    # Register hooks
    handles = []
    for i in range(model.config.n_layer):
        handle = model.transformer.h[i].register_forward_hook(
            save_activation_with_grad(f"layer_{i}")
        )
        handles.append(handle)
    
    # Forward pass
    outputs = model(**inputs)
    logits = outputs.logits
    
    # Compute loss w.r.t. target token
    loss = -logits[0, -1, target_id]  # Negative because we want to maximize
    loss.backward()
    
    # Compute attributions
    attributions = {}
    for i in range(model.config.n_layer):
        grad = activations[f"layer_{i}"].grad
        delta = clean_cache[f"layer_{i}"] - corrupted_cache[f"layer_{i}"]
        
        # Attribution = grad · delta (element-wise product, then sum)
        attribution = (grad.cpu() * delta).sum().item()
        attributions[i] = attribution
    
    # Clean up
    for handle in handles:
        handle.remove()
    
    return attributions

# Compute attributions
attrs = gradient_attribution(model, clean_prompt, corrupted_prompt, target)

# Visualize
plt.figure(figsize=(12, 5))
plt.bar(range(len(attrs)), [attrs[i] for i in range(len(attrs))])
plt.xlabel('Layer')
plt.ylabel('Attribution Score')
plt.title('Gradient-Based Attribution: Which layers matter most?')
plt.grid(True, alpha=0.3, axis='y')
plt.show()

# Top attributed layers
top_attrs = sorted(attrs.items(), key=lambda x: abs(x[1]), reverse=True)[:5]
print("\nTop 5 attributed layers:")
for layer, score in top_attrs:
    print(f"  Layer {layer}: {score:+.4f}")

**Question:** Do gradient attributions match the actual patching results from earlier? Why might they differ?

## Part 6: Average Indirect Effect (AIE)

Systematically measure causal importance across components.

In [None]:
def compute_aie(model, clean_examples, corrupted_examples, target_tokens):
    """
    Compute Average Indirect Effect across multiple examples.
    
    AIE_layer = mean over examples of: P(correct | patch layer) - P(correct | no patch)
    """
    n_layers = model.config.n_layer
    aie_scores = {layer: [] for layer in range(n_layers)}
    
    for clean_prompt, corrupted_prompt, target in zip(clean_examples, corrupted_examples, target_tokens):
        target_id = tokenizer.encode(target, add_special_tokens=False)[0]
        
        # Baseline (corrupted, no patch)
        baseline_logits, _ = run_with_cache(model, corrupted_prompt)
        p_baseline = torch.softmax(baseline_logits[0, -1, :], dim=0)[target_id].item()
        
        # Test each layer
        for layer in range(n_layers):
            # Get clean activations
            _, clean_cache = run_with_cache(model, clean_prompt, layer_idx=layer)
            
            # Patch and measure
            patched_logits = run_with_patch(
                model, corrupted_prompt, layer, clean_cache[f'layer_{layer}']
            )
            p_patched = torch.softmax(patched_logits[0, -1, :], dim=0)[target_id].item()
            
            # Indirect effect for this example
            ie = p_patched - p_baseline
            aie_scores[layer].append(ie)
    
    # Average across examples
    aie = {layer: np.mean(scores) for layer, scores in aie_scores.items()}
    
    return aie

# Test AIE on multiple examples
clean_examples = [
    "The capital of France is",
    "The Eiffel Tower is located in",
    "Paris is the capital of",
]

corrupted_examples = [
    "The capital of Germany is",
    "The Eiffel Tower is located in",  # Same (as a control)
    "Berlin is the capital of",
]

targets = [" Paris", " Paris", " France"]

aie_results = compute_aie(model, clean_examples, corrupted_examples, targets)

# Visualize
plt.figure(figsize=(12, 5))
plt.bar(range(len(aie_results)), [aie_results[i] for i in range(len(aie_results))],
       color=['red' if v < 0 else 'green' for v in aie_results.values()])
plt.xlabel('Layer')
plt.ylabel('Average Indirect Effect')
plt.title('AIE Across Layers: Which mediate the causal effect?')
plt.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
plt.grid(True, alpha=0.3, axis='y')
plt.show()

# Top AIE layers
top_aie = sorted(aie_results.items(), key=lambda x: x[1], reverse=True)[:5]
print("\nTop 5 layers by AIE:")
for layer, score in top_aie:
    print(f"  Layer {layer}: AIE = {score:+.4f}")

## Part 7: Function Vectors

Extract vectors that encode specific functions.

In [None]:
def extract_function_vector(model, positive_examples, negative_examples, layer_idx=-1):
    """
    Extract a function vector from contrastive pairs.
    
    Similar to steering vectors, but for specific functions.
    """
    pos_activations = []
    neg_activations = []
    
    for pos_prompt, neg_prompt in zip(positive_examples, negative_examples):
        # Positive
        _, pos_cache = run_with_cache(model, pos_prompt, layer_idx=layer_idx)
        pos_activations.append(pos_cache[f'layer_{layer_idx}'][0, -1, :].cpu())
        
        # Negative  
        _, neg_cache = run_with_cache(model, neg_prompt, layer_idx=layer_idx)
        neg_activations.append(neg_cache[f'layer_{layer_idx}'][0, -1, :].cpu())
    
    # Mean difference
    pos_mean = torch.stack(pos_activations).mean(dim=0)
    neg_mean = torch.stack(neg_activations).mean(dim=0)
    
    function_vector = pos_mean - neg_mean
    
    return function_vector

# Example: Extract "comparison reversal" function
comparison_pairs = [
    ("Paris is larger than Lyon", "Lyon is larger than Paris"),
    ("Ten is greater than five", "Five is greater than ten"),
    ("The elephant is bigger than the mouse", "The mouse is bigger than the elephant"),
]

pos_examples = [p[0] for p in comparison_pairs]
neg_examples = [p[1] for p in comparison_pairs]

comparison_vector = extract_function_vector(model, pos_examples, neg_examples, layer_idx=8)

print(f"Extracted comparison function vector")
print(f"Shape: {comparison_vector.shape}")
print(f"Magnitude: {comparison_vector.norm():.4f}")

### Test Function Vector

In [None]:
def apply_function_vector(model, prompt, function_vector, layer_idx, alpha=1.0):
    """
    Apply a function vector to a prompt.
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    function_vector = function_vector.to(device)
    
    def function_hook(module, input, output):
        if isinstance(output, tuple):
            hidden_states = output[0]
        else:
            hidden_states = output
        
        # Add function vector to last position
        hidden_states[0, -1, :] += alpha * function_vector
        
        if isinstance(output, tuple):
            return (hidden_states,) + output[1:]
        return hidden_states
    
    handle = model.transformer.h[layer_idx].register_forward_hook(function_hook)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    handle.remove()
    
    return outputs.logits

# Test on new example
test_prompt = "A mountain is taller than a"

print("Original predictions:")
original_logits, _ = run_with_cache(model, test_prompt)
for token, prob in get_top_prediction(original_logits)[:5]:
    print(f"  {token:15s}: {prob:.4f}")

print("\nWith comparison reversal function (+):")
reversed_logits = apply_function_vector(model, test_prompt, comparison_vector, 8, alpha=2.0)
for token, prob in get_top_prediction(reversed_logits)[:5]:
    print(f"  {token:15s}: {prob:.4f}")

### Exercise 7.1: Extract Your Own Function Vector

Design pairs for a function relevant to your concept.

In [None]:
# TODO: Define your function pairs
my_function_pairs = [
    # ("with function", "without function"),
]

# Extract and test your function vector

## Part 8: Designing Counterfactual Datasets

Practice creating effective minimal pairs.

In [None]:
def validate_counterfactual_pair(model, pair, expected_difference="outputs should differ"):
    """
    Check if a counterfactual pair produces different outputs.
    """
    prompt1, prompt2 = pair
    
    logits1, _ = run_with_cache(model, prompt1)
    logits2, _ = run_with_cache(model, prompt2)
    
    pred1 = get_top_prediction(logits1, top_k=1)[0]
    pred2 = get_top_prediction(logits2, top_k=1)[0]
    
    # Check structural similarity
    tokens1 = tokenizer.tokenize(prompt1)
    tokens2 = tokenizer.tokenize(prompt2)
    length_diff = abs(len(tokens1) - len(tokens2))
    
    print(f"Pair validation:")
    print(f"  Prompt 1: '{prompt1}'")
    print(f"    → {pred1[0]} ({pred1[1]:.4f})")
    print(f"  Prompt 2: '{prompt2}'")
    print(f"    → {pred2[0]} ({pred2[1]:.4f})")
    print(f"\n  Length difference: {length_diff} tokens")
    print(f"  Predictions differ: {pred1[0] != pred2[0]}")
    
    # Good pair criteria
    if length_diff > 2:
        print("  ⚠️  Warning: Large length difference may introduce confounds")
    if pred1[0] == pred2[0]:
        print("  ⚠️  Warning: Same prediction - pair may not test your hypothesis")
    else:
        print("  ✓ Good pair: minimal change, different predictions")
    
    return pred1[0] != pred2[0]

# Test examples
good_pair = ("The capital of France is", "The capital of Germany is")
bad_pair = ("France's capital is", "The capital of Germany is")

print("GOOD PAIR:")
print("="*60)
validate_counterfactual_pair(model, good_pair)

print("\n\nBAD PAIR (structure differs):")
print("="*60)
validate_counterfactual_pair(model, bad_pair)

### Exercise 8.1: Design Dataset for Your Concept

Create and validate 10-15 pairs for your project.

In [None]:
# TODO: Design your counterfactual pairs
my_pairs = [
    # ("with concept", "without concept"),
]

# Validate each pair
# for pair in my_pairs:
#     validate_counterfactual_pair(model, pair)
#     print("\n" + "="*60 + "\n")

## Reflection Questions

Answer these for your project writeup:

1. **Patching Strategy**: When should you use noise patching vs clean patching for your concept?

2. **Localization**: Which layers are most important for your concept? How does this compare to ROME (middle MLPs) or entity tracking (attention heads)?

3. **Attribution**: Do gradient attributions match actual patching results? What might explain discrepancies?

4. **AIE**: Which components show high AIE? What does this tell you about the causal pathway?

5. **Function Vectors**: Can your concept be captured as a function vector? Why or why not?

6. **Counterfactuals**: What makes a good counterfactual pair for your concept? What confounds did you have to avoid?

## Next Steps

For your assignment:
1. Design comprehensive counterfactual dataset (15-25 pairs)
2. Run AIE analysis to localize important layers/components
3. Apply gradient attribution for fine-grained identification
4. Validate top findings with direct patching
5. Extract function vectors if applicable
6. Build mechanistic story of how your concept is processed

These causal methods will be essential for finding complete circuits in Week 5!