# JAX Heat Simulator Benchmark

This notebook tests and benchmarks the JAX-based heat simulator against the NumPy implementation.

In [None]:
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd(), '..', 'src'))
sys.path.insert(0, os.path.join(os.getcwd(), '..', 'data', 'Heat_Signature_zero-starter_notebook'))

import numpy as np
import jax
import jax.numpy as jnp
import time
import pickle
import plotly.graph_objects as go
from plotly.subplots import make_subplots

print(f"JAX version: {jax.__version__}")
print(f"Backend: {jax.default_backend()}")
print(f"Devices: {jax.devices()}")

## 1. Load Test Data

In [None]:
# Load test dataset
data_path = os.path.join(os.getcwd(), '..', 'data', 'Heat_Signature_zero-starter_notebook', 'test_data.pkl')
with open(data_path, 'rb') as f:
    test_data = pickle.load(f)

meta = test_data['meta']
samples = test_data['samples']
print(f"Loaded {len(samples)} samples")
print(f"Meta: dt={meta['dt']}, q_range={meta['q_range']}")

## 2. Test JAX Simulator Correctness

First, let's verify the JAX simulator produces similar results to the NumPy version.

In [None]:
from simulator import Heat2D
from jax_simulator import JAXHeatSimulator, check_gpu

# Check GPU availability
gpu_info = check_gpu()
print("GPU Info:", gpu_info)

In [None]:
# Test parameters
Lx, Ly = 2.0, 1.0
nx, ny = 100, 50
kappa = 0.01
dt = 0.01
nt = 100
bc = 'dirichlet'
T0 = 0.0

# Test source
sources = [{'x': 1.0, 'y': 0.5, 'q': 1.0}]
sensors_xy = np.array([[0.5, 0.25], [1.0, 0.5], [1.5, 0.75]])

print("Test configuration:")
print(f"  Grid: {nx}x{ny}")
print(f"  Domain: {Lx}x{Ly}")
print(f"  Time steps: {nt} (dt={dt})")
print(f"  Source: {sources[0]}")

In [None]:
# NumPy simulation
solver_np = Heat2D(Lx, Ly, nx, ny, kappa, bc=bc)
times_np, Us_np = solver_np.solve(dt=dt, nt=nt, T0=T0, sources=sources)
Y_np = np.array([solver_np.sample_sensors(U, sensors_xy) for U in Us_np])

print(f"NumPy result shape: {Y_np.shape}")
print(f"NumPy final temps at sensors: {Y_np[-1]}")

In [None]:
# JAX simulation
solver_jax = JAXHeatSimulator(Lx, Ly, nx, ny)
Y_jax = solver_jax.simulate_and_sample(sources, sensors_xy, kappa, dt, nt, bc, T0)

print(f"JAX result shape: {Y_jax.shape}")
print(f"JAX final temps at sensors: {Y_jax[-1]}")

In [None]:
# Compare results
diff = np.abs(Y_np - Y_jax)
print(f"\nComparison:")
print(f"  Max absolute difference: {np.max(diff):.6e}")
print(f"  Mean absolute difference: {np.mean(diff):.6e}")
print(f"  Relative error: {np.max(diff) / (np.max(np.abs(Y_np)) + 1e-10):.6e}")

if np.max(diff) < 0.1:
    print("\n✓ JAX and NumPy results match within tolerance!")
else:
    print("\n✗ Results differ significantly - check implementation")

In [None]:
# Visualize comparison
fig = make_subplots(rows=1, cols=1)

colors = ['blue', 'red', 'green']
for i in range(len(sensors_xy)):
    fig.add_trace(go.Scatter(
        x=times_np, y=Y_np[:, i],
        mode='lines',
        name=f'NumPy Sensor {i+1}',
        line=dict(color=colors[i], dash='solid')
    ))
    fig.add_trace(go.Scatter(
        x=times_np, y=Y_jax[:, i],
        mode='lines',
        name=f'JAX Sensor {i+1}',
        line=dict(color=colors[i], dash='dash')
    ))

fig.update_layout(
    title='NumPy vs JAX Simulation Comparison',
    xaxis_title='Time',
    yaxis_title='Temperature',
    height=400
)
fig.show()

## 3. Benchmark: JAX vs NumPy Speed

