# Inverse Scattering Reconstruction: Gaussian Bumps

This notebook demonstrates how to use the `inverse_scattering_jax` library to reconstruct a perturbation in the refractive index (scatterer) from measurements of scattered fields produced by multiple incident waves.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from inverse_scattering_jax import (
    HelmholtzOperator, 
    HelmholtzSolver, 
    IncomingDirections, 
    get_projection_op, 
    create_forward_with_adjoint,
    solve_inverse_problem,
    GMRESOptions
)

# Set default figure size
plt.rcParams['figure.figsize'] = [8, 6]

## 1. Setup Parameters
We define the grid dimensions, frequency, and PML parameters.

In [None]:
nxint, nyint = 64, 64
npml = 10
nx, ny = nxint + 2 * npml, nyint + 2 * npml
h = 1.0 / (nxint - 1)
omega = 25.0
sigma_max = 20.0

# Multi-view setup (incident directions)
n_theta = 8 

## 2. Initialize Solver and Incident Waves
We use the `stencil` mode for efficient matrix-free application and a relaxed tolerance for GMRES to speed up the optimization.

In [None]:
op = HelmholtzOperator(
    nx, ny, npml, h, omega, sigma_max, 
    mode='stencil'
)
solver = HelmholtzSolver(op, gmres_options=GMRESOptions(tol=1e-3, maxiter=500))
inc = IncomingDirections(nx, ny, npml, h, omega, n_theta)

## 3. Define Ground Truth (Gaussian Bumps)
We create a true perturbation model consisting of several Gaussian bumps.

In [None]:
x_int = (jnp.arange(nxint) - nxint // 2) * h
y_int = (jnp.arange(nyint) - nyint // 2) * h
X_int, Y_int = jnp.meshgrid(x_int, y_int, indexing='ij')

def gaussian_bump(X, Y, x0, y0, sigma, amplitude):
    return amplitude * jnp.exp(-((X - x0)**2 + (Y - y0)**2) / (2 * sigma**2))

eta_true = (
    gaussian_bump(X_int, Y_int, -0.2, -0.1, 0.05, 0.1) +
    gaussian_bump(X_int, Y_int, 0.1, 0.2, 0.04, 0.15) +
    gaussian_bump(X_int, Y_int, 0.1, -0.2, 0.06, 0.12)
)

plt.figure(figsize=(6, 5))
plt.imshow(eta_true.T, extent=[x_int[0], x_int[-1], y_int[0], y_int[-1]], origin='lower')
plt.colorbar(label='Perturbation $\eta$')
plt.title('True Model (Gaussian Bumps)')
plt.show()

## 4. Generate Synthetic Data
We define observation points on a circle and compute the scattered field at these points.

In [None]:
# Setup observation points (32 points on a circle around the scatterer)
theta_r = jnp.linspace(0, 2 * jnp.pi, 32, endpoint=False)
points_query = 0.4 * jnp.stack([jnp.cos(theta_r), jnp.sin(theta_r)], axis=1)

# Grid for interpolation (relative to center)
x_full = (jnp.arange(nx) - npml - nxint // 2) * h
y_full = (jnp.arange(ny) - npml - nyint // 2) * h
projection_op = get_projection_op(x_full, y_full, points_query)

# Create forward function with custom adjoint for efficiency
forward_fun = create_forward_with_adjoint(solver, inc, projection_op)

print("Generating synthetic data...")
data = forward_fun(eta_true.flatten())

print(f"Data generated with shape: {data.shape} (obs_points, directions)")

## 5. Inverse Reconstruction
We solve the inverse problem using L-BFGS, starting from an initial guess of zero.

In [None]:
print("Starting Reconstruction (L-BFGS)...")
eta_init = jnp.zeros_like(eta_true).flatten()

# Run optimization for a few iterations
reconstructed_flat, state = solve_inverse_problem(
    eta_init, forward_fun, data, maxiter=30
)

reconstructed = reconstructed_flat.reshape((nxint, nyint))
print(f"Optimization finished. Final misfit: {state.value}")

## 6. Visualize Results
Comparison between the ground truth and the reconstructed model.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

im1 = axes[0].imshow(eta_true.T, extent=[x_int[0], x_int[-1], y_int[0], y_int[-1]], origin='lower')
axes[0].set_title('Ground Truth')
plt.colorbar(im1, ax=axes[0])

im2 = axes[1].imshow(reconstructed.T, extent=[x_int[0], x_int[-1], y_int[0], y_int[-1]], origin='lower')
axes[1].set_title('Reconstructed')
plt.colorbar(im2, ax=axes[1])

error = eta_true - reconstructed
im3 = axes[2].imshow(error.T, extent=[x_int[0], x_int[-1], y_int[0], y_int[-1]], origin='lower', cmap='seismic')
max_err = jnp.abs(error).max()
im3.set_clim(-max_err, max_err)
axes[2].set_title('Reconstruction Error')
plt.colorbar(im3, ax=axes[2])

plt.tight_layout()
plt.show()