# Mechanistic Interpretability of CTM on Maze Navigation

## Research Hypothesis

The Continuous Thought Machine (CTM) constructs a **"Virtual Coordinate System"** dynamically within the Synchronization Matrix ($S_t$) when solving 2D mazes **without positional embeddings**.

Specific clusters of neurons likely fire only when the agent "imagines" itself at specific $(x,y)$ coordinates in the maze - similar to **Place Cells** in the hippocampus.

---

## Notebook Overview

1. **Setup**: Load model and data
2. **Visualization**: Explore internal states across ticks
3. **Place Cell Analysis**: Find neurons that correlate with positions
4. **Intervention**: Test causal role of position-encoding neurons

In [None]:
# Standard imports
import os
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from pathlib import Path
from collections import defaultdict

# Add project root to path
PROJECT_ROOT = Path.cwd().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

# Project imports
from models.ctm import ContinuousThoughtMachine
from data.custom_datasets import MazeImageFolder

# Optional wandb
try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

# Plotting settings
sns.set_style('darkgrid')
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['figure.dpi'] = 100

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 1. Configuration

In [None]:
# Configuration
CONFIG = {
    # Model
    'checkpoint_path': str(PROJECT_ROOT / 'checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt'),
    
    # Data
    'data_root': str(PROJECT_ROOT / 'data/mazes'),
    'maze_size': 'medium',  # 'small', 'medium', or 'large'
    'num_samples': 50,
    'batch_size': 8,
    
    # Analysis
    'num_top_neurons': 20,
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 2. Load Model

The CTM architecture for maze solving:
- **Backbone**: ResNet-based feature extractor (no positional embeddings!)
- **Synapse Model**: U-Net style MLP for cross-neuron communication
- **NLMs**: Private MLPs per neuron processing activation history
- **Output**: Synchronization-based predictions

Key insight: `positional_embedding_type = 'none'` confirms no explicit position info!

In [None]:
def load_model(checkpoint_path, device):
    """Load CTM model from checkpoint."""
    print(f"Loading checkpoint: {checkpoint_path}")
    
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(
            f"Checkpoint not found: {checkpoint_path}\n"
            f"Download from: https://drive.google.com/drive/folders/1vSg8T7FqP-guMDk1LU7_jZaQtXFP9sZg"
        )
    
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model_args = checkpoint['args']
    
    # Handle legacy arguments
    if not hasattr(model_args, 'backbone_type'):
        model_args.backbone_type = f'{model_args.resnet_type}-{model_args.resnet_feature_scales[-1]}'
    if not hasattr(model_args, 'neuron_select_type'):
        model_args.neuron_select_type = 'first-last'
    if not hasattr(model_args, 'n_random_pairing_self'):
        model_args.n_random_pairing_self = 0
    
    # Print key config
    print(f"\nModel Configuration:")
    print(f"  d_model (neurons): {model_args.d_model}")
    print(f"  iterations (ticks): {model_args.iterations}")
    print(f"  memory_length: {model_args.memory_length}")
    print(f"  positional_embedding_type: {model_args.positional_embedding_type}")
    
    prediction_reshaper = [model_args.out_dims // 5, 5]
    
    model = ContinuousThoughtMachine(
        iterations=model_args.iterations,
        d_model=model_args.d_model,
        d_input=model_args.d_input,
        heads=model_args.heads,
        n_synch_out=model_args.n_synch_out,
        n_synch_action=model_args.n_synch_action,
        synapse_depth=model_args.synapse_depth,
        memory_length=model_args.memory_length,
        deep_nlms=model_args.deep_memory,
        memory_hidden_dims=model_args.memory_hidden_dims,
        do_layernorm_nlm=model_args.do_normalisation,
        backbone_type=model_args.backbone_type,
        positional_embedding_type=model_args.positional_embedding_type,
        out_dims=model_args.out_dims,
        prediction_reshaper=prediction_reshaper,
        dropout=0,
        neuron_select_type=model_args.neuron_select_type,
        n_random_pairing_self=model_args.n_random_pairing_self,
    ).to(device)
    
    state_dict_key = 'state_dict' if 'state_dict' in checkpoint else 'model_state_dict'
    model.load_state_dict(checkpoint[state_dict_key], strict=False)
    model.eval()
    
    return model, model_args

try:
    model, model_args = load_model(CONFIG['checkpoint_path'], device)
    print(f"\n✓ Model loaded successfully with {sum(p.numel() for p in model.parameters()):,} parameters")
except FileNotFoundError as e:
    print(f"\n✗ {e}")
    model = None

## 3. Load Maze Data

Maze encoding:
- **Red (1,0,0)**: Start position
- **Green (0,1,0)**: Goal position
- **Black (0,0,0)**: Walls
- **White (1,1,1)**: Walkable path

Output: Sequence of moves [0=Up, 1=Down, 2=Left, 3=Right, 4=Wait]

In [None]:
def load_maze_data(data_root, maze_size, num_samples, batch_size):
    """Load maze dataset."""
    data_path = f"{data_root}/{maze_size}/test"
    
    if not os.path.exists(data_path):
        raise FileNotFoundError(
            f"Maze data not found: {data_path}\n"
            f"Download from: https://drive.google.com/file/d/1cBgqhaUUtsrll8-o2VY42hPpyBcfFv86/view"
        )
    
    dataset = MazeImageFolder(
        root=data_path,
        which_set='test',
        maze_route_length=100,
        expand_range=True,
        trunc=True if num_samples < 1000 else False
    )
    
    if len(dataset) > num_samples:
        dataset = torch.utils.data.Subset(dataset, list(range(num_samples)))
    
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=0
    )
    
    return loader, dataset

try:
    dataloader, dataset = load_maze_data(
        CONFIG['data_root'], CONFIG['maze_size'], 
        CONFIG['num_samples'], CONFIG['batch_size']
    )
    print(f"\n✓ Loaded {len(dataset)} mazes")
except FileNotFoundError as e:
    print(f"\n✗ {e}")
    dataloader = None

## 4. Visualize Sample Maze

In [None]:
if dataloader is not None:
    # Get a sample maze
    inputs, targets = next(iter(dataloader))
    
    # Convert from [-1, 1] to [0, 1] for display
    maze_img = ((inputs[0].numpy() + 1) / 2).transpose(1, 2, 0)
    solution = targets[0].numpy()
    
    # Find actual path length (before padding)
    path_length = (solution != 4).sum()
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Show maze
    axes[0].imshow(maze_img)
    axes[0].set_title(f'Maze ({CONFIG["maze_size"]})')
    axes[0].axis('off')
    
    # Show solution statistics
    move_names = ['Up', 'Down', 'Left', 'Right', 'Wait']
    move_counts = [(solution == i).sum() for i in range(5)]
    
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7']
    axes[1].bar(move_names, move_counts, color=colors)
    axes[1].set_title(f'Solution Moves (Path length: {path_length})')
    axes[1].set_ylabel('Count')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nMaze shape: {maze_img.shape}")
    print(f"Solution: {solution[:path_length].tolist()}")

## 5. Collect Internal States

Run the model with `track=True` to capture:
- **Pre-activations** ($a_t$): Input to NLMs
- **Post-activations** ($z_t$): Output of NLMs - **KEY for our analysis!**
- **Synchronization** ($S_t$): Pairwise neuron correlations
- **Attention**: Where the model looks in the input

In [None]:
def collect_internal_states(model, dataloader, device, max_batches=None):
    """Collect internal states from model."""
    all_states = {
        'pre_activations': [],   # (T, B, D) per batch
        'post_activations': [],  # (T, B, D) per batch  
        'synch_out': [],         # (T, B, S) per batch
        'attention': [],         # (T, B, H, Hf, Wf) per batch
        'predictions': [],       # (B, out_dims, T) per batch
        'mazes': [],             # (B, H, W, 3) per batch
        'solutions': [],         # (B, route_len) per batch
    }
    
    model.eval()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(dataloader, desc="Collecting states")):
            if max_batches and batch_idx >= max_batches:
                break
                
            inputs = inputs.to(device)
            
            # Run with tracking
            results = model(inputs, track=True)
            predictions, certainties, (synch_out, synch_action), pre_act, post_act, attention = results
            
            # Store
            all_states['pre_activations'].append(pre_act)
            all_states['post_activations'].append(post_act)
            all_states['synch_out'].append(synch_out)
            all_states['attention'].append(attention)
            all_states['predictions'].append(predictions.cpu().numpy())
            all_states['mazes'].append(((inputs.cpu().numpy() + 1) / 2).transpose(0, 2, 3, 1))
            all_states['solutions'].append(targets.numpy())
    
    return all_states

