# 2. Geometry of Thought: Mechanistic Interpretability

**Objective:** Visualize how neural networks construct complex functions from simple pieces.

Here we shift from *optimizing* loss to *interpreting* structure. We will:
1. **Decompose a Neural Network:** See the individual ReLU "basis functions" that sum up to approximate a Sine wave.
2. **Analyze Topology:** Compare how "Width" (Polynomials) vs "Depth" (Layers) solve the Spiral Classification problem.

In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

sys.path.append(os.path.abspath('..'))

from src.torch_engine import ExplainableReLUNet, SpiralClassifier
from src.data_loader import generate_sine_wave, generate_spiral_data, PolynomialFeatureExpander
from src.visualization import plot_basis_mechanisms, plot_decision_boundary
from src.utils import set_seed

set_seed(2024)

## Part 1: The Sum of ReLUs

The Universal Approximation Theorem states that a hidden layer of sufficient width can approximate any continuous function. But *how*?

We can write the output of a network as:
$$ f(x) = \sum_{i=1}^{N} w_{out}^{(i)} \cdot \text{ReLU}(w_{in}^{(i)} x + b^{(i)}) $$

Each neuron learns a **Scale** ($w_{out}$), a **Slope** ($w_{in}$), and a **Kink** (via $b$). Let's visualize this.

In [None]:
# 1. Data
x_np, y_np = generate_sine_wave(n_samples=200)
x_train = torch.FloatTensor(x_np)
y_train = torch.FloatTensor(y_np).reshape(-1, 1)

# 2. Model (ExplainableReLUNet)
basis_count = 15
model = ExplainableReLUNet(hidden_dim=basis_count)
optimizer = optim.Adam(model.parameters(), lr=0.02)
criterion = nn.MSELoss()

# 3. Train
print(f"Training with {basis_count} basis functions...")
for epoch in range(2000):
    optimizer.zero_grad()
    y_pred = model(x_train)
    loss = criterion(y_pred, y_train)
    loss.backward()
    optimizer.step()
    
    if epoch % 500 == 0:
        print(f"Epoch {epoch}: Loss {loss.item():.5f}")

# 4. Visualize Decomposition
model.eval()
with torch.no_grad():
    x_plot_np = np.linspace(0, 2*np.pi, 500).reshape(-1, 1)
    x_plot = torch.FloatTensor(x_plot_np)
    
    # Get component-wise breakdown
    final_pred, contributions = model.forward_decomposed(x_plot)

plot_basis_mechanisms(
    x_plot_np, 
    contributions.numpy(), 
    final_pred.numpy(), 
    np.sin(x_plot_np)
)

### Interpretation
In the plot above:
1. **Top Panel:** The individual lines are the outputs of specific hidden neurons. Notice they are simple "bent lines" (ReLUs).
2. **Bottom Panel:** The red line is the *sum* of the ghost lines above. It reconstructs the sine wave.

## Part 2: Depth vs. Width (The Spiral)

The Spiral dataset is notoriously hard for linear classifiers because it is not linearly separable. We have two ways to solve this:
1. **Width (Feature Engineering):** Expand input space with polynomials ($x^2, xy, y^2...$).
2. **Depth (Representation Learning):** Use layers to "fold" the space.

In [None]:
# Generate Data
X_spiral, y_spiral = generate_spiral_data(n_points=1000, K=3, sigma=0.2)
X_tensor = torch.FloatTensor(X_spiral)
y_tensor = torch.LongTensor(y_spiral)

print("Data Shape:", X_tensor.shape)

### Experiment A: The Wide Approach (Polynomials)

In [None]:
# 1. Expand Features
poly = PolynomialFeatureExpander(degree=5)
X_poly = poly.transform(X_tensor)
input_dim = X_poly.shape[1]
print(f"Expanded 2D -> {input_dim}D features")

# 2. Linear Classifier on High-Dim Features
poly_model = nn.Linear(input_dim, 3)
optimizer = optim.Adam(poly_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# 3. Train
for epoch in range(1000):
    optimizer.zero_grad()
    out = poly_model(X_poly)
    loss = criterion(out, y_tensor)
    loss.backward()
    optimizer.step()

# 4. Viz Helper for Polynomials
class PolyWrapper(nn.Module):
    def __init__(self, model, expander):
        super().__init__()
        self.model = model
        self.expander = expander
    def forward(self, x):
        x_poly = self.expander.transform(x)
        return self.model(x_poly)

print("Wide Model Decision Boundary:")
plot_decision_boundary(PolyWrapper(poly_model, poly), X_spiral, y_spiral)

### Experiment B: The Deep Approach (MLP)
Notice how the Deep model creates a smoother, more generalized boundary without needing manual feature math.

In [None]:
# 1. Deep Model
deep_model = SpiralClassifier(input_dim=2, hidden_dims=[64, 32], output_dim=3)
optimizer = optim.Adam(deep_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# 2. Train
for epoch in range(1000):
    optimizer.zero_grad()
    out = deep_model(X_tensor)
    loss = criterion(out, y_tensor)
    loss.backward()
    optimizer.step()

print("Deep Model Decision Boundary:")
plot_decision_boundary(deep_model, X_spiral, y_spiral)