# Wave Propagation Toolkit

This notebook demonstrates how to simulate coherent wave propagation without installing the full **adorym** package. It re-implements the minimal propagation utilities (`get_kernel` and `fresnel_propagate`) and applies them to a workflow that loads an initial wavefront, propagates it, interacts with a 3D object, and visualizes the resulting complex fields.

## Contents

1. Helper utilities and propagation functions
2. Load the initial wave (magnitude & optional phase)
3. Propagate the source wave to an intermediate plane
4. Build a projected transmission function from a 3D refractive index volume
5. Form the exit wave, propagate to the detector, and visualize
6. Multislice propagation through the 3D object

Every section is executable on its own once you provide the required input files.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import tifffile as tiff

np.set_printoptions(suppress=True)
plt.rcParams.update({"image.cmap": "magma"})

### Propagation helpers

The functions below are the only adorym components that we need. They provide Fresnel and Fraunhofer transfer functions along with a thin wrapper that applies them to a complex wavefront.

In [None]:
def _frequency_mesh(pixel_size_m, grid_shape):
    """Return spatial frequency meshgrids for a 2D grid."""
    fy = np.fft.fftfreq(grid_shape[0], d=pixel_size_m[0])
    fx = np.fft.fftfreq(grid_shape[1], d=pixel_size_m[1])
    return np.meshgrid(fx, fy)


def get_kernel(distance_m, wavelength_m, pixel_size_m, grid_shape,
               fresnel_approx=True, sign_convention=1):
    r"""Return the Fourier-domain propagation kernel.

    Parameters
    ----------
    distance_m : float
        Propagation distance in metres.
    wavelength_m : float
        Wavelength in metres.
    pixel_size_m : tuple of float
        Detector pixel size along (y, x) in metres.
    grid_shape : tuple of int
        Size of the 2D field (ny, nx).
    fresnel_approx : bool, optional
        If True use the Fresnel approximation, otherwise Fraunhofer.
    sign_convention : {1, -1}, optional
        Choose ``1`` for :math:`exp(ikz)` or ``-1`` for :math:`exp(-ikz)`.
    """
    fx, fy = _frequency_mesh(pixel_size_m, grid_shape)
    quad = fx**2 + fy**2
    if fresnel_approx:
        phase = -sign_convention * np.pi * wavelength_m * distance_m * quad
        return np.exp(1j * phase)

    spectral_radius = 1 - (wavelength_m**2) * quad
    kernel = np.zeros_like(fx, dtype=np.complex128)
    mask = spectral_radius > 0
    kernel[mask] = np.exp(sign_convention * 1j * 2 * np.pi * distance_m / wavelength_m *
                          np.sqrt(spectral_radius[mask]))
    return kernel


def fresnel_propagate(field, distance_m, wavelength_m, pixel_size_m,
                      mode="fresnel", sign_convention=1):
    """Propagate ``field`` over ``distance_m`` using FFT-based convolution."""
    fresnel_approx = mode.lower() != "fraunhofer"
    kernel = get_kernel(distance_m, wavelength_m, pixel_size_m, field.shape,
                        fresnel_approx=fresnel_approx, sign_convention=sign_convention)
    field_ft = np.fft.fft2(field)
    propagated = np.fft.ifft2(field_ft * kernel)
    return propagated

### General utilities

In [None]:
def load_initial_wave(magnitude_path, phase_path=None):
    """Load magnitude and (optional) phase TIFF files and build a complex wave."""
    magnitude = tiff.imread(magnitude_path).astype(np.float64)
    if phase_path is not None and Path(phase_path).exists():
        phase = tiff.imread(phase_path).astype(np.float64)
    else:
        phase = np.zeros_like(magnitude)
    return magnitude * np.exp(1j * phase)


def project_transmission(n_volume, wavelength_m, voxel_size_z_m):
    """Collapse ``n = delta + i beta`` volume to a 2D transmission function."""
    delta = np.real(n_volume)
    beta = np.imag(n_volume)
    k = 2 * np.pi / wavelength_m
    phase = -k * voxel_size_z_m * np.sum(delta, axis=0)
    absorption = -k * voxel_size_z_m * np.sum(beta, axis=0)
    transmission = np.exp(absorption + 1j * phase)
    return transmission


def slice_transmission(n_slice, wavelength_m, voxel_size_z_m):
    """Transmission for a single slice of ``delta + i beta``."""
    delta = np.real(n_slice)
    beta = np.imag(n_slice)
    k = 2 * np.pi / wavelength_m
    phase = -k * voxel_size_z_m * delta
    absorption = -k * voxel_size_z_m * beta
    return np.exp(absorption + 1j * phase)


def pad_or_crop_to_shape(array, target_shape):
    """Pad (or crop) ``array`` to ``target_shape`` centred."""
    padded = array
    for axis, target in enumerate(target_shape):
        current = padded.shape[axis]
        if current == target:
            continue
        if current < target:
            pad_before = (target - current) // 2
            pad_after = target - current - pad_before
            pad_width = [(0, 0)] * padded.ndim
            pad_width[axis] = (pad_before, pad_after)
            padded = np.pad(padded, pad_width, mode="constant")
        else:
            start = (current - target) // 2
            end = start + target
            slices = [slice(None)] * padded.ndim
            slices[axis] = slice(start, end)
            padded = padded[tuple(slices)]
    return padded


