# Paraxial Wave Solver Demo

This notebook demonstrates the usage of the `paraxial_wave_solver` library. 
It simulates the propagation of a Gaussian beam through a random medium using the JAX-based solver.

In [None]:
# Install the package directly from GitHub using pip.
!pip install git+https://github.com/Forgotten/paraxial_wave_solver.git
!pip install matplotlib

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Import the solver module.
import paraxial_wave_solver as pws

In [None]:
# Setup the geometrical domain of the simulation.
sim_config = pws.SimulationConfig(
    nx=256,
    ny=256,
    dx=0.2,
    dy=0.2,
    dz=0.25,
    nz=200,
    wavelength=1.0
)

pml_config = pws.PMLConfig(width_x=20, width_y=20, strength=5.0)

# Setup the solver configuration. Here we use the spectral solver with the split-step method.
solver_config = pws.SolverConfig(method='spectral', stepper='split_step')

In [None]:
# Initial condition.
w0 = 5.0
psi_0 = pws.gaussian_beam(sim_config, w0=w0)

# Random medium.
key = jax.random.PRNGKey(42)
delta_n = pws.random_medium(sim_config, correlation_length=2.0, strength=0.05, key=key)

# Define refractive index function.
def n_ref_fn(z):
    # Find index
    idx = jnp.clip(jnp.round(z / sim_config.dz).astype(int), 0, sim_config.nz - 1)
    return 1.0 + delta_n[:, :, idx]

In [None]:
# Run simulation.
# This is jitted internally.
print("Running simulation in random media (Spectral + Split-Step)...")
solver = pws.ParaxialWaveSolver(sim_config, solver_config, pml_config, n_ref_fn)
psi_final, psi_history = solver.solve(psi_0)
print("Simulation complete.")

In [None]:
# Visualize.
center_y = sim_config.ny // 2
field_xz = psi_history[:, :, center_y].T
intensity_xz = jnp.abs(field_xz)**2

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
extent = [0, sim_config.lz, 0, sim_config.lx]
plt.imshow(intensity_xz, extent=extent, origin='lower', cmap='inferno', aspect='auto')
plt.colorbar(label='Intensity')
plt.xlabel('z')
plt.ylabel('x')
plt.title('Propagation (Spectral)')

plt.subplot(1, 2, 2)
# Plot refractive index slice.
plt.imshow(delta_n[:, center_y, :], extent=extent, origin='lower', cmap='gray', aspect='auto')
plt.colorbar(label='delta n')
plt.xlabel('z')
plt.ylabel('x')
plt.title('Refractive Index Perturbation (XZ)')

plt.tight_layout()
plt.show()

## Finite Difference Solver

We can also use a Finite Difference method with RK4 integration. Here we use a 4th-order FD stencil.

In [None]:
# Configure FD solver
solver_config_fd = pws.SolverConfig(
    method='finite_difference', 
    fd_order=4, 
    stepper='rk4'
)

print("Running simulation in random media (FD4 + RK4)...")
solver_fd = pws.ParaxialWaveSolver(sim_config, solver_config_fd, pml_config, n_ref_fn)
psi_final_fd, psi_history_fd = solver_fd.solve(psi_0)
print("Simulation complete.")

In [None]:
# Visualize FD Result
field_xz_fd = psi_history_fd[:, :, center_y].T
intensity_xz_fd = jnp.abs(field_xz_fd)**2

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
extent = [0, sim_config.lz, 0, sim_config.lx]
plt.imshow(intensity_xz_fd, extent=extent, origin='lower', cmap='inferno', aspect='auto')
plt.colorbar(label='Intensity')
plt.xlabel('z')
plt.ylabel('x')
plt.title('Propagation (FD4 + RK4)')

plt.subplot(1, 2, 2)
# Difference between Spectral and FD
diff = jnp.abs(intensity_xz - intensity_xz_fd)
plt.imshow(diff, extent=extent, origin='lower', cmap='viridis', aspect='auto')
plt.colorbar(label='|Diff|')
plt.xlabel('z')
plt.ylabel('x')
plt.title('Difference (Spectral vs FD)')

plt.tight_layout()
plt.show()