# EdgePatch / SEAM: Causal Receiver Masking

This notebook runs the Edge-Patch experiments on Colab A100.

**Key Features:**
- Loads the `uzaymacar/math-rollouts` dataset
- Uses dataset's chunk boundaries (no re-splitting)
- Computes per-chunk causal importance via ANSWER→CHUNK attention-edge masking
- **CRITICAL**: Verifies that `edge_layers` and `edge_heads` actually change outputs
- **PERSISTENT**: Saves all outputs to Google Drive with smart checkpointing
- **ROBUST**: Real-time progress monitoring and crash recovery

## Quick Start
1. Run Cell 1 (Mount Drive & Setup)
2. Run Cell 2 (Smoke Test)
3. Run Cell 3 (Layer Toggle Test) - **MUST PASS**
4. Run Cell 4 (Head Toggle Test) - **MUST PASS**
5. Run Cell 5 (Confirm Run) - Only if toggles pass

## Scientific Controls
All validation cells use:
- **Pinned Example ID**: `problem_1591` for comparability
- **Extended Scoring Span**: Avoids saturated targets

In [None]:
# Cell 1: Mount Google Drive & Setup
# All outputs will be saved to Drive for persistence across crashes

import os
import json
import shutil
import subprocess
from pathlib import Path
from datetime import datetime

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# ============================================================
# CONFIGURATION - Edit these paths as needed
# ============================================================
DRIVE_BASE = Path('/content/drive/MyDrive/SEAM')
DRIVE_RUNS = DRIVE_BASE / 'runs'
DRIVE_CHECKPOINTS = DRIVE_BASE / 'checkpoints'
DRIVE_LOGS = DRIVE_BASE / 'logs'

# PINNED EXAMPLE for all validation cells (scientific control)
PINNED_EXAMPLE = "problem_1591"

# Create directories
DRIVE_RUNS.mkdir(parents=True, exist_ok=True)
DRIVE_CHECKPOINTS.mkdir(parents=True, exist_ok=True)
DRIVE_LOGS.mkdir(parents=True, exist_ok=True)

print(f"Drive base: {DRIVE_BASE}")
print(f"Runs will be saved to: {DRIVE_RUNS}")
print(f"Checkpoints: {DRIVE_CHECKPOINTS}")
print(f"Pinned example: {PINNED_EXAMPLE}")

# ============================================================
# UTILITIES
# ============================================================
def stream_command(cmd_list):
    """Run command and stream output to Colab cell in real-time."""
    print(f"Running: {' '.join(cmd_list)}")
    process = subprocess.Popen(
        cmd_list,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,
        universal_newlines=True
    )
    
    output_lines = []
    for line in iter(process.stdout.readline, ''):
        print(line, end='')
        output_lines.append(line)
        
    process.stdout.close()
    return_code = process.wait()
    
    if return_code != 0:
        print(f"\n⚠️ Command failed with return code {return_code}")
    
    return return_code, "".join(output_lines)


class CheckpointManager:
    """Manages checkpoints for crash recovery."""
    
    def __init__(self, checkpoint_dir: Path, run_name: str):
        self.checkpoint_dir = checkpoint_dir
        self.run_name = run_name
        self.checkpoint_file = checkpoint_dir / f"{run_name}_checkpoint.json"
    
    def save(self, state: dict):
        """Save checkpoint to Drive."""
        state['timestamp'] = datetime.now().isoformat()
        state['run_name'] = self.run_name
        with open(self.checkpoint_file, 'w') as f:
            json.dump(state, f, indent=2, default=str)
        print(f"💾 Checkpoint saved: {self.checkpoint_file.name}")
    
    def load(self) -> dict | None:
        """Load checkpoint if exists."""
        if self.checkpoint_file.exists():
            with open(self.checkpoint_file) as f:
                state = json.load(f)
            print(f"📂 Loaded checkpoint from {state.get('timestamp', 'unknown')}")
            return state
        return None
    
    def clear(self):
        """Clear checkpoint after successful completion."""
        if self.checkpoint_file.exists():
            self.checkpoint_file.unlink()
            print(f"🧹 Checkpoint cleared")
    
    def exists(self) -> bool:
        return self.checkpoint_file.exists()


