# SPON Extensions: Quick Start Demo

This notebook provides an interactive walkthrough of the SPON (Spontaneous Neuron Activation) extension codebase.

**What you'll learn:**
1. How magnitude-based sparsification works (TEAL-style)
2. How SPON biases compensate for information loss
3. How to train SPON biases via KL divergence
4. How to evaluate and compare configurations

**Prerequisites:**
- GPU with at least 8GB VRAM (for LLaMA-3.2-1B)
- HuggingFace account with access to LLaMA models
- Run `pip install -r requirements.txt` first

In [None]:
# Setup: Add project root to path
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

import torch
import matplotlib.pyplot as plt
import numpy as np

# Check GPU
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Understanding Magnitude-Based Sparsification

TEAL (Training-free Activation Sparsification) zeros out small activations to reduce computation.

**Key idea:** Keep only the top-k% activations by absolute value.

In [None]:
from src.sparse_forward import magnitude_sparsify, compute_sparsity_stats

# Create sample activations (simulating a hidden state)
torch.manual_seed(42)
x = torch.randn(1, 10, 64)  # (batch=1, seq_len=10, hidden_dim=64)

# Apply different sparsity levels
sparsity_levels = [0.0, 0.25, 0.5, 0.75]

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for ax, sparsity in zip(axes, sparsity_levels):
    x_sparse = magnitude_sparsify(x, sparsity)
    stats = compute_sparsity_stats(x_sparse)
    
    # Visualize first token's activations
    ax.bar(range(64), x_sparse[0, 0].numpy(), alpha=0.7)
    ax.set_title(f"Sparsity={sparsity:.0%}\n{stats['sparsity']:.1%} zeros")
    ax.set_xlabel("Hidden dimension")
    ax.set_ylabel("Activation value")
    ax.set_ylim(-3, 3)

plt.tight_layout()
plt.suptitle("Effect of Magnitude-Based Sparsification", y=1.02, fontsize=14)
plt.show()

## 2. The Problem: Information Loss

When we zero activations, we lose information. This degrades the model's output.

Let's quantify this with a simple simulation:

In [None]:
# Simulate a linear layer
torch.manual_seed(42)
W = torch.randn(64, 32)  # Weight matrix (hidden_dim -> output_dim)
x = torch.randn(1, 10, 64)  # Input activations

# Dense output (ground truth)
y_dense = x @ W

# Sparse outputs at different sparsity levels
errors = []
sparsities = np.linspace(0, 0.9, 20)

for s in sparsities:
    x_sparse = magnitude_sparsify(x, s)
    y_sparse = x_sparse @ W
    
    # L2 error
    error = torch.norm(y_dense - y_sparse, p=2, dim=-1).mean().item()
    errors.append(error)

plt.figure(figsize=(10, 5))
plt.plot(sparsities * 100, errors, 'b-o', linewidth=2, markersize=6)
plt.xlabel("Sparsity (% zeros)", fontsize=12)
plt.ylabel("L2 Error (vs dense output)", fontsize=12)
plt.title("Information Loss from Sparsification", fontsize=14)
plt.grid(True, alpha=0.3)
plt.axhline(y=0, color='g', linestyle='--', label='Dense (target)')
plt.legend()
plt.show()

print(f"At 50% sparsity: L2 error = {errors[10]:.3f}")
print(f"At 75% sparsity: L2 error = {errors[15]:.3f}")

## 3. The Solution: SPON Biases

SPON adds **learned constant biases** to compensate for information loss.

**Key insight:** The average effect of zeroing activations can be approximated by a constant bias!

$$Y_{sparse} = W \cdot S(X) + b_{spon}$$

Where $b_{spon}$ is trained to minimize: $\text{KL}(P_{dense} || P_{sparse+spon})$

In [None]:
# Simple demonstration: optimal bias = mean of lost activation contribution
sparsity = 0.5
x_sparse = magnitude_sparsify(x, sparsity)
x_lost = x - x_sparse  # What we zeroed out

# The "lost" contribution to output
y_lost = x_lost @ W

# Optimal SPON bias (input-independent approximation)
# This is what SPON training learns!
b_spon_optimal = y_lost.mean(dim=(0, 1))  # Average over batch and sequence

# Compare errors
y_dense = x @ W
y_sparse = x_sparse @ W
y_sparse_spon = y_sparse + b_spon_optimal

error_sparse = torch.norm(y_dense - y_sparse, p=2, dim=-1).mean().item()
error_spon = torch.norm(y_dense - y_sparse_spon, p=2, dim=-1).mean().item()

print(f"L2 Error without SPON: {error_sparse:.3f}")
print(f"L2 Error with SPON:    {error_spon:.3f}")
print(f"Error reduction:       {(1 - error_spon/error_sparse)*100:.1f}%")

## 4. Loading a Real Model

Now let's work with an actual LLM. We'll use LLaMA-3.2-1B for fast iteration.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = "meta-llama/Llama-3.2-1B"

print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded!")
print(f"  Layers: {len(model.model.layers)}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")

## 5. Applying Sparsification with Hooks

We use **forward hooks** to apply sparsification non-destructively.

In [None]:
from src.sparse_forward import register_sparsification_hooks, remove_hooks