def show_amplitude_phase(field, title_prefix, extent=None):
    amplitude = np.abs(field)
    phase = np.angle(field)
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    ax = axes[0]
    im0 = ax.imshow(amplitude, extent=extent)
    ax.set_title(f"{title_prefix} amplitude")
    fig.colorbar(im0, ax=ax, shrink=0.8)
    ax = axes[1]
    im1 = ax.imshow(phase, extent=extent)
    ax.set_title(f"{title_prefix} phase")
    fig.colorbar(im1, ax=ax, shrink=0.8)
    plt.tight_layout()
    plt.show()


def show_intensity(field, title, extent=None):
    intensity = np.abs(field) ** 2
    plt.figure(figsize=(5, 4))
    plt.imshow(intensity, extent=extent)
    plt.colorbar(shrink=0.8)
    plt.title(title)
    plt.tight_layout()
    plt.show()


def hc_over_e():
    """Return hc in eV*m."""
    return 1.239841984e-6

## 1. Input configuration

Update the paths and experimental parameters in this cell. If the specified files are not found a synthetic example is generated so the notebook can still be executed.

In [None]:
# ---------------------------------------------------------------------------
# File paths (update to your data)
# ---------------------------------------------------------------------------
magnitude_path = Path("data/source_magnitude.tif")
phase_path = None  # e.g. Path("data/source_phase.tif")
volume_path = Path("data/object_volume.npy")  # numpy file storing complex delta + i*beta

# ---------------------------------------------------------------------------
# Experiment parameters
# ---------------------------------------------------------------------------
energy_ev = 8000.0  # photon energy in eV
wavelength_m = hc_over_e() / energy_ev
pixel_size_m = (25e-9, 25e-9)  # (dy, dx) pixel size at the wavefront plane
source_sample_distance_m = 0.02  # distance from source to sample
sample_detector_distance_m = 0.5  # distance from sample to detector
voxel_size_z_m = 50e-9           # object voxel size along z
propagation_mode = "fresnel"     # "fresnel" or "fraunhofer"

## 2. Load the initial wavefront

In [None]:
if magnitude_path.exists():
    initial_wave = load_initial_wave(magnitude_path, phase_path)
    print(f"Loaded magnitude from {magnitude_path}")
else:
    print("Magnitude file not found; generating a synthetic Gaussian beam.")
    grid_y, grid_x = 512, 512
    y = np.linspace(-1, 1, grid_y)
    x = np.linspace(-1, 1, grid_x)
    xx, yy = np.meshgrid(x, y)
    magnitude = np.exp(-((xx**2 + yy**2) / 0.1))
    phase = np.pi * (xx**2 - yy**2)
    initial_wave = magnitude * np.exp(1j * phase)

show_amplitude_phase(initial_wave, "Initial wave")

## 3. Propagate the source wave

In [None]:
propagated_wave = fresnel_propagate(initial_wave, source_sample_distance_m,
                                     wavelength_m, pixel_size_m,
                                     mode=propagation_mode)

show_amplitude_phase(initial_wave, "Source plane")
show_amplitude_phase(propagated_wave, f"After {source_sample_distance_m:.3f} m")

## 4. Build the transmission function from a 3D refractive index volume

In [None]:
if volume_path.exists():
    n_volume = np.load(volume_path)
    print(f"Loaded volume from {volume_path}")
else:
    print("Volume file not found; generating a synthetic object (phase grating).")
    nz, ny, nx = 64, initial_wave.shape[0] // 2, initial_wave.shape[1] // 2
    y = np.linspace(-1, 1, ny)
    x = np.linspace(-1, 1, nx)
    xx, yy = np.meshgrid(x, y)
    delta = 5e-6 * np.exp(-((xx**2 + yy**2) / 0.2))
    beta = 2e-7 * (1 + 0.5 * np.sin(10 * xx))
    n_slice = delta + 1j * beta
    n_volume = np.stack([n_slice for _ in range(nz)], axis=0)

transmission = project_transmission(n_volume, wavelength_m, voxel_size_z_m)
transmission = pad_or_crop_to_shape(transmission, initial_wave.shape)
show_amplitude_phase(transmission, "Transmission (projected)")

## 5. Exit wave and detector plane

In [None]:
exit_wave = propagated_wave * transmission
wave_at_detector = fresnel_propagate(exit_wave, sample_detector_distance_m,
                                     wavelength_m, pixel_size_m,
                                     mode=propagation_mode)

show_intensity(wave_at_detector, "Detector intensity")
show_amplitude_phase(wave_at_detector, "Detector wave")

## 6. Multislice propagation through the 3D object

In [None]:
wave_multislice = propagated_wave.copy()
for idx, n_slice in enumerate(n_volume):
    slice_t = slice_transmission(n_slice, wavelength_m, voxel_size_z_m)
    slice_t = pad_or_crop_to_shape(slice_t, wave_multislice.shape)
    wave_multislice *= slice_t
    if idx < len(n_volume) - 1:
        wave_multislice = fresnel_propagate(wave_multislice, voxel_size_z_m,
                                            wavelength_m, pixel_size_m,
                                            mode=propagation_mode)

wave_multislice_detector = fresnel_propagate(wave_multislice, sample_detector_distance_m,
                                             wavelength_m, pixel_size_m,
                                             mode=propagation_mode)

show_intensity(wave_multislice_detector, "Detector intensity (multislice)")
show_amplitude_phase(wave_multislice_detector, "Detector wave (multislice)")