# Week 8: Causal Abstraction and Interchange Intervention Analysis

This notebook provides hands-on exercises for understanding and applying causal abstraction, Interchange Intervention Analysis (IIA), and Distributed Alignment Search (DAS) to validate interpretability findings.

**Learning objectives:**
1. Build causal models for algorithmic and linguistic tasks
2. Implement Interchange Intervention Analysis (IIA)
3. Compare IIA with simple activation patching
4. Apply IIA to validate circuit hypotheses
5. Implement Distributed Alignment Search (DAS)
6. Validate SAE features using causal abstraction
7. Test probe findings with intervention analysis

We'll work through two main examples:
- **Addition** (Parts 1-3): Simple algorithmic task to build intuition
- **Subject-verb agreement** (Parts 4-7): Real NLP task demonstrating full pipeline

In [None]:
import torch
import torch.nn as nn
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Callable
from dataclasses import dataclass
import itertools

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Part 1: Building a Causal Model for Addition

We'll start with a simple task: computing `(a + b) + c` for three numbers.

**High-level causal model:**
```
Input₁, Input₂ → S₁ = Input₁ + Input₂
S₁, Input₃ → S₂ = S₁ + Input₃
S₂ → Output
```

**Exercise 1.1:** Implement the high-level causal model.

In [None]:
@dataclass
class AdditionCausalModel:
    """High-level causal model for (a + b) + c."""
    input1: int
    input2: int
    input3: int
    
    def compute_s1(self) -> int:
        """First intermediate sum: S₁ = Input₁ + Input₂."""
        # TODO: Implement
        raise NotImplementedError("Implement computation of S₁")
    
    def compute_s2(self, s1: int) -> int:
        """Second intermediate sum: S₂ = S₁ + Input₃."""
        # TODO: Implement
        raise NotImplementedError("Implement computation of S₂")
    
    def forward(self) -> Dict[str, int]:
        """Run the full causal model, returning all intermediate variables."""
        # TODO: Implement
        # Should return {'S1': ..., 'S2': ..., 'Output': ...}
        raise NotImplementedError("Implement forward pass")

# Test your implementation
model = AdditionCausalModel(input1=5, input2=3, input3=2)
result = model.forward()
print(f"Causal model result: {result}")
assert result['Output'] == 10, "Expected (5 + 3) + 2 = 10"

## Part 2: Training a Neural Network for Addition

Now we'll train a small neural network to perform this addition task, and identify which layer might correspond to the intermediate variable S₁.

**Exercise 2.1:** Implement and train a simple feedforward network.

In [None]:
class AdditionNetwork(nn.Module):
    """Neural network for computing (a + b) + c."""
    
    def __init__(self, hidden_size=64):
        super().__init__()
        # Input: 3 numbers, Output: 1 number
        self.layer1 = nn.Linear(3, hidden_size)
        self.layer2 = nn.Linear(hidden_size, hidden_size)
        self.layer3 = nn.Linear(hidden_size, hidden_size)  # Candidate for S₁
        self.layer4 = nn.Linear(hidden_size, 1)
        self.activations = {}  # Store intermediate activations
    
    def forward(self, x):
        h1 = torch.relu(self.layer1(x))
        self.activations['L1'] = h1
        
        h2 = torch.relu(self.layer2(h1))
        self.activations['L2'] = h2
        
        h3 = torch.relu(self.layer3(h2))
        self.activations['L3'] = h3  # Hypothesize this represents S₁
        
        output = self.layer4(h3)
        self.activations['Output'] = output
        
        return output

def generate_addition_dataset(n_samples=1000, max_val=20):
    """Generate random addition problems."""
    # TODO: Generate random (a, b, c) and compute (a + b) + c
    raise NotImplementedError("Implement dataset generation")