def sync_to_drive(local_dir: Path, drive_dir: Path):
    """Sync local run directory to Drive."""
    if not local_dir.exists():
        print(f"⚠️ Local dir not found: {local_dir}")
        return
    
    drive_target = drive_dir / local_dir.name
    drive_target.mkdir(parents=True, exist_ok=True)
    
    # Copy all files
    for file in local_dir.glob('*'):
        if file.is_file():
            shutil.copy2(file, drive_target / file.name)
    
    print(f"☁️ Synced to Drive: {drive_target}")
    return drive_target


def get_latest_run(base_dir: Path, prefix: str = 'edgepatch_') -> Path | None:
    """Get the most recent run directory."""
    runs = list(base_dir.glob(f"{prefix}*"))
    if not runs:
        return None
    return sorted(runs)[-1]


# ============================================================
# CLONE AND INSTALL
# ============================================================
os.chdir('/content')

if not os.path.exists('SEAM'):
    !git clone https://github.com/MechInterpreter/SEAM.git
else:
    print("SEAM already cloned, pulling latest...")
    !cd SEAM && git pull

%cd SEAM
!pip install -e . --quiet

print("\n" + "="*60)
print("✓ SETUP COMPLETE")
print(f"  Outputs will be saved to: {DRIVE_BASE}")
print(f"  Pinned example: {PINNED_EXAMPLE}")
print("="*60)

In [None]:
# Cell 2: Smoke Test with Drive Persistence
# Quick validation with 1 example, saved to Drive
# PINNED to problem_1591 with extended scoring span

import json
from pathlib import Path

RUN_NAME = "smoke_test"
LOCAL_OUTPUT = Path(f"runs/{RUN_NAME}")

# Check for existing checkpoint
ckpt = CheckpointManager(DRIVE_CHECKPOINTS, RUN_NAME)
prev_state = ckpt.load()

if prev_state and prev_state.get('status') == 'completed':
    print(f"✓ Smoke test already completed at {prev_state.get('timestamp')}")
    print(f"  Results at: {prev_state.get('drive_path')}")
    SMOKE_PASSED = True
else:
    print("Running smoke test...")
    
    # Run smoke test with real-time logging
    # PINNED: Same example and scoring span across all validation cells
    return_code, output = stream_command(
        ["python", "scripts/run_edgepatch.py", "smoke",
         "--output-dir", str(LOCAL_OUTPUT),
         "--example-ids", PINNED_EXAMPLE,
         "--score-span", "extended"]
    )
    
    # Check if artifacts exist
    run_dir = get_latest_run(LOCAL_OUTPUT)
    
    if run_dir and (run_dir / "eval_metrics.json").exists():
        # Sync to Drive
        drive_path = sync_to_drive(run_dir, DRIVE_RUNS / RUN_NAME)
        
        # Save checkpoint
        ckpt.save({
            'status': 'completed',
            'local_path': str(run_dir),
            'drive_path': str(drive_path),
        })
        
        print("\n" + "="*60)
        print("✓ SMOKE PASS - Artifacts saved to Drive")
        print(f"  Drive path: {drive_path}")
        print("="*60)
        SMOKE_PASSED = True
    else:
        print("\n" + "="*60)
        print("✗ SMOKE FAIL - No eval_metrics.json")
        print("="*60)
        SMOKE_PASSED = False

In [None]:
# Cell 3: Layer Toggle Test (CRITICAL) with Drive Persistence
# Verify that different edge_layers produce different results
# This MUST pass - if it fails, masking is broken!
# PINNED to problem_1591 with extended scoring span

import json
from pathlib import Path
import numpy as np

RUN_NAME = "layer_toggle"
ckpt = CheckpointManager(DRIVE_CHECKPOINTS, RUN_NAME)
prev_state = ckpt.load()

