# SPON Layer Analysis

This notebook provides deep analysis of layer-wise SPON behavior.

**What you'll learn:**
1. Which layers benefit most from SPON biases
2. How to visualize SPON bias magnitudes across layers
3. Hidden state shift analysis (L2 distance between dense and sparse)
4. Layer sensitivity / importance ranking

**Key Research Questions:**
- Do early layers (near embeddings) need more SPON than later layers?
- Can we identify "cornerstone" layers where SPON is critical?
- What do the learned SPON biases encode?

In [None]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

plt.style.use('seaborn-v0_8-whitegrid')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

## 1. Load Model and Trained SPON Biases

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = "meta-llama/Llama-3.2-1B"

print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

num_layers = len(model.model.layers)
print(f"Loaded model with {num_layers} layers")

In [None]:
# Check for saved SPON biases
checkpoint_dir = Path("../results/allocation_sweep/runs")
spon_biases = None

if checkpoint_dir.exists():
    # Find most recent checkpoint
    checkpoints = list(checkpoint_dir.glob("*/checkpoints/*.pt"))
    if checkpoints:
        latest = sorted(checkpoints)[-1]
        print(f"Loading SPON biases from: {latest}")
        checkpoint = torch.load(latest, map_location="cuda")
        spon_biases = checkpoint.get("biases", {})
        print(f"Loaded {len(spon_biases)} bias tensors")
    else:
        print("No checkpoints found - will train fresh biases")
else:
    print("No results directory - will train fresh biases")

In [None]:
# If no saved biases, train them
if spon_biases is None or len(spon_biases) == 0:
    from src.allocation import SPONConfig
    from src.spon_trainer import SPONTrainer, TrainingArgs, create_calibration_dataloader
    
    # Train on ALL layers for analysis
    config = SPONConfig(
        name="UNIF-ALL",
        layer_mask=list(range(num_layers)),
        modules=["down_proj"]
    )
    
    args = TrainingArgs(
        epochs=3,
        learning_rate=1e-4,
        batch_size=4,
        block_size=64,
        device="cuda"
    )
    
    print("Creating calibration data...")
    dataloader = create_calibration_dataloader(
        tokenizer, block_size=64, batch_size=4, num_samples=256
    )
    
    print("Training SPON biases on all layers...")
    trainer = SPONTrainer(model, config, sparsity=0.5, args=args)
    trainer.train(dataloader)
    spon_biases = trainer.get_spon_biases()
    print(f"Trained {len(spon_biases)} bias tensors")

## 2. SPON Bias Magnitude Analysis

Larger bias magnitudes suggest the layer needs more compensation for sparsification.

In [None]:
def analyze_bias_magnitudes(spon_biases):
    """Analyze SPON bias magnitudes per layer."""
    layer_stats = []
    
    for key, bias in sorted(spon_biases.items()):
        # Extract layer index
        parts = key.split('_')
        layer_idx = int(parts[1])
        module = '_'.join(parts[2:])
        
        bias_np = bias.float().cpu().numpy()
        
        layer_stats.append({
            'layer': layer_idx,
            'module': module,
            'l2_norm': np.linalg.norm(bias_np),
            'mean': np.mean(bias_np),
            'std': np.std(bias_np),
            'max_abs': np.max(np.abs(bias_np)),
            'sparsity': np.mean(np.abs(bias_np) < 1e-6),
            'dim': len(bias_np)
        })
    
    return layer_stats

stats = analyze_bias_magnitudes(spon_biases)

# Plot L2 norms
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

layers = [s['layer'] for s in stats]

# L2 Norm
ax = axes[0, 0]
ax.bar(layers, [s['l2_norm'] for s in stats], color='steelblue', alpha=0.8)
ax.set_xlabel('Layer Index')
ax.set_ylabel('L2 Norm')
ax.set_title('SPON Bias L2 Norm by Layer')
ax.axhline(y=np.mean([s['l2_norm'] for s in stats]), color='red', linestyle='--', label='Mean')
ax.legend()

# Mean value
ax = axes[0, 1]
colors = ['green' if s['mean'] > 0 else 'red' for s in stats]
ax.bar(layers, [s['mean'] for s in stats], color=colors, alpha=0.8)
ax.set_xlabel('Layer Index')
ax.set_ylabel('Mean Bias Value')
ax.set_title('Mean SPON Bias by Layer')
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)