def train_addition_network(model, X_train, y_train, epochs=100, lr=0.001):
    """Train the addition network."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    for epoch in range(epochs):
        # TODO: Implement training loop
        raise NotImplementedError("Implement training loop")
    
    return model

# Generate data and train
# X_train, y_train, X_test, y_test = generate_addition_dataset()
# net = AdditionNetwork()
# net = train_addition_network(net, X_train, y_train)

## Part 3: Implementing Interchange Intervention Analysis (IIA)

Now we test whether layer L3 truly represents the intermediate sum S₁ using IIA.

**IIA Procedure:**
1. Run base input through high-level model → get S₁(base)
2. Run source input through high-level model → get S₁(source)
3. Run base input through neural network → get h₃(base)
4. Run source input through neural network → get h₃(source)
5. **Intervene:** Replace h₃(base) with h₃(source) in base run
6. **Check:** Does the output match what the high-level model predicts when S₁ is intervened?

**Exercise 3.1:** Implement IIA for the addition task.

In [None]:
def interchange_intervention(
    neural_net: nn.Module,
    causal_model_class: type,
    base_input: Tuple[int, int, int],
    source_input: Tuple[int, int, int],
    intervention_var: str = 'S1',  # Variable to intervene on
    alignment_layer: str = 'L3'     # Neural layer aligned to that variable
) -> Dict[str, float]:
    """
    Perform Interchange Intervention Analysis.
    
    Returns:
        - 'iia_output': Neural network output after intervention
        - 'expected_output': What causal model predicts
        - 'error': Absolute difference
        - 'base_output': Original neural network output (no intervention)
    """
    # TODO: Implement IIA
    # 1. Get S₁(source) from high-level causal model
    # 2. Get h₃(base) and h₃(source) from neural network
    # 3. Intervene: replace h₃(base) with h₃(source), continue forward pass
    # 4. Compare with expected output from causal model with intervened S₁
    raise NotImplementedError("Implement IIA")

# Test IIA
# base = (5, 3, 2)    # (5+3)+2 = 10
# source = (7, 1, 2)  # (7+1)+2 = 10, but S₁ differs: 8 vs 8
# result = interchange_intervention(net, AdditionCausalModel, base, source)
# print(f"IIA result: {result}")

**Exercise 3.2:** Compare IIA with simple activation patching.

Simple patching just replaces activations without checking if the intervention matches a causal model prediction.

In [None]:
def simple_activation_patching(
    neural_net: nn.Module,
    base_input: Tuple[int, int, int],
    source_input: Tuple[int, int, int],
    patch_layer: str = 'L3'
) -> Dict[str, float]:
    """
    Simple activation patching (no causal model).
    Just patches and sees what happens.
    """
    # TODO: Implement simple patching
    # 1. Run base input, store activations
    # 2. Run source input, get activation at patch_layer
    # 3. Re-run base with patched activation
    # 4. Return output change
    raise NotImplementedError("Implement simple patching")

# Compare both methods
# iia_res = interchange_intervention(net, AdditionCausalModel, (5,3,2), (7,1,2))
# patch_res = simple_activation_patching(net, (5,3,2), (7,1,2))
# print(f"IIA error: {iia_res['error']:.4f}")
# print(f"Patch output change: {patch_res['output_change']:.4f}")
# print("IIA provides interpretable error w.r.t. causal model, patching just shows effect")

## Part 4: Subject-Verb Agreement - Building the Causal Model

Now we move to a real NLP task: subject-verb agreement in sentences like:
- "The **key** to the cabinets **is** on the table."

**High-level causal model:**
```
Subject_Number (singular/plural) → Verb_Prediction (is/are)
```

Distractors ("cabinets") should NOT affect the prediction.

**Exercise 4.1:** Implement the linguistic causal model.

In [None]:
@dataclass
class SubjectVerbCausalModel:
    """High-level causal model for subject-verb agreement."""
    subject_number: str  # 'singular' or 'plural'
    
    def predict_verb(self) -> str:
        """Predict correct verb form based on subject number."""
        # TODO: Implement
        # 'singular' → 'is', 'plural' → 'are'
        raise NotImplementedError("Implement verb prediction")

def generate_agreement_sentences(n_samples=100):
    """
    Generate subject-verb agreement sentences.
    
    Template: "The [SUBJECT] to the [DISTRACTOR] [VERB] on the table."
    
    Returns:
        List of dicts with keys: 'sentence', 'subject_number', 'correct_verb'
    """
    subjects = {
        'singular': ['key', 'book', 'cat', 'student', 'teacher'],
        'plural': ['keys', 'books', 'cats', 'students', 'teachers']
    }
    distractors = {
        'singular': ['cabinet', 'shelf', 'box', 'room', 'desk'],
        'plural': ['cabinets', 'shelves', 'boxes', 'rooms', 'desks']
    }
    
    # TODO: Generate sentences with varying subject/distractor combinations
    # Critical: Include cases where distractor has OPPOSITE number from subject
    raise NotImplementedError("Implement sentence generation")

# Generate dataset
# sentences = generate_agreement_sentences()
# print(f"Example sentences:")
# for s in sentences[:5]:
#     print(f"  {s['sentence']} → {s['correct_verb']} (subject: {s['subject_number']})")

## Part 5: Loading GPT-2 and Extracting Activations

We'll use a pre-trained GPT-2 model and extract activations from specific layers and attention heads.

**Exercise 5.1:** Implement activation extraction with hooks.

In [None]:
# Load model and tokenizer
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
model.eval()

class ActivationExtractor:
    """Extract and store activations from specific layers/heads."""
    
    def __init__(self, model):
        self.model = model
        self.activations = {}
        self.hooks = []
    
    def register_hooks(self, layer_idx: int, head_indices: List[int] = None):
        """
        Register hooks to extract activations from specified layer and heads.
        
        Args:
            layer_idx: Which transformer layer (0-11 for GPT-2)
            head_indices: Which attention heads (None = all heads)
        """
        def hook_fn(module, input, output):
            # TODO: Store attention outputs
            # output[0] has shape (batch, seq_len, n_heads, head_dim)
            raise NotImplementedError("Implement hook function")
        
        # Register hook on attention layer
        layer = self.model.transformer.h[layer_idx].attn
        hook = layer.register_forward_hook(hook_fn)
        self.hooks.append(hook)
    
    def get_activation(self, sentence: str, position: int, layer: int, head: int):
        """
        Get activation for a specific position/layer/head.
        
        Args:
            sentence: Input sentence
            position: Token position (e.g., subject position)
            layer: Layer index
            head: Head index
        """
        # TODO: Tokenize, run model, extract activation
        raise NotImplementedError("Implement activation extraction")
    
    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

# Test extraction
# extractor = ActivationExtractor(model)
# extractor.register_hooks(layer_idx=5, head_indices=[2, 7, 11])
# act = extractor.get_activation("The key to the cabinets is", position=1, layer=5, head=2)
# print(f"Activation shape: {act.shape}")

## Part 6: Implementing IIA for Subject-Verb Agreement

Now we test whether specific attention heads (e.g., Layer 5, Heads {2, 7, 11}) track subject number.

**IIA for NLP:**
1. Base sentence: "The **key** to the cabinets is"
2. Source sentence: "The **keys** to the cabinet are" (swapped number)
3. Extract activations from hypothesized heads
4. Intervene: Replace base head activations with source head activations
5. Check: Does the model now predict "are" instead of "is"?

**Exercise 6.1:** Implement IIA for linguistic task.

In [None]:
def linguistic_iia(
    model: GPT2LMHeadModel,
    tokenizer: GPT2Tokenizer,
    base_sentence: str,
    source_sentence: str,
    intervention_layer: int,
    intervention_heads: List[int],
    subject_position: int
) -> Dict[str, any]:
    """
    Perform IIA for subject-verb agreement.
    
    Returns:
        - 'base_prediction': Verb predicted without intervention
        - 'intervened_prediction': Verb predicted after intervention
        - 'expected_change': Whether intervention should change prediction
        - 'intervention_success': Whether prediction changed as expected
    """
    # TODO: Implement linguistic IIA
    # 1. Get activations from base and source sentences
    # 2. Intervene on specified heads at subject position
    # 3. Check if verb prediction changes appropriately
    raise NotImplementedError("Implement linguistic IIA")

# Test IIA on example sentence pair
# base = "The key to the cabinets"
# source = "The keys to the cabinet"  # Subject number flipped
# result = linguistic_iia(
#     model, tokenizer, base, source,
#     intervention_layer=5,
#     intervention_heads=[2, 7, 11],
#     subject_position=1  # Position of "key"/"keys"
# )
# print(f"IIA result: {result}")

**Exercise 6.2:** Compute IIA accuracy across many sentence pairs.

A good alignment should have high IIA accuracy: interventions should consistently produce the expected changes.

In [None]:
def evaluate_iia_accuracy(
    model: GPT2LMHeadModel,
    tokenizer: GPT2Tokenizer,
    sentence_pairs: List[Tuple[str, str]],
    layer: int,
    heads: List[int]
) -> float:
    """
    Evaluate IIA accuracy over many intervention pairs.
    
    Returns:
        Fraction of interventions that produce expected change
    """
    # TODO: Run IIA on all sentence pairs, compute accuracy
    raise NotImplementedError("Implement IIA evaluation")

# Generate pairs and evaluate
# pairs = [(s1['sentence'], s2['sentence']) 
#          for s1, s2 in zip(sentences[::2], sentences[1::2])
#          if s1['subject_number'] != s2['subject_number']]
# accuracy = evaluate_iia_accuracy(model, tokenizer, pairs, layer=5, heads=[2,7,11])
# print(f"IIA accuracy: {accuracy:.2%}")

## Part 7: Distributed Alignment Search (DAS)

Instead of manually hypothesizing which heads track subject number, we can search automatically.

**DAS Algorithm:**
1. Start with empty set of components
2. For each component (layer/head):
   - Add it to the set
   - Compute IIA accuracy
   - Keep it if accuracy improves
3. Return best-performing set

**Exercise 7.1:** Implement greedy DAS.

In [None]:
def distributed_alignment_search(
    model: GPT2LMHeadModel,
    tokenizer: GPT2Tokenizer,
    sentence_pairs: List[Tuple[str, str]],
    max_components: int = 5,
    n_layers: int = 12,
    n_heads: int = 12
) -> Dict[str, any]:
    """
    Greedy search for components aligned to subject number.
    
    Returns:
        - 'components': List of (layer, head) tuples
        - 'accuracy': IIA accuracy with these components
        - 'search_history': Accuracy at each step
    """
    selected = []
    best_accuracy = 0.0
    history = []
    
    # TODO: Implement greedy search
    # For each potential component:
    #   1. Try adding it to selected set
    #   2. Compute IIA accuracy
    #   3. Keep if it improves accuracy
    # Stop after max_components or when no improvement
    raise NotImplementedError("Implement DAS")
    
    return {
        'components': selected,
        'accuracy': best_accuracy,
        'search_history': history
    }

# Run DAS
# das_result = distributed_alignment_search(model, tokenizer, pairs, max_components=5)
# print(f"DAS found components: {das_result['components']}")
# print(f"Final IIA accuracy: {das_result['accuracy']:.2%}")
# 
# # Plot search progress
# plt.plot(das_result['search_history'])
# plt.xlabel('Component added')
# plt.ylabel('IIA accuracy')
# plt.title('DAS Search Progress')
# plt.show()

## Part 8: Validating Circuit Findings from Week 5

Now let's use IIA to validate a circuit hypothesis from Week 5 (e.g., induction head circuit).

**Exercise 8.1:** Test whether identified induction heads truly perform the induction operation using causal abstraction.

In [None]:
def validate_induction_circuit(
    model: GPT2LMHeadModel,
    tokenizer: GPT2Tokenizer,
    induction_heads: List[Tuple[int, int]],  # (layer, head) pairs
    n_tests: int = 20
) -> Dict[str, float]:
    """
    Validate induction head circuit using IIA.
    
    Induction task: "A B ... A" → predict B
    
    Causal model:
        Previous_Token_Match → Next_Token_Prediction
    
    Returns:
        IIA accuracy for induction heads
    """
    # TODO: Implement IIA for induction
    # 1. Generate base: "cat dog ... cat"
    # 2. Generate source: "bird fish ... bird"
    # 3. Intervene on induction heads
    # 4. Check if prediction changes from "dog" to "fish"
    raise NotImplementedError("Implement induction circuit validation")

# Test induction heads (example coordinates)
# induction_heads = [(5, 1), (5, 5), (6, 9)]  # Hypothetical
# result = validate_induction_circuit(model, tokenizer, induction_heads)
# print(f"Induction circuit IIA accuracy: {result['accuracy']:.2%}")

## Part 9: Validating SAE Features from Week 7

We can use causal abstraction to test whether an SAE feature causally represents a concept.

**Exercise 9.1:** Validate an SAE feature using intervention.

In [None]:
def validate_sae_feature(
    model: GPT2LMHeadModel,
    sae_encoder: nn.Module,  # Trained SAE from Week 7
    sae_decoder: nn.Module,
    feature_idx: int,
    concept_examples: List[Tuple[str, str]],  # (has_concept, lacks_concept) pairs
    layer: int
) -> Dict[str, float]:
    """
    Test if SAE feature causally represents a concept.
    
    Causal model:
        Concept_Present → Feature_Active → Behavior_Change
    
    Procedure:
        1. Find examples where feature activates (concept present)
        2. Find examples where feature doesn't activate (concept absent)
        3. Intervene: Set feature activation to opposite value
        4. Check if behavior changes as expected
    
    Returns:
        Causal effect strength (how much intervention changes output)
    """
    # TODO: Implement SAE feature validation
    # 1. Encode activations to get SAE features
    # 2. Intervene on specific feature (set to 0 or mean value)
    # 3. Decode back to activation space
    # 4. Continue model forward pass
    # 5. Measure output change
    raise NotImplementedError("Implement SAE feature validation")

# Example: Validate a "plural subject" feature
# sae_encoder = ...  # Load from Week 7
# sae_decoder = ...
# plural_examples = [
#     ("The keys are", "The key is"),  # plural vs singular
#     ("The books are", "The book is"),
# ]
# result = validate_sae_feature(model, sae_encoder, sae_decoder, 
#                               feature_idx=42, concept_examples=plural_examples, layer=5)
# print(f"Causal effect of feature 42: {result['causal_effect']:.3f}")

## Part 10: Project Template - Validating Your Concept

Now apply these techniques to validate your own research project from previous weeks.

**Project Workflow:**
1. Define your concept's causal model
2. Identify candidate neural components (circuits, SAE features, or probe-identified layers)
3. Implement IIA to test alignment
4. Use DAS if you need to search for components
5. Report IIA accuracy and causal effect sizes
6. Make rigorous claims backed by intervention evidence

**Exercise 10.1:** Fill in this template for your project.

In [None]:
# ========== PROJECT TEMPLATE ==========

# 1. Define your concept and causal model
MY_CONCEPT = "[Your concept here, e.g., 'temporal reasoning', 'politeness', 'causality']"

@dataclass
class MyConceptCausalModel:
    """High-level causal model for [YOUR CONCEPT]."""
    # TODO: Define variables in your causal model
    pass

# 2. Generate test examples
def generate_concept_examples():
    """Generate sentence pairs that differ in your concept."""
    # TODO: Create examples where your concept varies
    # Return list of (base, source) pairs
    raise NotImplementedError("Generate examples for your concept")

# 3. Identify candidate components
# From Week 5 (circuits), Week 6 (probes), or Week 7 (SAE features)
CANDIDATE_COMPONENTS = [
    # TODO: List (layer, head) or (layer, neuron) or (layer, sae_feature)
]

# 4. Run IIA validation
def validate_my_concept(
    model: GPT2LMHeadModel,
    components: List[Tuple[int, int]],
    examples: List[Tuple[str, str]]
) -> Dict[str, any]:
    """Validate that components represent your concept."""
    # TODO: Implement IIA for your concept
    # 1. For each example pair:
    #    - Extract activations from components
    #    - Intervene on components
    #    - Check if output changes as causal model predicts
    # 2. Compute IIA accuracy
    # 3. Compute effect sizes
    raise NotImplementedError("Implement validation for your concept")

# 5. Run DAS if needed
def search_for_concept_components(model, examples):
    """Search for components aligned to your concept."""
    # TODO: Run DAS to find best components
    raise NotImplementedError("Run DAS for your concept")

# 6. Report results
# examples = generate_concept_examples()
# validation_results = validate_my_concept(model, CANDIDATE_COMPONENTS, examples)
# 
# print(f"Concept: {MY_CONCEPT}")
# print(f"IIA Accuracy: {validation_results['accuracy']:.2%}")
# print(f"Average Causal Effect: {validation_results['avg_effect']:.3f}")
# print(f"\nConclusion: These components {'DO' if validation_results['accuracy'] > 0.8 else 'DO NOT'} ")
# print(f"            causally represent {MY_CONCEPT} (IIA accuracy: {validation_results['accuracy']:.2%})")

## Summary and Research Guidelines

**Key Takeaways:**

1. **Causal abstraction** formalizes the relationship between high-level concepts and neural mechanisms
2. **IIA** provides quantitative validation of interpretability claims (vs. qualitative inspection)
3. **DAS** automates the search for components aligned to concepts
4. **Integration:** Use IIA to validate findings from circuits (Week 5), probes (Week 6), and SAEs (Week 7)

**For Your Research Paper:**

✅ **DO:**
- Define explicit causal models for your concept
- Report IIA accuracy and effect sizes
- Compare multiple alignment hypotheses
- Use intervention controls (e.g., unrelated components)

❌ **DON'T:**
- Make causal claims based only on correlation (e.g., probe accuracy)
- Cherry-pick examples without systematic evaluation
- Ignore cases where interventions fail

**Next Steps:**
1. Apply these techniques to your project
2. Iterate between exploratory methods (Week 5-7) and validation (Week 8)
3. Build rigorous evidence for NeurIPS submission