In [None]:
def benchmark_simulation(n_runs=10):
    """Benchmark simulation speed."""
    
    # Warm up JAX (JIT compilation)
    print("Warming up JAX (JIT compilation)...")
    _ = solver_jax.simulate_and_sample(sources, sensors_xy, kappa, dt, nt, bc, T0)
    
    # NumPy benchmark
    print(f"\nRunning {n_runs} NumPy simulations...")
    start = time.time()
    for _ in range(n_runs):
        times_np, Us_np = solver_np.solve(dt=dt, nt=nt, T0=T0, sources=sources)
        Y = np.array([solver_np.sample_sensors(U, sensors_xy) for U in Us_np])
    numpy_time = (time.time() - start) / n_runs
    print(f"NumPy: {numpy_time*1000:.2f} ms per simulation")
    
    # JAX benchmark
    print(f"\nRunning {n_runs} JAX simulations...")
    start = time.time()
    for _ in range(n_runs):
        Y = solver_jax.simulate_and_sample(sources, sensors_xy, kappa, dt, nt, bc, T0)
        # Block until computation completes
        if hasattr(Y, 'block_until_ready'):
            Y.block_until_ready()
    jax_time = (time.time() - start) / n_runs
    print(f"JAX: {jax_time*1000:.2f} ms per simulation")
    
    speedup = numpy_time / jax_time
    print(f"\nSpeedup: {speedup:.2f}x")
    
    return {'numpy_ms': numpy_time*1000, 'jax_ms': jax_time*1000, 'speedup': speedup}

results = benchmark_simulation(n_runs=10)

## 4. Test JAX Automatic Differentiation

The key advantage of JAX is automatic differentiation - we can compute exact gradients!

In [None]:
# Get a real sample
sample = samples[0]
Y_observed = sample['Y_noisy']
sensors_xy_sample = np.array(sample['sensors_xy'])
n_sources = sample['n_sources']

print(f"Sample: {sample['sample_id']}")
print(f"  n_sources: {n_sources}")
print(f"  n_sensors: {len(sensors_xy_sample)}")
print(f"  Y_observed shape: {Y_observed.shape}")

In [None]:
# Test gradient computation
sample_meta = sample['sample_metadata']
kappa_s = sample_meta['kappa']
bc_s = sample_meta['bc']
T0_s = sample_meta['T0']
nt_s = sample_meta['nt']

# Create a test source guess
test_sources = [{'x': 1.0, 'y': 0.5, 'q': 1.0}]
if n_sources == 2:
    test_sources.append({'x': 0.5, 'y': 0.3, 'q': 0.8})

print(f"Computing objective and gradient for sources: {test_sources}")

# Compute objective
rmse = solver_jax.compute_objective(
    test_sources, Y_observed, sensors_xy_sample,
    kappa_s, meta['dt'], nt_s, bc_s, T0_s
)
print(f"RMSE: {rmse:.6f}")

# Compute gradient
grads = solver_jax.compute_gradient(
    test_sources, Y_observed, sensors_xy_sample,
    kappa_s, meta['dt'], nt_s, bc_s, T0_s
)
print(f"Gradients shape: {grads.shape}")
print(f"Gradients: {grads}")

In [None]:
# Verify gradient with finite differences
print("\nVerifying gradients with finite differences...")
eps = 1e-4
sources_arr = np.array([[s['x'], s['y'], s['q']] for s in test_sources])

fd_grads = np.zeros_like(sources_arr)
for i in range(sources_arr.shape[0]):
    for j in range(sources_arr.shape[1]):
        sources_plus = sources_arr.copy()
        sources_plus[i, j] += eps
        sources_minus = sources_arr.copy()
        sources_minus[i, j] -= eps
        
        sources_p = [{'x': s[0], 'y': s[1], 'q': s[2]} for s in sources_plus]
        sources_m = [{'x': s[0], 'y': s[1], 'q': s[2]} for s in sources_minus]
        
        f_plus = solver_jax.compute_objective(
            sources_p, Y_observed, sensors_xy_sample,
            kappa_s, meta['dt'], nt_s, bc_s, T0_s
        )
        f_minus = solver_jax.compute_objective(
            sources_m, Y_observed, sensors_xy_sample,
            kappa_s, meta['dt'], nt_s, bc_s, T0_s
        )
        fd_grads[i, j] = (f_plus - f_minus) / (2 * eps)

print(f"\nAutomatic diff gradients:\n{grads}")
print(f"\nFinite diff gradients:\n{fd_grads}")
print(f"\nDifference:\n{np.abs(grads - fd_grads)}")
print(f"\nMax difference: {np.max(np.abs(grads - fd_grads)):.6e}")

## 5. Test JAX Optimizer

In [None]:
from jax_optimizer import JAXOptimizer

# Create optimizer
jax_opt = JAXOptimizer(Lx, Ly, nx, ny)

In [None]:
# Run optimization on a sample
print(f"Optimizing sample: {sample['sample_id']}")
print(f"  n_sources: {sample['n_sources']}")
print(f"  bc: {sample['sample_metadata']['bc']}")

