# Heat Source Optimizer Demo

This notebook demonstrates our **Option 1: Gradient-Based Optimization** approach.

**Note**: Each simulation takes ~0.5-1 seconds, so optimization can be slow. This is expected!

In [None]:
import sys
sys.path.insert(0, '../data/Heat_Signature_zero-starter_notebook')
sys.path.insert(0, '../src')

import numpy as np
import pickle
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from simulator import Heat2D
from utils import make_training_dataset
from optimizer import HeatSourceOptimizer
from visualize import plot_sample_overview, plot_estimation_comparison, plot_source_search_space

print("Imports successful!")

## 1. Generate Synthetic Data with Known Ground Truth

First, let's create a simple test case where we know the true source location.

In [None]:
# Domain parameters (fixed for this challenge)
Lx, Ly = 2.0, 1.0
nx, ny = 100, 50
dt = 0.004

# Generate 1 sample: 1 source, 3 sensors, no noise (easiest case)
dset = make_training_dataset(
    Lx, Ly, nx, ny,
    num_samples=1,
    dt=dt,
    num_sources_sensors={(1, 3): 1.0},
    q_range=(0.5, 2.0),
    noise_std={0.0: 1.0},  # No noise
    nt={400: 1.0},  # Shorter simulation for speed
    kappa={0.1: 1.0},
    bc='dirichlet',
    seed=42,
)

sample = dset['samples'][0]
meta = dset['meta']

print("=== Ground Truth ===")
true_src = sample['true_sources'][0]
print(f"Source location: x={true_src['x']:.3f}, y={true_src['y']:.3f}")
print(f"Source intensity: q={true_src['q']:.3f}")
print(f"\nSensors: {sample['sensors_xy'].shape[0]}")
print(f"Time steps: {sample['sample_metadata']['nt']}")

## 2. Visualize the Sample

In [None]:
# Plot the sample with ground truth
fig = plot_sample_overview(
    sample, meta,
    true_sources=sample['true_sources'],
)
fig.show()

## 3. Run the Optimizer

Now let's try to recover the source parameters from just the sensor readings.

**This will take 1-2 minutes** because each optimization iteration requires running a full simulation.

In [None]:
%%time

# Create optimizer
optimizer = HeatSourceOptimizer(Lx, Ly, nx, ny)

# Run optimization WITH PARALLEL RESTARTS for speedup!
# parallel=True runs multiple restarts across CPU cores simultaneously
estimated, rmse = optimizer.estimate_sources(
    sample, meta,
    q_range=(0.5, 2.0),
    method='L-BFGS-B',
    n_restarts=4,   # More restarts since they run in parallel
    max_iter=30,
    parallel=True,  # Enable parallel execution!
    n_jobs=-1,      # Use all CPU cores (-1 = auto-detect)
)

print("\n=== Results ===")
print(f"True:      x={true_src['x']:.3f}, y={true_src['y']:.3f}, q={true_src['q']:.3f}")
print(f"Estimated: x={estimated[0][0]:.3f}, y={estimated[0][1]:.3f}, q={estimated[0][2]:.3f}")
print(f"RMSE: {rmse:.6f}")

# Position error
pos_error = np.sqrt((true_src['x'] - estimated[0][0])**2 + (true_src['y'] - estimated[0][1])**2)
print(f"Position error: {pos_error:.4f}")

## 4. Visualize Results

In [None]:
# Compare true vs estimated
fig = plot_sample_overview(
    sample, meta,
    estimated_sources=estimated,
    true_sources=sample['true_sources'],
)
fig.update_layout(title="Ground Truth (red) vs Estimated (green)")
fig.show()

In [None]:
# Compare sensor readings
fig = plot_estimation_comparison(
    sample, meta,
    estimated_sources=estimated,
)
fig.show()

## 5. Visualize the Loss Landscape

This shows RMSE as a function of source position (with q fixed). Lower values (darker) indicate better source locations.

In [None]:
%%time
# This takes ~1-2 minutes to compute the grid
fig = plot_source_search_space(
    sample, meta,
    estimated_sources=estimated,
    true_sources=sample['true_sources'],
    grid_resolution=15,  # Lower for speed
)
if fig:
    fig.show()

## 6. Test on Real Test Data

Now let's try on a sample from the actual test dataset.

In [None]:
# Load test dataset
with open('../data/Heat_Signature_zero-starter_notebook/testdataset_80samples.pkl', 'rb') as f:
    test_data = pickle.load(f)

print(f"Loaded {len(test_data['samples'])} test samples")

# Pick a simple sample (1 source)
test_sample = None
for s in test_data['samples']:
    if s['n_sources'] == 1:
        test_sample = s
        break

print(f"\nSelected: {test_sample['sample_id']}")
print(f"n_sources: {test_sample['n_sources']}")
print(f"n_sensors: {test_sample['sensors_xy'].shape[0]}")
print(f"noise: {test_sample['sample_metadata']['noise_std']}")

In [None]:
# Visualize the test sample (no ground truth available)
fig = plot_sample_overview(test_sample, test_data['meta'])
fig.show()

In [None]:
%%time
# Run optimizer on test sample
optimizer = HeatSourceOptimizer(Lx, Ly, nx, ny)

estimated_test, rmse_test = optimizer.estimate_sources(
    test_sample, test_data['meta'],
    q_range=(0.5, 2.0),
    method='L-BFGS-B',
    n_restarts=3,
    max_iter=50,
)

print(f"\nEstimated source: x={estimated_test[0][0]:.3f}, y={estimated_test[0][1]:.3f}, q={estimated_test[0][2]:.3f}")
print(f"RMSE: {rmse_test:.6f}")

In [None]:
# Show result
fig = plot_sample_overview(
    test_sample, test_data['meta'],
    estimated_sources=estimated_test,
)
fig.update_layout(title=f"Estimated Source for {test_sample['sample_id']}")
fig.show()

In [None]:
# Compare observed vs predicted
fig = plot_estimation_comparison(
    test_sample, test_data['meta'],
    estimated_sources=estimated_test,
)
fig.show()

## Key Observations

1. **The optimizer works!** It can recover source parameters from sensor readings.

2. **Parallel restarts** - Use `parallel=True` to run multiple restarts across CPU cores (~4-8x speedup).

3. **Next steps to improve**:
   - Better initialization (use sensor data heuristics)
   - Coarse-to-fine optimization
   - GPU acceleration with JAX (see below)

## GPU Acceleration Options

For even faster optimization, we can use GPU acceleration:

| Option | Speedup | Notes |
|--------|---------|-------|
| **JAX** | 10-50x | Best option - GPU + autodiff for true gradients |
| **CuPy** | 2-5x | Drop-in numpy replacement for GPU |
| **PyTorch** | 5-20x | Good for neural network surrogates |

### Why JAX is ideal for this problem:
1. **GPU acceleration** - Run PDE solver on GPU
2. **Automatic differentiation** - Get exact gradients for optimizer (not finite differences)
3. **JIT compilation** - Compile solver for maximum speed
4. **Vectorization** - Batch multiple simulations together

To enable JAX, install with: `uv add jax jaxlib` (or `jax[cuda12]` for GPU)