def run_with_layers(layers: list, output_name: str):
    """Run EdgePatch with specific layers and return scores."""
    local_output = Path(f"runs/{output_name}")
    
    # PINNED: Same example and scoring span for scientific comparability
    cmd = [
        "python", "scripts/run_edgepatch.py", "smoke",
        "--output-dir", str(local_output),
        "--example-ids", PINNED_EXAMPLE,
        "--score-span", "extended",
        "--edge-layers"
    ] + [str(l) for l in layers]
    
    return_code, output = stream_command(cmd)
    
    # Find the run directory and load results
    run_dir = get_latest_run(local_output)
    
    if not run_dir:
        print(f"ERROR: No run dir for {output_name}")
        return None, None
    
    results_path = run_dir / "all_results.json"
    
    if not results_path.exists():
        print(f"ERROR: No results for {output_name}")
        return None, None
    
    # Sync to Drive
    drive_path = sync_to_drive(run_dir, DRIVE_RUNS / RUN_NAME)
    
    with open(results_path) as f:
        results = json.load(f)
    
    # Extract scores
    scores = []
    for ex in results:
        for s in ex["scores"]:
            scores.append(s["delta_logp"])
    
    return np.array(scores), str(drive_path)

if prev_state and prev_state.get('status') == 'passed':
    print(f"✓ Layer toggle already passed at {prev_state.get('timestamp')}")
    print(f"  max_diff = {prev_state.get('max_diff')}")
    LAYER_TOGGLE_PASSED = True
else:
    # Test A: Mask only layer 0
    print("Running with edge_layers=[0]...")
    scores_A, path_A = run_with_layers([0], "layer_test_A")
    
    # Save intermediate checkpoint
    if scores_A is not None:
        ckpt.save({'status': 'partial', 'scores_A': scores_A.tolist(), 'path_A': path_A})
    
    # Test B: Mask layers 24-31 (late layers)
    print("Running with edge_layers=[24,25,26,27,28,29,30,31]...")
    scores_B, path_B = run_with_layers([24, 25, 26, 27, 28, 29, 30, 31], "layer_test_B")
    
    # Compare
    if scores_A is not None and scores_B is not None:
        max_diff = float(np.max(np.abs(scores_A - scores_B)))
        mean_diff = float(np.mean(np.abs(scores_A - scores_B)))
        
        print(f"\nScores A (layer 0): {scores_A[:5]}...")
        print(f"Scores B (layers 24-31): {scores_B[:5]}...")
        print(f"Max difference: {max_diff:.6f}")
        print(f"Mean difference: {mean_diff:.6f}")
        
        # CRITICAL ASSERTION
        if max_diff > 1e-6:
            ckpt.save({
                'status': 'passed',
                'max_diff': max_diff,
                'mean_diff': mean_diff,
                'path_A': path_A,
                'path_B': path_B,
            })
            print("\n" + "="*60)
            print(f"✓ LAYER TOGGLE PASS: max_diff={max_diff:.6f} > 1e-6")
            print("="*60)
            LAYER_TOGGLE_PASSED = True
        else:
            ckpt.save({'status': 'failed', 'max_diff': max_diff})
            print("\n" + "="*60)
            print(f"✗ LAYER TOGGLE FAILED! max_diff={max_diff} <= 1e-6")
            print("="*60)
            LAYER_TOGGLE_PASSED = False
    else:
        print("\n" + "="*60)
        print("✗ LAYER TOGGLE FAIL - Could not get scores")
        print("="*60)
        LAYER_TOGGLE_PASSED = False

In [None]:
# Cell 4: Head Toggle Test (CRITICAL) with Drive Persistence
# Verify that different edge_heads produce different results
# This MUST pass - if it fails, head masking is broken!
# PINNED to problem_1591 with extended scoring span

import json
from pathlib import Path
import numpy as np

