# Research Notebook: Causal Scrubbing & Path Patching

Project: GDSC Interpretability - glassboxllms </br>
Researcher: Ankita Sharma (sharmaankita3387) </br>
Objective: To move beyond correlation and verify the causal mechanisms within Language Models. </br>

### 1. Mathematical and Theoretical Framework
Causal Scrubbing is a principled approach to testing whether a specific part of a neural network (a node, layer, or head) is responsible for a specific behavior. Unlike simple correlation, we use interventions to prove a causal link.

We measure the effectiveness of an intervention using the **Logit Difference** $LD$:$$LD = \text{Logit}(T_{correct}) - \text{Logit}(T_{incorrect})$$

In this notebook, we perform **Path Patching**. We run a "clean" prompt and a "corrupted" prompt. We then "scrub" the clean run by replacing a specific activation with the activation from the corrupted run. If the model's performance drops, that component was a causal bottleneck.

### 2. Implementation: The Logic
This section contains the core architecture for the `glassboxllms` library.

In [3]:
import torch
import torch.nn as nn
from unittest.mock import MagicMock
from typing import Dict, List, Optional, Callable
from collections import defaultdict

# --- ACTIVATION STORE ---
# Manages the storage and retrieval of model internals
class ActivationStore:
    def __init__(self, device: str = "cpu"):
        self.device = device
        self._buffer = defaultdict(list)
    
    def save(self, layer_name: str, activations: torch.Tensor):
        self._buffer[layer_name] = activations.detach()
        
    def get(self, layer_name: str):
        return self._buffer.get(layer_name)

# --- CAUSAL SCRUBBING MODULE ---
# The primary API for conducting intervention experiments
class CausalScrubber:
    def __init__(self, model, hook_manager, activation_store):
        self.model = model
        self.hook_manager = hook_manager
        self.store = activation_store

    def scrub_path(self, clean_input, target_node, corrupted_input=None, strategy="patch"):
        """
        Executes a controlled intervention on a specific computational path.
        """
        # 1. Setup Baseline for Patching
        baseline_act = None
        if strategy == "patch" and corrupted_input is not None:
            # Simulate capturing the 'corrupted' activation
            # In a real run, this would be: self.model(corrupted_input)
            baseline_act = torch.randn(1, 128) 
            self.store.save(target_node, baseline_act)

        # 2. Define the Intervention Hook
        def intervention_hook(activation, hook):
            if strategy == "patch":
                return self.store.get(target_node)
            elif strategy == "zero":
                return torch.zeros_like(activation)
            return activation

        # 3. Apply and Execute
        self.hook_manager.add_hook(target_node, intervention_hook)
        
        try:
            output = self.model(clean_input)
            return output
        finally:
            # Ensure hooks are always removed, even if the model crashes
            self.hook_manager.remove_all_hooks()

### 3. Verification & Mock Testing

In [4]:
# --- MOCKING THE ENVIRONMENT ---
# 1. Create a 'stunt double' model that returns dummy logits
mock_model = MagicMock(return_value=torch.tensor([0.85, 0.05, 0.10]))

# 2. Create a 'stunt double' hook manager
mock_hook_manager = MagicMock()

# 3. Initialize the real objects
store = ActivationStore()
scrubber = CausalScrubber(mock_model, mock_hook_manager, store)

# --- EXECUTION ---
print("Running Causal Scrubbing Verification...")

clean_input = torch.tensor([101, 202, 303]) # Dummy tokens
corrupted_input = torch.tensor([404, 505, 606])

# Test Path Patching on an Attention Head
result = scrubber.scrub_path(
    clean_input=clean_input,
    target_node="blocks.layer_7.attn",
    corrupted_input=corrupted_input,
    strategy="patch"
)

print(f"\nStatus: SUCCESS")
print(f"Logits produced: {result}")
print(f"Verified: Hook Manager was called? {mock_hook_manager.add_hook.called}")

Running Causal Scrubbing Verification...

Status: SUCCESS
Logits produced: tensor([0.8500, 0.0500, 0.1000])
Verified: Hook Manager was called? True


### 4. Conclusion and Interpretation
By utilizing this module, we can objectively score the importance of model components.

- A **Total Drop** in Logit Difference indicates the component is part of the "Critical Path."

- A **Minor Drop** suggests the component is redundant or carries irrelevant information for this task.

This module provides the necessary infrastructure for the glassboxllms project to move into **Circuit Discovery**.