# Std deviation
ax = axes[1, 0]
ax.bar(layers, [s['std'] for s in stats], color='orange', alpha=0.8)
ax.set_xlabel('Layer Index')
ax.set_ylabel('Std Deviation')
ax.set_title('SPON Bias Variability by Layer')

# Max absolute value
ax = axes[1, 1]
ax.bar(layers, [s['max_abs'] for s in stats], color='purple', alpha=0.8)
ax.set_xlabel('Layer Index')
ax.set_ylabel('Max |Bias|')
ax.set_title('Maximum Absolute Bias by Layer')

plt.tight_layout()
plt.suptitle('SPON Bias Analysis Across Layers', y=1.02, fontsize=14)
plt.show()

## 3. Hidden State Shift Analysis

Measure how much sparsification changes hidden states at each layer.

In [None]:
from src.spon_trainer import create_calibration_dataloader
from src.evaluation import compute_hidden_state_shift

# Create evaluation data
eval_dataloader = create_calibration_dataloader(
    tokenizer, block_size=64, batch_size=4, num_samples=64
)

device = torch.device("cuda")

# Compute shifts WITHOUT SPON
print("Computing hidden state shifts (TEAL only)...")
shifts_no_spon = compute_hidden_state_shift(
    model, eval_dataloader, sparsity=0.5,
    spon_biases=None, device=device
)

# Compute shifts WITH SPON
print("Computing hidden state shifts (TEAL + SPON)...")
shifts_with_spon = compute_hidden_state_shift(
    model, eval_dataloader, sparsity=0.5,
    spon_biases=spon_biases, device=device
)

In [None]:
# Parse and plot
def parse_shifts(shifts_dict):
    """Parse shift dict into layer-indexed array."""
    layer_shifts = {}
    for key, shift in shifts_dict.items():
        layer_idx = int(key.split('_')[1])
        layer_shifts[layer_idx] = shift
    return layer_shifts

no_spon = parse_shifts(shifts_no_spon)
with_spon = parse_shifts(shifts_with_spon)

layers = sorted(no_spon.keys())

fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(layers))
width = 0.35

bars1 = ax.bar(x - width/2, [no_spon[l] for l in layers], width, 
               label='TEAL only', color='red', alpha=0.7)
bars2 = ax.bar(x + width/2, [with_spon[l] for l in layers], width,
               label='TEAL + SPON', color='green', alpha=0.7)