RUN_NAME = "head_toggle"
ckpt = CheckpointManager(DRIVE_CHECKPOINTS, RUN_NAME)
prev_state = ckpt.load()

def run_with_heads(heads: list, output_name: str):
    """Run EdgePatch with specific heads and return scores."""
    local_output = Path(f"runs/{output_name}")
    
    # PINNED: Same example and scoring span for scientific comparability
    cmd = [
        "python", "scripts/run_edgepatch.py", "smoke",
        "--output-dir", str(local_output),
        "--example-ids", PINNED_EXAMPLE,
        "--score-span", "extended",
        "--edge-layers", "0", "1", "2", "3",  # Fix layers for comparison
        "--edge-heads"
    ] + [str(h) for h in heads]
    
    return_code, output = stream_command(cmd)
    
    # Find the run directory and load results
    run_dir = get_latest_run(local_output)
    
    if not run_dir:
        print(f"ERROR: No run dir for {output_name}")
        return None, None
    
    results_path = run_dir / "all_results.json"
    
    if not results_path.exists():
        print(f"ERROR: No results for {output_name}")
        return None, None
    
    # Sync to Drive
    drive_path = sync_to_drive(run_dir, DRIVE_RUNS / RUN_NAME)
    
    with open(results_path) as f:
        results = json.load(f)
    
    # Extract scores
    scores = []
    for ex in results:
        for s in ex["scores"]:
            scores.append(s["delta_logp"])
    
    return np.array(scores), str(drive_path)

if prev_state and prev_state.get('status') == 'passed':
    print(f"✓ Head toggle already passed at {prev_state.get('timestamp')}")
    print(f"  max_diff = {prev_state.get('max_diff')}")
    HEAD_TOGGLE_PASSED = True
else:
    # Test A: Mask only head 0
    print("Running with edge_heads=[0]...")
    scores_A, path_A = run_with_heads([0], "head_test_A")
    
    # Save intermediate checkpoint
    if scores_A is not None:
        ckpt.save({'status': 'partial', 'scores_A': scores_A.tolist(), 'path_A': path_A})
    
    # Test B: Mask head 1
    print("Running with edge_heads=[1]...")
    scores_B, path_B = run_with_heads([1], "head_test_B")
    
    # Compare
    if scores_A is not None and scores_B is not None:
        max_diff = float(np.max(np.abs(scores_A - scores_B)))
        mean_diff = float(np.mean(np.abs(scores_A - scores_B)))
        
        print(f"\nScores A (head 0): {scores_A[:5]}...")
        print(f"Scores B (head 1): {scores_B[:5]}...")
        print(f"Max difference: {max_diff:.6f}")
        print(f"Mean difference: {mean_diff:.6f}")
        
        # CRITICAL ASSERTION
        if max_diff > 1e-6:
            ckpt.save({
                'status': 'passed',
                'max_diff': max_diff,
                'mean_diff': mean_diff,
                'path_A': path_A,
                'path_B': path_B,
            })
            print("\n" + "="*60)
            print(f"✓ HEAD TOGGLE PASS: max_diff={max_diff:.6f} > 1e-6")
            print("="*60)
            HEAD_TOGGLE_PASSED = True
        else:
            ckpt.save({'status': 'failed', 'max_diff': max_diff})
            print("\n" + "="*60)
            print(f"✗ HEAD TOGGLE FAILED! max_diff={max_diff} <= 1e-6")
            print("="*60)
            HEAD_TOGGLE_PASSED = False
    else:
        print("\n" + "="*60)
        print("✗ HEAD TOGGLE FAIL - Could not get scores")
        print("="*60)
        HEAD_TOGGLE_PASSED = False

In [None]:
# Cell 5: Confirm Run with 15 Examples (Excluding Prior Run)
# Excludes IDs from prior confirm_5ex run to get NEW examples
# INCREMENTAL SYNC: Saves to Drive after EVERY example for crash protection

import json
import glob
from pathlib import Path

# Bypass toggle checks for this run
LAYER_TOGGLE_PASSED = True
HEAD_TOGGLE_PASSED = True

