# Forward Scattering Problem Demo

This notebook demonstrates how to use the JAX-based Helmholtz solver to compute the scattered field from a perturbation composed of several Gaussian bumps.

In [None]:
import sys
import os
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Add the project root to path to import inverse_scattering_jax
sys.path.append(os.path.abspath(".."))

from inverse_scattering_jax.src.helmholtz import HelmholtzSolver, HelmholtzOperator
from inverse_scattering_jax.src.inverse_scattering import IncomingDirections, get_projection_op

## 1. Setup Parameters and Solver

In [None]:
# Interior points.
nxint, nyint = 100, 100
# Number of points for the PML.
npml = 20
# Total number of points (in the exteded domain).
nx, ny = nxint + 2 * npml, nyint + 2 * npml
# Space discretization.
h = 1.0 / (nxint - 1)
# Angular frequency.
omega = 20.0
# Maximum dampening in the PML.
sigma_max = 80.0
# Order of the finite differences discretization.
order = 4

# Define the Helmholtz operator. Here we need to define the data type.
op = HelmholtzOperator(
  nx, ny, npml, h, omega, sigma_max, order, dtype=jnp.complex64
)
solver = HelmholtzSolver(op)
inc = IncomingDirections(nx, ny, npml, h, omega, n_theta=8)

# Grid for visualization
x_int = (jnp.arange(nxint) - nxint//2) * h
y_int = (jnp.arange(nyint) - nyint//2) * h
Xi, Yi = jnp.meshgrid(x_int, y_int, indexing='ij')

## 2. Define Gaussian Perturbation

We create a few Gaussian bumps as the scattering medium.

In [None]:
def gaussian_bump(X, Y, x0, y0, sigma, amplitude):
    return amplitude * jnp.exp(-((X - x0)**2 + (Y - y0)**2) / (2 * sigma**2))

eta = jnp.zeros((nxint, nyint))
eta += gaussian_bump(Xi, Yi, -0.2, -0.2, 0.05, 0.2)
eta += gaussian_bump(Xi, Yi, 0.2, 0.1, 0.03, 0.3)
eta += gaussian_bump(Xi, Yi, 0.0, 0.3, 0.04, 0.25)

plt.figure(figsize=(6, 5))
plt.imshow(eta.T, extent=[x_int[0], x_int[-1], y_int[0], y_int[-1]])
plt.colorbar(label='Perturbation $\eta$')
plt.title("Scattering Medium (Gaussian Bumps)")
plt.show()

## 3. Compute Scattered Field

We use the direct solver to compute the wavefield for each incident direction.

In [None]:
from inverse_scattering_jax.src.helmholtz import extend_model

eta_ext = extend_model(eta, nxint, nyint, npml)
m_ext = 1.0 + eta_ext
rhs = inc.get_rhs(eta_ext)

# Solve for the first direction
f_0 = rhs[:, 0]
u_scattered_0, info = solver.solve(f_0, m_ext)
u_scattered_0 = u_scattered_0.reshape((ny, nx))

# Total wavefield
u_total_0 = (u_scattered_0.flatten() + inc.U_in[:, 0]).reshape((ny, nx))

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.imshow(jnp.real(u_scattered_0).T, extent=[x_int[0], x_int[-1], y_int[0], y_int[-1]])
plt.title("Real part of Scattered Field")
plt.colorbar()

plt.subplot(1, 2, 2)
plt.imshow(jnp.abs(u_total_0).T, extent=[x_int[0], x_int[-1], y_int[0], y_int[-1]])
plt.title("Amplitude of Total Field")
plt.colorbar()
plt.show()