ax.set_xlabel('Layer Index', fontsize=12)
ax.set_ylabel('L2 Shift (vs Dense)', fontsize=12)
ax.set_title('Hidden State Drift from Sparsification', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(layers)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Compute reduction
print("\nSPON Drift Reduction by Layer:")
print("-" * 40)
for l in layers:
    reduction = (no_spon[l] - with_spon[l]) / no_spon[l] * 100
    print(f"Layer {l:2d}: {reduction:5.1f}% reduction")

## 4. Layer Importance Ranking

Which layers contribute most to SPON's effectiveness?

In [None]:
# Compute importance based on drift reduction
importance = {}
for l in layers:
    # Importance = how much this layer's SPON reduces drift
    importance[l] = no_spon[l] - with_spon[l]

# Normalize
total = sum(importance.values())
importance_pct = {l: v/total * 100 for l, v in importance.items()}

# Sort by importance
sorted_layers = sorted(importance_pct.items(), key=lambda x: x[1], reverse=True)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart
ax1.bar([str(l) for l, _ in sorted_layers], [v for _, v in sorted_layers],
        color=plt.cm.RdYlGn(np.linspace(0.8, 0.2, len(sorted_layers))))
ax1.set_xlabel('Layer Index (sorted by importance)')
ax1.set_ylabel('Importance (%)')
ax1.set_title('Layer Importance for SPON')
ax1.tick_params(axis='x', rotation=45)

# Cumulative importance
cumulative = np.cumsum([v for _, v in sorted_layers])
ax2.plot(range(1, len(cumulative)+1), cumulative, 'b-o', linewidth=2, markersize=8)
ax2.axhline(y=80, color='red', linestyle='--', label='80% threshold')
ax2.axhline(y=95, color='orange', linestyle='--', label='95% threshold')
ax2.set_xlabel('Number of Layers (most important first)')
ax2.set_ylabel('Cumulative Importance (%)')
ax2.set_title('Cumulative Layer Importance')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Find minimum layers for thresholds
for threshold in [80, 90, 95]:
    n_layers = np.searchsorted(cumulative, threshold) + 1
    print(f"{threshold}% importance achieved with top {n_layers} layers ({n_layers/num_layers*100:.0f}%)")

## 5. Bias Distribution Visualization

What do the learned biases look like?

In [None]:
# Visualize bias distributions for key layers
top_3_layers = [l for l, _ in sorted_layers[:3]]
bottom_3_layers = [l for l, _ in sorted_layers[-3:]]

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for idx, layer in enumerate(top_3_layers):
    key = f"layer_{layer}_down_proj"
    if key in spon_biases:
        bias = spon_biases[key].float().cpu().numpy()
        ax = axes[0, idx]
        ax.hist(bias, bins=50, alpha=0.7, color='green', edgecolor='black')
        ax.axvline(x=0, color='red', linestyle='--')
        ax.set_title(f'Layer {layer} (High Importance)\nμ={np.mean(bias):.4f}, σ={np.std(bias):.4f}')
        ax.set_xlabel('Bias Value')
        ax.set_ylabel('Count')

for idx, layer in enumerate(bottom_3_layers):
    key = f"layer_{layer}_down_proj"
    if key in spon_biases:
        bias = spon_biases[key].float().cpu().numpy()
        ax = axes[1, idx]
        ax.hist(bias, bins=50, alpha=0.7, color='orange', edgecolor='black')
        ax.axvline(x=0, color='red', linestyle='--')
        ax.set_title(f'Layer {layer} (Low Importance)\nμ={np.mean(bias):.4f}, σ={np.std(bias):.4f}')
        ax.set_xlabel('Bias Value')
        ax.set_ylabel('Count')

plt.suptitle('SPON Bias Distributions: High vs Low Importance Layers', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 6. Heatmap: Bias Values Across Layers

In [None]:
# Create heatmap of bias magnitudes
# Subsample dimensions for visualization
max_dims = 100  # Show first 100 dimensions

bias_matrix = []
for layer in range(num_layers):
    key = f"layer_{layer}_down_proj"
    if key in spon_biases:
        bias = spon_biases[key].float().cpu().numpy()[:max_dims]
        # Pad if needed
        if len(bias) < max_dims:
            bias = np.pad(bias, (0, max_dims - len(bias)))
        bias_matrix.append(bias)

bias_matrix = np.array(bias_matrix)

fig, ax = plt.subplots(figsize=(14, 8))

# Center colormap at 0
vmax = np.max(np.abs(bias_matrix))
im = ax.imshow(bias_matrix, aspect='auto', cmap='RdBu_r', vmin=-vmax, vmax=vmax)

ax.set_xlabel('Hidden Dimension (first 100)', fontsize=12)
ax.set_ylabel('Layer Index', fontsize=12)
ax.set_title('SPON Bias Values Across Layers and Dimensions', fontsize=14)

cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Bias Value', fontsize=12)

plt.tight_layout()
plt.show()

## 7. Summary and Recommendations

In [None]:
print("=" * 60)
print("LAYER ANALYSIS SUMMARY")
print("=" * 60)

print(f"\nModel: {MODEL_NAME}")
print(f"Total layers: {num_layers}")
print(f"Sparsity: 50%")

print("\n--- Most Important Layers ---")
for i, (layer, imp) in enumerate(sorted_layers[:5]):
    print(f"  {i+1}. Layer {layer}: {imp:.1f}% importance")

print("\n--- Least Important Layers ---")
for i, (layer, imp) in enumerate(sorted_layers[-3:]):
    print(f"  Layer {layer}: {imp:.1f}% importance")

print("\n--- Efficiency Recommendations ---")
for threshold in [80, 90, 95]:
    n = np.searchsorted(cumulative, threshold) + 1
    layers_to_use = [l for l, _ in sorted_layers[:n]]
    print(f"  For {threshold}% effectiveness: Use layers {sorted(layers_to_use)}")
    print(f"    -> {n}/{num_layers} layers = {(1-n/num_layers)*100:.0f}% parameter savings")

print("\n--- Key Insights ---")
early_importance = sum(imp for l, imp in importance_pct.items() if l < num_layers//4)
middle_importance = sum(imp for l, imp in importance_pct.items() if num_layers//4 <= l < 3*num_layers//4)
late_importance = sum(imp for l, imp in importance_pct.items() if l >= 3*num_layers//4)

print(f"  Early layers (0-{num_layers//4-1}): {early_importance:.1f}% importance")
print(f"  Middle layers ({num_layers//4}-{3*num_layers//4-1}): {middle_importance:.1f}% importance")
print(f"  Late layers ({3*num_layers//4}-{num_layers-1}): {late_importance:.1f}% importance")