# Tutorial 3: Comparing Architectures

**Time:** 25 minutes  
**Goal:** Understand when to use SIREN vs Fourier Features vs ReLU MLP

---

## The Three Architectures

| Architecture | Key Feature | Best For |
|--------------|-------------|----------|
| **ReLU MLP** | Standard activations | Baseline (usually poor) |
| **Fourier Features** | Random Fourier mapping | Most tasks, good default |
| **SIREN** | Sine activations | Smooth signals, derivatives |

Let's compare them head-to-head!

In [None]:
# Imports
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import time

from inr_toolkit.models import ReLUMLP, FourierFeaturesMLP, SIREN
from inr_toolkit.training import Trainer
from inr_toolkit.utils import get_image_coordinates, psnr, load_image

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## Test Image: Natural Photo

Let's use a natural image with both smooth and detailed regions.
We'll create a simple test image for this demo.

In [None]:
# Create a test image with mixed frequencies
height, width = 128, 128

x = np.linspace(-2, 2, width)
y = np.linspace(-2, 2, height)
X, Y = np.meshgrid(x, y)

# Create interesting pattern
R = np.sqrt(X**2 + Y**2)
image = np.stack([
    0.5 + 0.5 * np.sin(5 * R) / (1 + R),           # Ripple in red
    0.5 + 0.3 * np.cos(3 * X) * np.cos(3 * Y),    # Grid in green
    0.5 + 0.4 * np.exp(-R**2 / 2),                # Gaussian in blue
], axis=-1)

plt.figure(figsize=(6, 6))
plt.imshow(np.clip(image, 0, 1))
plt.title('Test Image: Mixed Frequencies')
plt.axis('off')
plt.show()

# Prepare data
coords = get_image_coordinates(height, width)
colors = torch.from_numpy(image.reshape(-1, 3).astype(np.float32))

## Benchmark Setup

We'll train all three architectures with the same capacity and compare:
1. **Quality** (PSNR)
2. **Training Time**
3. **Visual Results**

In [None]:
# Configuration (same for all models)
config = {
    'hidden_dim': 256,
    'num_layers': 4,
    'lr': 1e-3,
    'epochs': 1000,
}

# Create models
models = {
    'ReLU MLP': ReLUMLP(
        in_dim=2, out_dim=3,
        hidden_dim=config['hidden_dim'],
        num_layers=config['num_layers']
    ),
    'Fourier Features': FourierFeaturesMLP(
        in_dim=2, out_dim=3,
        hidden_dim=config['hidden_dim'],
        num_layers=config['num_layers'],
        fourier_scale=10.0
    ),
    'SIREN': SIREN(
        in_dim=2, out_dim=3,
        hidden_dim=config['hidden_dim'],
        num_layers=config['num_layers']
    ),
}

# Show parameter counts
print('Model Parameter Counts:')
for name, model in models.items():
    print(f'  {name:20s}: {model.count_parameters():,} params')

## Train All Models

This will take a few minutes. Watch the progress bars!

In [None]:
results = {}

for name, model in models.items():
    print(f'\n{"="*60}')
    print(f'Training: {name}')
    print(f'{"="*60}')
    
    # Train
    trainer = Trainer(model, lr=config['lr'], device=device)
    
    start_time = time.time()
    trainer.fit(coords, colors, epochs=config['epochs'], log_every=250)
    train_time = time.time() - start_time
    
    # Evaluate
    model.eval()
    with torch.no_grad():
        output = model(coords.to(device)).cpu().numpy()
        output = output.reshape(height, width, 3)
    
    psnr_val = psnr(torch.from_numpy(output), torch.from_numpy(image))
    
    results[name] = {
        'output': output,
        'psnr': psnr_val,
        'time': train_time,
        'params': model.count_parameters()
    }
    
    print(f'\n✅ {name}: PSNR = {psnr_val:.2f} dB, Time = {train_time:.1f}s')

print(f'\n{"="*60}')
print('Training complete!')
print(f'{"="*60}')

## Visual Comparison

Let's see the outputs side by side!

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Row 1: Reconstructions
axes[0, 0].imshow(np.clip(image, 0, 1))
axes[0, 0].set_title('Ground Truth', fontsize=14)
axes[0, 0].axis('off')

