# Benchmarking Helmholtz Operator Modes

This notebook benchmarks the three different operator modes implemented in the `HelmholtzSolver`:
1. `matrix`: Uses 1D precomputed matrices and `vmap`.
2. `stencil`: Fully matrix-free stencil application via manual shifts.
3. `conv`: High-performance operator using `jax.lax.convolution`.

We compare their execution times after JIT compilation on multiple grid sizes.

In [None]:
import sys
import os
import time
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Add the project root to path
sys.path.append(os.path.abspath(".."))

from inverse_scattering_jax.src.helmholtz import HelmholtzSolver, HelmholtzOperator, extend_model

## 1. Benchmarking Function

We define a function that initializes the solver and times the execution of a single apply (the Helmholtz operator) and a full solve (GMRES).

In [None]:
def benchmark_mode(mode, nxint, nyint, npml, h, omega, sigma_max, order, n_trials=5):
    nx = nxint + 2 * npml
    ny = nyint + 2 * npml
    
op = HelmholtzOperator(nx, ny, npml, h, omega, sigma_max, order, mode=mode)
op = HelmholtzOperator(op)
op = HelmholtzOperator(nx, ny, npml, h, omega, sigma_max, order, mode=mode)
solver = HelmholtzSolver(op)
    
    # Generate dummy input
    key = jax.random.PRNGKey(0)
    u_vec = jax.random.normal(key, (ny * nx,), dtype=jnp.complex64)
    m_ext = jnp.ones((ny, nx), dtype=jnp.complex64)
    f_vec = jax.random.normal(key, (ny * nx,), dtype=jnp.complex64)
    
    # JIT the operator and solve
    op_jit = jax.jit(lambda u: solver.operator(u, m_ext))
    solve_jit = jax.jit(lambda f: solver.solve(f, m_ext))
    
    # Warm-up (Trigger JIT)
    _ = op_jit(u_vec).block_until_ready()
    # Note: Full solve might take long for warm-up on large grids, 
    # but we need it for timing.
    _ = solve_jit(f_vec)[0].block_until_ready()
    
    # Time Operator Apply
    t_op = []
    for _ in range(n_trials):
        start = time.perf_counter()
        _ = op_jit(u_vec).block_until_ready()
        t_op.append(time.perf_counter() - start)
    
    # Time Full Solve (GMRES)
    t_solve = []
    for _ in range(n_trials):
        start = time.perf_counter()
        _ = solve_jit(f_vec)[0].block_until_ready()
        t_solve.append(time.perf_counter() - start)
        
    return jnp.mean(jnp.array(t_op)), jnp.mean(jnp.array(t_solve))

## 2. Running the Benchmarks

We run the benchmarks for various grid sizes.

In [None]:
grid_sizes = [50, 100, 200, 400]
modes = ['matrix', 'stencil', 'conv']
results = {mode: {'op': [], 'solve': []} for mode in modes}

npml = 20
h = 0.01
omega = 10.0
sigma_max = 50.0
order = 4

for size in grid_sizes:
    print(f"Benchmarking grid size {size}x{size}...")
    for mode in modes:
        t_op, t_solve = benchmark_mode(mode, size, size, npml, h, omega, sigma_max, order)
        results[mode]['op'].append(t_op)
        results[mode]['solve'].append(t_solve)
        print(f"  Mode {mode:8}: Op {t_op*1000:6.2f} ms | Solve {t_solve:6.3f} s")

## 3. Visualization

We plot the operator execution time as a function of grid size.

In [None]:
plt.figure(figsize=(10, 6))
for mode in modes:
    plt.plot(grid_sizes, [t*1000 for t in results[mode]['op']], 'o-', label=f"Mode: {mode}")

plt.xlabel("Interior Grid Size (N)")
plt.ylabel("Operator Apply Time (ms)")
plt.title("Helmholtz Operator Benchmark (Order 4)")
plt.legend()
plt.grid(True)
plt.show()

### Conclusion

- **`matrix`**: Is usually fast for moderate sizes but scales purely with memory throughput.
- **`stencil`**: Is truly matrix-free. Depending on device (CPU/GPU), it might be faster or slower than convolution for small windows.
- **`conv`**: Leverages JAX's highly optimized convolution primitives. Should scale very well on GPUs for high-order stencils.