# FDTD vs Spectral Methods (3D Baseline, Atomic Model)

This notebook compares a **3D FDTD** simulation against spectral propagation methods (Fresnel, Angular Spectrum, WPM)
for a **single-atom 3D potential**. The FDTD settings mirror the working configuration from the 3D plane-wave notebook
for stability and reproducibility.

In [None]:
%matplotlib widget

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".5"

import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import abtem
import ase
from ase.units import _hplanck, _c, _me, _e

from wide_angle_propagation import (
    simulate_fresnel_as,
    simulate_wpm,
    fresnel_propagation_kernel,
    angular_spectrum_propagation_kernel
)
from wide_angle_propagation import RelativisticSolver3D, calculate_physics_params

# Configure JAX (64-bit for stability)
jax.config.update("jax_enable_x64", True)

## 1. Physics and Grid Setup

We use the same wavelength and grid settings as the working 3D plane-wave notebook to keep FDTD stable.

In [None]:
def wavelength_to_energy_ev(lam_angstrom):
    lam = lam_angstrom * 1e-10
    h = _hplanck
    c = _c
    m = _me
    p = h / lam
    E_total_J = np.sqrt((p*c)**2 + (m*c**2)**2)
    E_kin_J = E_total_J - (m*c**2)
    return E_kin_J / _e

# Plane-wave notebook settings
wavelength_angstrom = 0.05
e_total, p_inf, e_sim = calculate_physics_params(wavelength_angstrom)
k0 = p_inf
energy_ev = wavelength_to_energy_ev(wavelength_angstrom)

# Grid (ROI)
nx, ny = 64, 64
roi_nz = 128

roi_xmin, roi_xmax = -1.0, 1.0
roi_ymin, roi_ymax = -1.0, 1.0
roi_zmin, roi_zmax = -1.0, 1.0

x = np.linspace(roi_xmin, roi_xmax, nx)
y = np.linspace(roi_ymin, roi_ymax, ny)
z = np.linspace(roi_zmin, roi_zmax, roi_nz)
dx, dy, dz = x[1] - x[0], y[1] - y[0], z[1] - z[0]
extent_roi = [roi_xmin, roi_xmax, roi_ymin, roi_ymax, roi_zmin, roi_zmax]

# Extended grid for FDTD
pml_thickness = 0.4
source_buffer = 0.2
z_pad = pml_thickness + source_buffer
ext_zmin = roi_zmin - z_pad
ext_zmax = roi_zmax + z_pad
ext_nz = int(np.round((ext_zmax - ext_zmin) / dz))
z_ext = np.linspace(ext_zmin, ext_zmin + (ext_nz - 1) * dz, ext_nz)
ext_zmax = z_ext[-1]
extent_ext = [roi_xmin, roi_xmax, roi_ymin, roi_ymax, ext_zmin, ext_zmax]

print("Grid Summary:")
print(f"  ROI:  {nx}x{ny}x{roi_nz}")
print(f"  FDTD: {nx}x{ny}x{ext_nz}")
print(f"  Steps: dx={dx:.4f}, dz={dz:.4f}")
print(f"  Energy: {energy_ev:.2f} eV")

## 2. Build Atomic Potential (Single Atom)

We use abTEM to generate a 3D potential for a single Au atom, with slice thickness matching dz.

In [None]:
# Match single_atom defaults for determinism
abtem.config.set({"device": "cpu"})
abtem.config.set({"precision": "float64"})

cell_x = roi_xmax - roi_xmin
cell_y = roi_ymax - roi_ymin
cell_z = roi_nz * dz
atoms = ase.Atoms('Au', cell=[cell_x, cell_y, cell_z], pbc=True)
atoms.center()

# Build potential: abTEM returns (nz, ny, nx)
pot = abtem.Potential(atoms, gpts=(nx, ny), slice_thickness=dz, projection='finite')
pot_arr = pot.build(lazy=False).array