if model is not None and dataloader is not None:
    states = collect_internal_states(model, dataloader, device, max_batches=5)
    
    # Print shapes
    print("\nCollected State Shapes:")
    for key in ['pre_activations', 'post_activations', 'synch_out', 'attention']:
        if states[key]:
            print(f"  {key}: {states[key][0].shape} (per batch)")
else:
    print("Model or data not loaded. Skipping state collection.")

## 6. Visualize Neuron Dynamics Across Ticks

How do neuron activations evolve as the model "thinks" through the maze?

In [None]:
if 'states' in dir() and states['post_activations']:
    # Get first batch, first sample
    post_acts = states['post_activations'][0]  # (T, B, D)
    T, B, D = post_acts.shape
    
    sample_idx = 0
    activations = post_acts[:, sample_idx, :]  # (T, D)
    
    # Find most active neurons
    neuron_variance = np.var(activations, axis=0)
    top_neurons = np.argsort(neuron_variance)[-10:]
    
    fig, axes = plt.subplots(2, 1, figsize=(14, 10))
    
    # Heatmap of top neurons over time
    im = axes[0].imshow(activations[:, top_neurons].T, aspect='auto', cmap='RdBu_r')
    axes[0].set_xlabel('Tick (t)')
    axes[0].set_ylabel('Neuron Index')
    axes[0].set_title('Top 10 Most Variable Neurons Over Time')
    axes[0].set_yticks(range(10))
    axes[0].set_yticklabels(top_neurons)
    plt.colorbar(im, ax=axes[0], label='Activation')
    
    # Line plot of individual neurons
    for i, neuron_idx in enumerate(top_neurons[:5]):
        axes[1].plot(activations[:, neuron_idx], label=f'Neuron {neuron_idx}', alpha=0.8)
    
    axes[1].set_xlabel('Tick (t)')
    axes[1].set_ylabel('Activation')
    axes[1].set_title('Activation Traces of Top 5 Neurons')
    axes[1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nModel has {D} neurons (d_model)")
    print(f"Running for {T} internal ticks")
else:
    print("No states collected yet.")

## 7. Place Cell Analysis

Core question: Do specific neurons fire at specific maze positions?

We'll compute a **spatial information score** for each neuron:
- High score = neuron fires consistently at specific (x,y) locations
- Low score = neuron fires uniformly regardless of position

In [None]:
def trace_path(maze_img, solution):
    """Trace the solution path to get (row, col) positions."""
    # Find start (red pixel)
    start_mask = (
        (maze_img[:,:,0] > 0.9) & 
        (maze_img[:,:,1] < 0.1) & 
        (maze_img[:,:,2] < 0.1)
    )
    start_coords = np.argwhere(start_mask)
    
    if len(start_coords) == 0:
        return []
    
    current_pos = list(start_coords[0])
    positions = [tuple(current_pos)]
    
    # Direction mappings
    deltas = {
        0: (-1, 0),  # Up
        1: (1, 0),   # Down
        2: (0, -1),  # Left
        3: (0, 1),   # Right
        4: (0, 0),   # Wait
    }
    
    for move in solution:
        if move == 4:  # Wait/Stop
            positions.append(tuple(current_pos))
        else:
            delta = deltas.get(int(move), (0, 0))
            current_pos[0] += delta[0]
            current_pos[1] += delta[1]
            positions.append(tuple(current_pos))
    
    return positions

if 'states' in dir() and states['post_activations']:
    # Collect position-activation pairs
    position_neuron_activations = defaultdict(lambda: defaultdict(list))
    
    for batch_idx in range(len(states['post_activations'])):
        post_acts = states['post_activations'][batch_idx]  # (T, B, D)
        mazes = states['mazes'][batch_idx]
        solutions = states['solutions'][batch_idx]
        
        T, B, D = post_acts.shape
        
        for sample_idx in range(B):
            positions = trace_path(mazes[sample_idx], solutions[sample_idx])
            if not positions:
                continue
            
            # Map ticks to positions
            for t in range(T):
                pos_idx = min(t * len(positions) // T, len(positions) - 1)
                pos = positions[pos_idx]
                
                # Store activation for each neuron at this position
                for neuron_idx in range(D):
                    position_neuron_activations[pos][neuron_idx].append(
                        post_acts[t, sample_idx, neuron_idx]
                    )
    
    print(f"Collected activations at {len(position_neuron_activations)} unique positions")
else:
    print("No states collected yet.")

In [None]:
if 'position_neuron_activations' in dir() and position_neuron_activations:
    # Compute place cell scores for each neuron
    D = len(list(position_neuron_activations.values())[0])  # Number of neurons
    
    neuron_scores = {}
    neuron_peak_positions = {}
    
    for neuron_idx in tqdm(range(D), desc="Computing place cell scores"):
        pos_means = {}
        pos_vars = {}
        
        for pos, neuron_acts in position_neuron_activations.items():
            if neuron_idx in neuron_acts:
                acts = neuron_acts[neuron_idx]
                pos_means[pos] = np.mean(acts)
                pos_vars[pos] = np.var(acts) if len(acts) > 1 else 0
        
        if pos_means:
            spatial_variance = np.var(list(pos_means.values()))
            within_variance = np.mean(list(pos_vars.values())) + 1e-6
            
            score = spatial_variance / within_variance
            peak_pos = max(pos_means.keys(), key=lambda p: pos_means[p])
            
            neuron_scores[neuron_idx] = score
            neuron_peak_positions[neuron_idx] = peak_pos
    
    # Rank neurons
    ranked_neurons = sorted(neuron_scores.keys(), key=lambda n: neuron_scores[n], reverse=True)
    
    print(f"\nTop 10 Place Cell Neurons:")
    for i, n in enumerate(ranked_neurons[:10]):
        print(f"  {i+1}. Neuron {n}: score={neuron_scores[n]:.4f}, peak={neuron_peak_positions[n]}")
else:
    print("No position data collected yet.")

## 8. Visualize Place Fields

Create heatmaps showing where each top neuron fires strongest.

In [None]:
if 'ranked_neurons' in dir() and ranked_neurons:
    # Get maze size
    maze_size = 39 if CONFIG['maze_size'] in ['small', 'medium'] else 99
    
    # Plot place fields for top neurons
    top_n = min(12, len(ranked_neurons))
    n_cols = 4
    n_rows = (top_n + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    axes = axes.flatten()
    
    for idx, neuron_idx in enumerate(ranked_neurons[:top_n]):
        ax = axes[idx]
        
        # Create place field
        place_field = np.zeros((maze_size, maze_size))
        counts = np.zeros((maze_size, maze_size))
        
        for pos, neuron_acts in position_neuron_activations.items():
            if neuron_idx in neuron_acts:
                row, col = pos
                if 0 <= row < maze_size and 0 <= col < maze_size:
                    place_field[row, col] = np.mean(neuron_acts[neuron_idx])
                    counts[row, col] = 1
        
        place_field = np.ma.masked_where(counts == 0, place_field)
        
        im = ax.imshow(place_field, cmap='hot', aspect='equal')
        ax.set_title(f"Neuron {neuron_idx}\nScore: {neuron_scores[neuron_idx]:.2f}")
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    # Hide empty subplots
    for idx in range(top_n, len(axes)):
        axes[idx].set_visible(False)
    
    plt.suptitle('Place Fields: Neuron Activation vs. Maze Position', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()
else:
    print("No ranking available yet.")

## 9. Synchronization vs Activations: Where is Position Encoded?

**Key Experiment**: Compare linear probe decoding of (x,y) position from:
- **Z_t**: Raw neuron activations (2048 dims)
- **S_out**: Synchronization output (2080 dims) - captures neuron *correlations*

This tests the CTM paper's core claim that the Synchronization Matrix is the key representation.

## 10. Activation Patching: Causal Test of Position Neurons

**Question**: Do the neurons we identified as "position-encoding" actually cause behavioral changes?

**Method**: 
1. Identify "position neurons" (high activation variance across positions)
2. Patch their activations mid-inference (at tick T=5)
3. Compare behavior change to random neuron baseline

**Key insight**: Comparing to random neurons is critical - without this baseline, we can't know if any effect is meaningful.

In [None]:
# Visualize probe results
if 'r2_z' in dir() and 'r2_s' in dir():
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Bar chart comparison
    representations = ['Z_t\n(Activations)', 'S_out\n(Synchronization)', 'Combined']
    r2_scores = [r2_z, r2_s, r2_combined]
    colors = ['#FF6B6B', '#4ECDC4', '#96CEB4']
    
    bars = axes[0].bar(representations, r2_scores, color=colors, edgecolor='black', linewidth=1.5)
    axes[0].set_ylabel('R² Score', fontsize=12)
    axes[0].set_title('Position Decoding: Synchronization vs Activations', fontsize=12)
    axes[0].set_ylim(0, 1)
    
    # Add value labels
    for bar, score in zip(bars, r2_scores):
        axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                    f'{score:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    # Scatter plot: Z_t predictions
    axes[1].scatter(y_test_z[:, 0], y_pred_z[:, 0], alpha=0.5, label='X', c='blue')
    axes[1].scatter(y_test_z[:, 1], y_pred_z[:, 1], alpha=0.5, label='Y', c='red')
    axes[1].plot([0, 60], [0, 60], 'k--', alpha=0.5)
    axes[1].set_xlabel('True Position')
    axes[1].set_ylabel('Predicted Position')
    axes[1].set_title(f'Z_t Predictions (R²={r2_z:.3f})')
    axes[1].legend()
    
    # Scatter plot: S_out predictions
    axes[2].scatter(y_test_s[:, 0], y_pred_s[:, 0], alpha=0.5, label='X', c='blue')
    axes[2].scatter(y_test_s[:, 1], y_pred_s[:, 1], alpha=0.5, label='Y', c='red')
    axes[2].plot([0, 60], [0, 60], 'k--', alpha=0.5)
    axes[2].set_xlabel('True Position')
    axes[2].set_ylabel('Predicted Position')
    axes[2].set_title(f'S_out Predictions (R²={r2_s:.3f})')
    axes[2].legend()
    
    plt.tight_layout()
    plt.savefig(str(PROJECT_ROOT / 'experiments/interpretability/outputs/sync_vs_activations.png'), dpi=150)
    plt.show()
    
    print("\n" + "="*50)
    print("KEY FINDING: Position is encoded in CORRELATIONS")
    print("="*50)
    print(f"S_out (synchronization) R² = {r2_s:.3f}")
    print(f"Z_t (activations) R² = {r2_z:.3f}")
    print(f"Improvement: {((r2_s - r2_z) / r2_z * 100):.1f}%")

In [None]:
from sklearn.linear_model import Ridge
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score

def collect_probe_data(model, dataloader, device, num_samples=500):
    """Collect Z_t and S_out along with position labels for probing."""
    z_t_list = []
    s_out_list = []
    positions_list = []
    
    model.eval()
    samples_collected = 0
    
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Collecting probe data"):
            if samples_collected >= num_samples:
                break
                
            inputs = inputs.to(device)
            B = inputs.shape[0]
            
            # Get internal states
            results = model(inputs, track=True)
            predictions, certainties, (synch_out, synch_action), pre_act, post_act, attention = results
            
            # Z_t: post activations at final tick (T, B, D) -> take last tick
            z_t = post_act[-1]  # (B, D)
            
            # S_out: synch_out at final tick (T, B, S) -> take last tick
            s_out = synch_out[-1]  # (B, S)
            
            # Get positions from mazes
            mazes_np = ((inputs.cpu().numpy() + 1) / 2).transpose(0, 2, 3, 1)
            solutions_np = targets.numpy()
            
            for i in range(B):
                if samples_collected >= num_samples:
                    break
                    
                positions = trace_path(mazes_np[i], solutions_np[i])
                if positions and len(positions) > 0:
                    # Use middle position as representative
                    mid_pos = positions[len(positions) // 2]
                    
                    z_t_list.append(z_t[i].cpu().numpy())
                    s_out_list.append(s_out[i].cpu().numpy())
                    positions_list.append(mid_pos)
                    samples_collected += 1
    
    return np.array(z_t_list), np.array(s_out_list), np.array(positions_list)

def train_position_probe(X, y, name=""):
    """Train Ridge regression probe and return R² score."""
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    probe = Ridge(alpha=1.0)
    probe.fit(X_train, y_train)
    
    y_pred = probe.predict(X_test)
    r2 = r2_score(y_test, y_pred)
    
    print(f"{name} R² = {r2:.4f}")
    return r2, probe, y_test, y_pred

# Run the sync matrix vs activations comparison
if model is not None and dataloader is not None:
    print("Collecting data for probing experiment...")
    z_t_data, s_out_data, positions = collect_probe_data(model, dataloader, device, num_samples=200)
    
    print(f"\nCollected {len(positions)} samples")
    print(f"Z_t shape: {z_t_data.shape}")
    print(f"S_out shape: {s_out_data.shape}")
    print(f"Positions shape: {positions.shape}")
    
    # Train probes
    print("\n" + "="*50)
    print("Position Decoding Results:")
    print("="*50)
    
    r2_z, probe_z, y_test_z, y_pred_z = train_position_probe(z_t_data, positions, "Z_t (activations)")
    r2_s, probe_s, y_test_s, y_pred_s = train_position_probe(s_out_data, positions, "S_out (synchronization)")
    
    # Combined
    combined = np.concatenate([z_t_data, s_out_data], axis=1)
    r2_combined, _, _, _ = train_position_probe(combined, positions, "Combined")
    
    print(f"\nImprovement: S_out is {((r2_s - r2_z) / r2_z * 100):.1f}% better than Z_t")
else:
    print("Model or data not loaded.")

In [None]:
# Save results if analysis was completed
if 'ranked_neurons' in dir() and ranked_neurons:
    output_dir = str(PROJECT_ROOT / 'experiments/interpretability/outputs')
    os.makedirs(output_dir, exist_ok=True)
    
    np.savez(
        f"{output_dir}/notebook_results.npz",
        ranked_neurons=ranked_neurons,
        neuron_scores=dict(neuron_scores),
        neuron_peak_positions=dict(neuron_peak_positions)
    )
    
    print(f"Results saved to: {output_dir}/notebook_results.npz")