# Use Adam optimizer
est_sources, rmse = jax_opt.estimate_sources_adam(
    sample, meta,
    q_range=meta['q_range'],
    n_restarts=3,
    max_iter=100,
    learning_rate=0.05,
    verbose=True
)

print(f"\nResults:")
print(f"  Estimated sources: {est_sources}")
print(f"  RMSE: {rmse:.6f}")

## 6. Compare JAX vs NumPy Optimizer Performance

In [None]:
from optimizer import HeatSourceOptimizer

# NumPy optimizer
np_opt = HeatSourceOptimizer(Lx, Ly, nx, ny)

# Get a sample with single source for simpler comparison
single_source_samples = [s for s in samples if s['n_sources'] == 1]
test_sample = single_source_samples[0] if single_source_samples else samples[0]
print(f"Testing on sample: {test_sample['sample_id']} (n_sources={test_sample['n_sources']})")

In [None]:
# Benchmark NumPy optimizer
print("Running NumPy optimizer (L-BFGS-B, 3 restarts)...")
start = time.time()
np_sources, np_rmse = np_opt.estimate_sources(
    test_sample, meta,
    q_range=meta['q_range'],
    method='L-BFGS-B',
    n_restarts=3,
    max_iter=100
)
np_time = time.time() - start
print(f"NumPy: {np_time:.2f}s, RMSE={np_rmse:.6f}")
print(f"  Sources: {np_sources}")

In [None]:
# Benchmark JAX optimizer
print("\nRunning JAX optimizer (Adam, 3 restarts)...")
start = time.time()
jax_sources, jax_rmse = jax_opt.estimate_sources_adam(
    test_sample, meta,
    q_range=meta['q_range'],
    n_restarts=3,
    max_iter=100,
    learning_rate=0.05,
    verbose=False
)
jax_time = time.time() - start
print(f"JAX: {jax_time:.2f}s, RMSE={jax_rmse:.6f}")
print(f"  Sources: {jax_sources}")

In [None]:
# Summary
print("\n" + "="*50)
print("OPTIMIZATION BENCHMARK SUMMARY")
print("="*50)
print(f"NumPy (L-BFGS-B): {np_time:.2f}s, RMSE={np_rmse:.6f}")
print(f"JAX (Adam):       {jax_time:.2f}s, RMSE={jax_rmse:.6f}")
print(f"\nTime ratio: JAX is {np_time/jax_time:.2f}x {'faster' if jax_time < np_time else 'slower'}")
print(f"RMSE comparison: JAX is {'better' if jax_rmse < np_rmse else 'worse'} by {abs(jax_rmse - np_rmse):.6f}")

## 7. Visualize Optimization Results

In [None]:
# Simulate with estimated sources and compare
Y_observed = test_sample['Y_noisy']
sensors = np.array(test_sample['sensors_xy'])
sample_meta = test_sample['sample_metadata']

# Simulate with JAX estimated sources
jax_est_dicts = [{'x': s[0], 'y': s[1], 'q': s[2]} for s in jax_sources]
Y_jax_pred = solver_jax.simulate_and_sample(
    jax_est_dicts, sensors,
    sample_meta['kappa'], meta['dt'], sample_meta['nt'],
    sample_meta['bc'], sample_meta['T0']
)

# Plot comparison
times = np.arange(len(Y_observed)) * meta['dt']
n_sensors = Y_observed.shape[1]

fig = make_subplots(rows=n_sensors, cols=1, 
                    subplot_titles=[f'Sensor {i+1}' for i in range(n_sensors)])

for i in range(n_sensors):
    fig.add_trace(
        go.Scatter(x=times, y=Y_observed[:, i], mode='lines', 
                   name=f'Observed {i+1}', line=dict(color='blue')),
        row=i+1, col=1
    )
    fig.add_trace(
        go.Scatter(x=times, y=Y_jax_pred[:, i], mode='lines',
                   name=f'JAX Pred {i+1}', line=dict(color='red', dash='dash')),
        row=i+1, col=1
    )

fig.update_layout(
    title=f'Observed vs JAX Predicted (RMSE={jax_rmse:.4f})',
    height=200*n_sensors,
    showlegend=True
)
fig.show()

## Summary

Key findings:
1. **Correctness**: JAX simulator produces results matching NumPy within numerical tolerance
2. **Speed**: JAX with JIT compilation provides speedup (varies by hardware)
3. **Autodiff**: JAX provides exact gradients, verified against finite differences
4. **Optimization**: JAX optimizer with Adam can achieve competitive RMSE

Note: For GPU acceleration, you would need to install JAX with CUDA support.