# Tutorial 2: Fourier Features - Why They're Magic

**Time:** 20 minutes  
**Goal:** Understand the spectral bias problem and how Fourier features solve it

---

## The Problem: Spectral Bias

Standard MLPs with ReLU activations have a **spectral bias** - they learn low frequencies first and struggle with high-frequency details.

Let's see this problem in action, then fix it with Fourier features!

In [None]:
# Imports
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt

from inr_toolkit.models import ReLUMLP, FourierFeaturesMLP
from inr_toolkit.training import Trainer
from inr_toolkit.utils import get_image_coordinates, psnr

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## Create a Test Image with High Frequencies

Let's create an image with both low and high frequency patterns - a checkerboard!

In [None]:
# Create checkerboard pattern
height, width = 128, 128
checker_size = 8

x = np.arange(width)
y = np.arange(height)
X, Y = np.meshgrid(x, y)

# Create checkerboard
checker = ((X // checker_size) + (Y // checker_size)) % 2
image = np.stack([checker, checker, checker], axis=-1).astype(np.float32)

plt.figure(figsize=(6, 6))
plt.imshow(image, cmap='gray')
plt.title('Target: Checkerboard Pattern\n(High frequency details)')
plt.axis('off')
plt.show()

# Prepare training data
coords = get_image_coordinates(height, width)
colors = torch.from_numpy(image.reshape(-1, 3))

## Experiment 1: Standard ReLU MLP (Will Fail!)

Let's try a standard MLP and watch it struggle with the high frequencies.

In [None]:
# Create standard ReLU MLP
relu_model = ReLUMLP(
    in_dim=2,
    out_dim=3,
    hidden_dim=256,
    num_layers=4
)

print(f'ReLU MLP has {relu_model.count_parameters():,} parameters')

# Train
trainer = Trainer(relu_model, lr=1e-3, device=device)
print('\nTraining ReLU MLP...')
trainer.fit(coords, colors, epochs=1000, log_every=200)

In [None]:
# Visualize results
relu_model.eval()
with torch.no_grad():
    relu_output = relu_model(coords.to(device)).cpu().numpy()
    relu_output = relu_output.reshape(height, width, 3)

relu_psnr = psnr(torch.from_numpy(relu_output), torch.from_numpy(image))

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(image, cmap='gray')
axes[0].set_title('Target Image')
axes[0].axis('off')

axes[1].imshow(np.clip(relu_output, 0, 1), cmap='gray')
axes[1].set_title(f'ReLU MLP Output\nPSNR: {relu_psnr:.2f} dB\n(Blurry!)')
axes[1].axis('off')

axes[2].imshow(np.abs(image - relu_output))
axes[2].set_title('Error Map\n(Red = Wrong)')
axes[2].axis('off')

plt.tight_layout()
plt.show()

print(f'\n❌ ReLU MLP PSNR: {relu_psnr:.2f} dB (Low = Bad)')
print('Notice how it only learned a blurry version!')

## Experiment 2: Fourier Features MLP (Will Succeed!)

Now let's use Fourier features to map coordinates to higher dimensions before the MLP.

**Key idea:** `x → [sin(Bx), cos(Bx)] → MLP`

This helps the network learn high-frequency patterns!

In [None]:
# Create Fourier Features MLP
fourier_model = FourierFeaturesMLP(
    in_dim=2,
    out_dim=3,
    hidden_dim=256,
    num_layers=4,
    fourier_scale=10.0  # Controls frequency range
)

print(f'Fourier MLP has {fourier_model.count_parameters():,} parameters')

# Train
trainer = Trainer(fourier_model, lr=1e-3, device=device)
print('\nTraining Fourier Features MLP...')
trainer.fit(coords, colors, epochs=1000, log_every=200)

In [None]:
# Visualize results
fourier_model.eval()
with torch.no_grad():
    fourier_output = fourier_model(coords.to(device)).cpu().numpy()
    fourier_output = fourier_output.reshape(height, width, 3)

fourier_psnr = psnr(torch.from_numpy(fourier_output), torch.from_numpy(image))

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(image, cmap='gray')
axes[0].set_title('Target Image')
axes[0].axis('off')

axes[1].imshow(np.clip(fourier_output, 0, 1), cmap='gray')
axes[1].set_title(f'Fourier Features Output\nPSNR: {fourier_psnr:.2f} dB\n(Sharp!)')
axes[1].axis('off')

axes[2].imshow(np.abs(image - fourier_output))
axes[2].set_title('Error Map\n(Much better!)')
axes[2].axis('off')

plt.tight_layout()
plt.show()

print(f'\n✅ Fourier MLP PSNR: {fourier_psnr:.2f} dB (High = Good)')
print('Notice how it captured the sharp edges!')

## Side-by-Side Comparison

Let's see the dramatic difference!

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

axes[0].imshow(image, cmap='gray')
axes[0].set_title('Target', fontsize=16)
axes[0].axis('off')

axes[1].imshow(np.clip(relu_output, 0, 1), cmap='gray')
axes[1].set_title(f'❌ ReLU MLP\nPSNR: {relu_psnr:.2f} dB', fontsize=16)
axes[1].axis('off')

axes[2].imshow(np.clip(fourier_output, 0, 1), cmap='gray')
axes[2].set_title(f'✅ Fourier Features\nPSNR: {fourier_psnr:.2f} dB', fontsize=16)
axes[2].axis('off')

plt.tight_layout()
plt.show()

improvement = fourier_psnr - relu_psnr
print(f'\nImprovement: +{improvement:.2f} dB')
print('This is the power of Fourier features! 🚀')

## Understanding the Fourier Scale Parameter

The `fourier_scale` parameter controls the frequency range.
- **Higher scale** → captures finer details
- **Lower scale** → smoother outputs

Let's experiment!

In [None]:
# Try different Fourier scales
scales = [1.0, 5.0, 10.0, 20.0]
results = []

for scale in scales:
    print(f'\nTraining with fourier_scale={scale}...')
    
    model = FourierFeaturesMLP(
        in_dim=2, out_dim=3,
        hidden_dim=128, num_layers=3,
        fourier_scale=scale
    )
    
    trainer = Trainer(model, lr=1e-3, device=device)
    trainer.fit(coords, colors, epochs=500, log_every=500)
    
    model.eval()
    with torch.no_grad():
        output = model(coords.to(device)).cpu().numpy()
        output = output.reshape(height, width, 3)
    
    psnr_val = psnr(torch.from_numpy(output), torch.from_numpy(image))
    results.append((scale, output, psnr_val))

# Visualize
fig, axes = plt.subplots(1, len(scales), figsize=(20, 5))

for ax, (scale, output, psnr_val) in zip(axes, results):
    ax.imshow(np.clip(output, 0, 1), cmap='gray')
    ax.set_title(f'Scale={scale}\nPSNR: {psnr_val:.2f} dB', fontsize=14)
    ax.axis('off')

plt.tight_layout()
plt.show()

print('\nConclusion: Higher scales capture more detail (for this image, scale=10-20 works best)')

## Summary

**What you learned:**
1. ✅ Standard MLPs have **spectral bias** - they struggle with high frequencies
2. ✅ **Fourier features** solve this by mapping `x → [sin(Bx), cos(Bx)]`
3. ✅ The `fourier_scale` parameter controls frequency range
4. ✅ Fourier features give massive PSNR improvements on complex signals

**Key takeaway:**
```python
# Don't do this for INRs:
model = ReLUMLP(...)  # ❌ Will be blurry

# Do this instead:
model = FourierFeaturesMLP(...)  # ✅ Sharp and detailed
```

---

## Experiment!

Try:
- Different `fourier_scale` values
- Different `fourier_dim` (try 128, 512)
- Your own images with fine details

**Next:** [Tutorial 3 - Comparing Architectures](03_comparing_architectures.ipynb) to see SIREN vs Fourier Features!