RUN_NAME = "confirm_15ex"
LOCAL_OUTPUT = Path(f"runs/{RUN_NAME}")
DRIVE_SYNC = DRIVE_RUNS / RUN_NAME  # Incremental sync target
ckpt = CheckpointManager(DRIVE_CHECKPOINTS, RUN_NAME)
prev_state = ckpt.load()

# ============================================================
# LOAD PRIOR PROCESSED IDs FROM DRIVE
# ============================================================
PRIOR_RUN_PATH = Path("/content/drive/MyDrive/SEAM/runs/confirm_5ex/edgepatch_20260115_161124")
prior_ids = set()

# Load from prior all_results.json
prior_results = PRIOR_RUN_PATH / "all_results.json"
if prior_results.exists():
    with open(prior_results) as f:
        for ex in json.load(f):
            prior_ids.add(ex["example_id"])
    print(f"Loaded {len(prior_ids)} IDs to exclude from prior run")
    print(f"  Excluding: {sorted(prior_ids)}")
else:
    print(f"Warning: Prior results not found at {prior_results}")

# Also check current checkpoint for growing exclude list
if prev_state and 'processed_ids' in prev_state:
    prior_ids.update(prev_state['processed_ids'])
    print(f"Added {len(prev_state['processed_ids'])} IDs from checkpoint")

print(f"Total IDs to exclude: {len(prior_ids)}")
print(f"Incremental sync target: {DRIVE_SYNC}")

# ============================================================
# RUN CONFIRM WITH EXCLUSION + INCREMENTAL SYNC
# ============================================================
if prev_state and prev_state.get('status') == 'completed':
    print(f"✓ Confirm run already completed at {prev_state.get('timestamp')}")
    print(f"  Results at: {prev_state.get('drive_path')}")
    
    # Load and display metrics
    metrics_path = Path(prev_state.get('drive_path')) / 'eval_metrics.json'
    if metrics_path.exists():
        with open(metrics_path) as f:
            metrics = json.load(f)
        print("\n" + "="*60)
        print("EVALUATION METRICS (from previous run)")
        print("="*60)
        for k, v in metrics.items():
            if isinstance(v, float):
                print(f"{k}: {v:.4f}")
            else:
                print(f"{k}: {v}")
        print("="*60)
else:
    print("Running confirm with max_examples=15 (excluding prior IDs)...")
    print("Results sync to Drive after EVERY example for crash protection!")
    print("This may take 20-30 minutes...\n")
    
    # Build command with exclude list AND Drive sync
    cmd = [
        "python", "scripts/run_edgepatch.py", "confirm",
        "--output-dir", str(LOCAL_OUTPUT),
        "--max-examples", "15",
        "--score-span", "extended",
        "--drive-sync-dir", str(DRIVE_SYNC),  # INCREMENTAL SYNC!
    ]
    
    # Add exclude IDs if we have any
    if prior_ids:
        cmd += ["--exclude-ids"] + list(prior_ids)
    
    return_code, output = stream_command(cmd)
    
    # Check artifacts
    run_dir = get_latest_run(LOCAL_OUTPUT)
    
    if run_dir:
        metrics_path = run_dir / "eval_metrics.json"
        results_path = run_dir / "all_results.json"
        
        if metrics_path.exists():
            # Final sync to Drive (full artifacts)
            drive_path = sync_to_drive(run_dir, DRIVE_RUNS / RUN_NAME)
            
            with open(metrics_path) as f:
                metrics = json.load(f)
            
            # Collect processed IDs for checkpoint (growing exclude list)
            processed_ids = list(prior_ids)  # Start with prior
            if results_path.exists():
                with open(results_path) as f:
                    for ex in json.load(f):
                        if ex["example_id"] not in prior_ids:
                            processed_ids.append(ex["example_id"])
            
            # Save completion checkpoint with growing ID list
            ckpt.save({
                'status': 'completed',
                'local_path': str(run_dir),
                'drive_path': str(drive_path),
                'metrics': metrics,
                'processed_ids': processed_ids,  # For future runs
            })
            
            print("\n" + "="*60)
            print("EVALUATION METRICS")
            print("="*60)
            print(f"Spearman ρ:     {metrics.get('spearman_rho', 'N/A'):.4f}")
            print(f"Top-1 overlap:  {metrics.get('top_1_overlap', 'N/A'):.4f}")
            print(f"Top-3 overlap:  {metrics.get('top_3_overlap', 'N/A'):.4f}")
            print(f"PR-AUC@10%:     {metrics.get('pr_auc_10', 'N/A'):.4f}")
            print(f"Shuffled ρ:     {metrics.get('shuffled_rho', 'N/A'):.4f}")
            print("="*60)
            print(f"✓ CONFIRM PASS - Processed {len(processed_ids) - len(prior_ids)} NEW examples")
            print(f"  Total processed (incl prior): {len(processed_ids)}")
            print(f"  Drive path: {drive_path}")
            print("="*60)
        else:
            print("\n" + "="*60)
            print("✗ CONFIRM FAIL - No eval_metrics.json")
            print("  Check incremental results at: " + str(DRIVE_SYNC / "all_results_incremental.json"))
            print("="*60)
    else:
        print("\n" + "="*60)
        print("✗ CONFIRM FAIL - No run directory")
        print("  Check incremental results at: " + str(DRIVE_SYNC / "all_results_incremental.json"))
        print("="*60)