# Test input
text = "The capital of France is"
inputs = tokenizer(text, return_tensors="pt").to(model.device)

# Dense forward
with torch.no_grad():
    dense_output = model(**inputs)
    dense_next_token = dense_output.logits[0, -1].argmax()
    print(f"Dense prediction: {tokenizer.decode(dense_next_token)}")

# Sparse forward (50% sparsity)
hooks = register_sparsification_hooks(model, sparsity=0.5, target_modules=["down_proj"])
try:
    with torch.no_grad():
        sparse_output = model(**inputs)
        sparse_next_token = sparse_output.logits[0, -1].argmax()
        print(f"Sparse prediction (50%): {tokenizer.decode(sparse_next_token)}")
finally:
    remove_hooks(hooks)

# Sparse forward (75% sparsity)
hooks = register_sparsification_hooks(model, sparsity=0.75, target_modules=["down_proj"])
try:
    with torch.no_grad():
        sparse_output = model(**inputs)
        sparse_next_token = sparse_output.logits[0, -1].argmax()
        print(f"Sparse prediction (75%): {tokenizer.decode(sparse_next_token)}")
finally:
    remove_hooks(hooks)

## 6. Training SPON Biases

Now let's train SPON biases to compensate for sparsification.

In [None]:
from src.allocation import SPONConfig
from src.spon_trainer import SPONTrainer, TrainingArgs, create_calibration_dataloader

# Create a simple config: SPON on first 4 layers only
config = SPONConfig(
    name="TOP-25",
    layer_mask=[0, 1, 2, 3],  # First 4 of 16 layers
    modules=["down_proj"]
)

# Training args (quick demo)
args = TrainingArgs(
    epochs=2,
    learning_rate=1e-4,
    batch_size=4,
    block_size=64,  # Shorter for demo
    device="cuda"
)

# Create calibration data
print("Creating calibration data...")
dataloader = create_calibration_dataloader(
    tokenizer,
    block_size=64,
    batch_size=4,
    num_samples=128  # Small for demo
)

print(f"Created {len(dataloader)} batches")

In [None]:
# Train SPON biases
trainer = SPONTrainer(
    model=model,
    config=config,
    sparsity=0.5,
    args=args,
    tokenizer=tokenizer
)

print(f"Training {sum(p.numel() for p in trainer.spon_params):,} SPON parameters...")
metrics = trainer.train(dataloader)

print(f"\nFinal loss: {metrics['final_loss']:.4f}")

In [None]:
# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(metrics['training_loss'], 'b-o')
plt.xlabel("Logging Step")
plt.ylabel("KL Divergence Loss")
plt.title("SPON Training Loss")
plt.grid(True, alpha=0.3)
plt.show()

## 7. Evaluating SPON

Let's compare perplexity with and without SPON.

In [None]:
from src.evaluation import compute_perplexity

# Get trained biases
spon_biases = trainer.get_spon_biases()
device = torch.device("cuda")

# Evaluation dataloader (smaller)
eval_dataloader = create_calibration_dataloader(
    tokenizer,
    block_size=64,
    batch_size=4,
    num_samples=64
)

# Dense baseline (reload model for clean state)
print("Computing dense PPL...")
dense_result = compute_perplexity(model, eval_dataloader, device, use_sparse=False)
print(f"Dense PPL: {dense_result.perplexity:.2f}")

# TEAL only (no SPON)
print("\nComputing TEAL-only PPL...")
teal_result = compute_perplexity(
    model, eval_dataloader, device,
    use_sparse=True, sparsity=0.5, spon_biases=None
)
print(f"TEAL PPL: {teal_result.perplexity:.2f}")

# TEAL + SPON
print("\nComputing TEAL+SPON PPL...")
spon_result = compute_perplexity(
    model, eval_dataloader, device,
    use_sparse=True, sparsity=0.5, spon_biases=spon_biases
)
print(f"TEAL+SPON PPL: {spon_result.perplexity:.2f}")

# Summary
print("\n" + "="*50)
print("SUMMARY")
print("="*50)
print(f"Dense (baseline):     {dense_result.perplexity:.2f}")
print(f"TEAL only:            {teal_result.perplexity:.2f} (+{(teal_result.perplexity/dense_result.perplexity - 1)*100:.1f}%)")
print(f"TEAL + SPON:          {spon_result.perplexity:.2f} (+{(spon_result.perplexity/dense_result.perplexity - 1)*100:.1f}%)")
improvement = (teal_result.perplexity - spon_result.perplexity) / teal_result.perplexity * 100
print(f"\nSPON improvement:     {improvement:.1f}% PPL reduction vs TEAL")

## 8. Next Steps

You've learned the basics! Next notebooks:

1. **02_visualize_results.ipynb** - Analyze experimental results, plot Pareto frontiers
2. **03_layer_analysis.ipynb** - Understand which layers benefit most from SPON

To run the full experiment suite:
```bash
python experiments/run_allocation_sweep.py \
    --model meta-llama/Llama-3.2-1B \
    --configs BASELINE-TEAL UNIF-ALL TOP-25 TOP-50 \
    --sparsity 0.5 \
    --epochs 10
```