print(f"Potential shape (nz, ny, nx): {pot_arr.shape}")
plt.figure()
plt.imshow(pot_arr[:, ny // 2, :], extent=[roi_xmin, roi_xmax, roi_zmin, roi_zmax], origin='lower', cmap='RdBu')
plt.title("Potential Slice (XZ)")
plt.colorbar()
plt.show()

## 3. Embed Potential into Extended Grid (FDTD)

In [None]:
# abTEM array is (nz, ny, nx). Transpose to (z, y, x) with z-major indexing
v_phys_roi_t = pot_arr

v_eff_fdtd_t = np.zeros((ext_nz, ny, nx), dtype=np.float64)

# Find index in extended z-grid corresponding to ROI z_min (best match)
k0 = int(np.argmin(np.abs(z_ext - roi_zmin)))
if k0 + roi_nz > ext_nz:
    k0 = ext_nz - roi_nz

v_eff_fdtd_t[k0 : k0 + roi_nz, :, :] = -v_phys_roi_t

print(f"Inserted ROI potential into ext grid at z indices [{k0}, {k0 + roi_nz}) of {ext_nz}")
plt.figure()
plt.imshow(v_eff_fdtd_t[:, ny // 2, :], extent=[roi_xmin, roi_xmax, ext_zmin, ext_zmax], origin='upper', cmap='RdBu')
plt.title("Zero-padded Potential Slice (XZ)")
plt.colorbar()
plt.show()

## 4. Run 3D FDTD (Plane-wave settings)

In [None]:
total_steps = 30000
n_frames = 1
steps_per_frame = total_steps // n_frames
snapshot_count = 1

def run_fdtd_3d():
    solver = RelativisticSolver3D()

    # Potential energy in eV (abTEM returns projected potential; divide by dz)
    v_fdtd_ev = v_eff_fdtd_t / dz

    # Source plane index (near low-z side, after PML)
    dist_src = pml_thickness + source_buffer * 0.5
    k_source = int(dist_src / dz)
    pml_cells = int(np.ceil(pml_thickness / dz))
    min_source = pml_cells + 2
    if k_source < min_source:
        k_source = min_source
    k_source = max(1, min(k_source, ext_nz - 2))

    print(f"Total steps: {total_steps} (saving {n_frames} frames)")
    print(f"k_source={k_source} (z ~ {ext_zmin + k_source * dz:.3f} Å)")

    sim_result = solver.run(
        v_fdtd_ev,
        wavelength_angstrom=wavelength_angstrom,
        extent=extent_ext,
        n_frames=n_frames,
        steps_per_frame=steps_per_frame,
        snapshot_count=snapshot_count,
        pml_thick=pml_thickness,
        k_source=k_source,
        use_angstrom_units=True,
    )

    if not sim_result:
        raise RuntimeError("FDTD Failed")
    return sim_result

fdtd_res = run_fdtd_3d()
snapshots = fdtd_res['snapshots']
E_sim = fdtd_res['E_sim']
dt = fdtd_res['dt']

# Phase advance per saved frame (check for near-2π aliasing)
phase_per_frame = E_sim * steps_per_frame * dt

print(f"Phase advance per frame: {phase_per_frame:.4f} rad")
print(f"Phase advance / 2π = {phase_per_frame / 2 * np.pi:.4f}")

## 5. Extract and Normalize FDTD Exit Wave

In [None]:
k_exit = np.argmin(np.abs(z_ext - roi_zmax))
k_entry = np.argmin(np.abs(z_ext - roi_zmin))
k_monitor = max(0, k_entry - 2)

psi_exit_accum = np.zeros((ny, nx), dtype=np.complex128)
psi_monitor_accum = np.zeros((ny, nx), dtype=np.complex128)
count = 0

start_avg = int(0.6 * len(snapshots))
for t, phi, psi in snapshots[start_avg:]:
    phasor = np.exp(1j * E_sim * t)
    field3d = (phi + 1j * psi) * phasor
    psi_exit_accum += field3d[k_exit, :, :]
    psi_monitor_accum += field3d[k_monitor, :, :]
    count += 1

psi_fdtd_exit = psi_exit_accum / max(1, count)
psi_monitor = psi_monitor_accum / max(1, count)

# Normalize by incident intensity measured before ROI
I_incident = np.mean(np.abs(psi_monitor) ** 2)
if I_incident > 1e-9:
    psi_fdtd_exit = psi_fdtd_exit / np.sqrt(I_incident)
    print(f"Normalized FDTD by incident intensity: {I_incident:.6e}")

plt.figure()
plt.imshow(np.abs(psi_fdtd_exit) ** 2, extent=[roi_xmin, roi_xmax, roi_ymin, roi_ymax], origin='lower')
plt.title("FDTD Exit Wave Intensity (Angstrom Units)")
plt.colorbar()
plt.show()

## 6. Spectral Methods (Fresnel / Angular Spectrum / WPM)

In [None]:
# Initial wave: plane wave (all ones)
psi_init = jnp.ones((ny, nx), dtype=jnp.complex128)

# Potential for BPM (abtem yields potential energy in eV); divide by dz to get per-slice potential
v_bpm_eV = np.array(v_phys_roi_t, dtype=np.float64) / dz
v_bpm_eV_jax = jnp.array(v_bpm_eV)

# Sampling for kernels (y, x)
sampling = (dy, dx)

print("Starting spectral simulations (using energy in eV)...")

# Pre-propagate through vacuum to match FDTD padding
vacuum_dist = z_pad
prop_kernel_vac = angular_spectrum_propagation_kernel(ny, nx, sampling, vacuum_dist, energy_ev)
psi_init = jnp.fft.ifft2(jnp.fft.fft2(psi_init) * prop_kernel_vac)

# 1. Fresnel
prop_kernel_fresnel = fresnel_propagation_kernel(ny, nx, sampling, dz, energy_ev)
psi_fresnel, _, psi_fresnel_xyz = simulate_fresnel_as(v_bpm_eV_jax, psi_init, prop_kernel_fresnel, dz, energy_ev)

# 2. Angular Spectrum
prop_kernel_as = angular_spectrum_propagation_kernel(ny, nx, sampling, dz, energy_ev)
psi_as, _, psi_as_xyz = simulate_fresnel_as(v_bpm_eV_jax, psi_init, prop_kernel_as, dz, energy_ev)

# 3. WPM
psi_wpm, _, psi_wpm_xyz = simulate_wpm(v_bpm_eV_jax, psi_init, dz, energy_ev, sampling, n_bins=32)

print("Spectral simulations complete.")

## 7. Comparison (Exit Plane and X–Z Slices)

In [None]:
methods = [
    (np.abs(psi_fdtd_exit) ** 2, "FDTD (3D)"),
    (np.abs(psi_fresnel) ** 2, "Fresnel"),
    (np.abs(psi_as) ** 2, "Ang. Spec."),
    (np.abs(psi_wpm) ** 2, "WPM")
]

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for ax, (data, title) in zip(axes, methods):
    im = ax.imshow(data, extent=[roi_xmin, roi_xmax, roi_ymin, roi_ymax], origin='lower', cmap='inferno')
    ax.set_title(title)
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

# X-Z view (central Y slice)
central_y = ny // 2
start_avg = int(0.6 * len(snapshots))
psi_fdtd_xz_accum = np.zeros((len(z_ext), nx), dtype=np.complex128)
count = 0
for t, phi, psi in snapshots[start_avg:]:
    phasor = np.exp(1j * E_sim * t)
    field3d = (phi + 1j * psi) * phasor
    psi_fdtd_xz_accum += field3d[:, central_y, :]
    count += 1
psi_fdtd_xz = psi_fdtd_xz_accum / max(1, count)

psi_fresnel_xz = (np.abs(np.array(psi_fresnel_xyz)[:, central_y, :]) ** 2).astype(np.float64)
psi_as_xz = (np.abs(np.array(psi_as_xyz)[:, central_y, :]) ** 2).astype(np.float64)
psi_wpm_xz = (np.abs(np.array(psi_wpm_xyz)[:, central_y, :]) ** 2).astype(np.float64)

extent_xz = [x[0], x[-1], z_ext[0], z_ext[-1]]
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

im0 = axes[0].imshow(np.abs(psi_fdtd_xz) ** 2, extent=extent_xz, origin='upper', aspect='auto', cmap='inferno')
axes[0].set_title("FDTD (3D) X-Z")
axes[0].set_xlabel("X (Å)")
axes[0].set_ylabel("Z (Å)")
plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

im1 = axes[1].imshow(psi_fresnel_xz, extent=extent_xz, origin='upper', aspect='auto', cmap='inferno')
axes[1].set_title("Fresnel X-Z")
axes[1].set_xlabel("X (Å)")
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

im2 = axes[2].imshow(psi_as_xz, extent=extent_xz, origin='upper', aspect='auto', cmap='inferno')
axes[2].set_title("Ang. Spec. X-Z")
axes[2].set_xlabel("X (Å)")
plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)

im3 = axes[3].imshow(psi_wpm_xz, extent=extent_xz, origin='upper', aspect='auto', cmap='inferno')
axes[3].set_title("WPM X-Z")
axes[3].set_xlabel("X (Å)")
plt.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

In [None]:
# 1D cross-section of all waves at z = -1 Å (normalized by each curve's max)
target_z = 1.0
iz_roi = int(np.argmin(np.abs(z - target_z)))       # for spectral methods (z)
iz_fdtd = int(np.argmin(np.abs(z_ext - target_z)))   # for FDTD (z_ext)

cross_fdtd = np.abs(psi_fdtd_xz[iz_fdtd, :]) ** 2
cross_fresnel = psi_fresnel_xz[iz_roi, :].astype(np.float64)
cross_as = psi_as_xz[iz_roi, :].astype(np.float64)
cross_wpm = psi_wpm_xz[iz_roi, :].astype(np.float64)

# Normalize each by its own maximum (safe against zero max)
def norm(a):
    m = np.max(a)
    return a / m if m > 0 else a

cross_fdtd = norm(cross_fdtd)
cross_fresnel = norm(cross_fresnel)
cross_as = norm(cross_as)
cross_wpm = norm(cross_wpm)

plt.figure(figsize=(8, 4))
plt.plot(x, cross_fdtd,  label=f'FDTD (z={z_ext[iz_fdtd]:.3f} Å)', lw=1.6)
plt.plot(x, cross_fresnel, label=f'Fresnel (z={z[iz_roi]:.3f} Å)', lw=1.2)
plt.plot(x, cross_as,     label=f'Ang. Spec. (z={z[iz_roi]:.3f} Å)', lw=1.2)
plt.plot(x, cross_wpm,    label=f'WPM (z={z[iz_roi]:.3f} Å)', lw=1.2)
plt.xlabel('x (Å)')
plt.ylabel('Normalized intensity (a.u.)')
plt.title(f'1D cross-section at z = {target_z:.3f} Å (normalized)')
plt.grid(True)
plt.legend(loc='best')
plt.tight_layout()
plt.show()