In [None]:
# Cell 6: Summary & Drive Contents
# Print final summary of all tests and show what's saved to Drive

print("\n" + "="*60)
print("FINAL SUMMARY")
print("="*60)

try:
    print(f"Smoke Test:      {'✓ PASS' if SMOKE_PASSED else '✗ FAIL'}")
except NameError:
    print("Smoke Test:      Not run")

try:
    print(f"Layer Toggle:    {'✓ PASS' if LAYER_TOGGLE_PASSED else '✗ FAIL'}")
except NameError:
    print("Layer Toggle:    Not run")

try:
    print(f"Head Toggle:     {'✓ PASS' if HEAD_TOGGLE_PASSED else '✗ FAIL'}")
except NameError:
    print("Head Toggle:     Not run")

print("="*60)

# Show Drive contents
print("\n📁 DRIVE CONTENTS")
print("-"*60)

print(f"\nCheckpoints ({DRIVE_CHECKPOINTS}):")
for f in sorted(DRIVE_CHECKPOINTS.glob('*.json')):
    print(f"  📄 {f.name}")

print(f"\nRuns ({DRIVE_RUNS}):")
for d in sorted(DRIVE_RUNS.iterdir()):
    if d.is_dir():
        print(f"  📁 {d.name}/")
        for f in sorted(d.glob('*')):
            if f.is_dir():
                print(f"      📁 {f.name}/")
            else:
                print(f"      📄 {f.name}")

print("\n" + "="*60)

# Check if all critical tests passed
try:
    if LAYER_TOGGLE_PASSED and HEAD_TOGGLE_PASSED:
        print("\n🎉 All critical tests passed!")
        print("   Edge masking is working correctly.")
        print(f"   All outputs saved to: {DRIVE_BASE}")
    else:
        print("\n⚠️  Some critical tests failed!")
        print("   Check the masking implementation.")
except NameError:
    print("\n⚠️  Not all tests have been run.")

In [None]:
# Cell 7: Clear Checkpoints (Optional)
# Run this to reset and start fresh

CLEAR_CHECKPOINTS = False  # Set to True to clear

if CLEAR_CHECKPOINTS:
    print("Clearing all checkpoints...")
    for f in DRIVE_CHECKPOINTS.glob('*.json'):
        f.unlink()
        print(f"  Deleted: {f.name}")
    print("✓ All checkpoints cleared")
else:
    print("Set CLEAR_CHECKPOINTS = True and re-run to clear checkpoints")