for ax, (name, res) in zip(axes[0, 1:], list(results.items())[:2]):
    ax.imshow(np.clip(res['output'], 0, 1))
    ax.set_title(f"{name}\nPSNR: {res['psnr']:.2f} dB", fontsize=14)
    ax.axis('off')

# Row 2: Third model and error maps
name, res = list(results.items())[2]
axes[1, 0].imshow(np.clip(res['output'], 0, 1))
axes[1, 0].set_title(f"{name}\nPSNR: {res['psnr']:.2f} dB", fontsize=14)
axes[1, 0].axis('off')

# Error maps
for ax, (name, res) in zip(axes[1, 1:], results.items()):
    error = np.abs(image - res['output'])
    ax.imshow(error)
    ax.set_title(f"{name}\nError Map", fontsize=12)
    ax.axis('off')

plt.tight_layout()
plt.show()

## Quantitative Results

Let's see the numbers!

In [None]:
import pandas as pd

# Create comparison table
comparison = []
for name, res in results.items():
    comparison.append({
        'Model': name,
        'PSNR (dB)': f"{res['psnr']:.2f}",
        'Training Time (s)': f"{res['time']:.1f}",
        'Parameters': f"{res['params']:,}"
    })

df = pd.DataFrame(comparison)
print('\n' + '='*70)
print('BENCHMARK RESULTS')
print('='*70)
print(df.to_string(index=False))
print('='*70)

# Determine winner
best_model = max(results.items(), key=lambda x: x[1]['psnr'])[0]
print(f"\n🏆 Winner (by PSNR): {best_model}")

## Zoomed-In Comparison

Let's zoom into a detailed region to see the differences more clearly.

In [None]:
# Define crop region
crop_slice = (slice(40, 88), slice(40, 88))

fig, axes = plt.subplots(1, 4, figsize=(20, 5))

# Ground truth
axes[0].imshow(np.clip(image[crop_slice], 0, 1))
axes[0].set_title('Ground Truth\n(Zoomed)', fontsize=14)
axes[0].axis('off')

# Model outputs
for ax, (name, res) in zip(axes[1:], results.items()):
    ax.imshow(np.clip(res['output'][crop_slice], 0, 1))
    ax.set_title(f"{name}\n{res['psnr']:.2f} dB", fontsize=14)
    ax.axis('off')

plt.tight_layout()
plt.show()

print('Look closely at the fine details!')
print('ReLU MLP is blurrier, while SIREN and Fourier Features are sharper.')

## When to Use Each Architecture?

### 🔴 ReLU MLP
**Use when:** You want a baseline to compare against  
**Avoid when:** You need good quality (almost always!)  
**Pro:** Simple  
**Con:** Poor quality due to spectral bias  

---

### 🟢 Fourier Features (Recommended Default)
**Use when:** General purpose INR tasks  
**Best for:** Most images, signals, volumes  
**Pro:** Great quality, easy to tune (just adjust `fourier_scale`)  
**Con:** Slightly more parameters than others  

---

### 🔵 SIREN
**Use when:** You need smooth derivatives, very smooth signals  
**Best for:** Physics simulations, signed distance functions  
**Pro:** Excellent for smooth functions, good derivatives  
**Con:** Can be sensitive to initialization, harder to tune  

---

## 💡 **Practical Recommendation:**
**Start with Fourier Features** (`fourier_scale=10.0`).  
If you need better derivatives or very smooth outputs, try SIREN.

## Summary

**What you learned:**
1. ✅ How to benchmark different INR architectures
2. ✅ ReLU MLPs are poor for INRs (spectral bias)
3. ✅ Fourier Features are the best general-purpose choice
4. ✅ SIREN excels at smooth signals and derivatives

**Decision tree:**
```
Do you need INR?
├─ Yes → Use Fourier Features (start here!)
│   ├─ Need derivatives? → Try SIREN
│   └─ General task? → Stick with Fourier Features
└─ No → Use standard methods
```

---

## Congratulations! 🎉

You've completed all three tutorials and now understand:
- What INRs are and how they work
- Why Fourier features are necessary
- When to use each architecture

**Next steps:**
- Try the [examples](../examples/) on real images
- Read [architecture deep dives](../docs/architectures.md)
- Explore [benchmarks](../benchmarks/) for more comparisons
- Build your own INR applications!

**Questions?** Open an issue on GitHub!