# EX_03

In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
import numpy as np
import math
import torch.optim as optim
import torch.nn as nn

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Physical and Simulation Parameters ---
Nx = Ny = 256  # Grid points
lambda0_scalar = 532e-9  # m, wavelength
Lx_phys = Ly_phys = 0.6e-3  # m, physical size of computational window

# Propagation distances
Lz_segment = 5 * Lx_phys  # Distance P1-P2, and P2-P3
L_total = 2 * Lz_segment


# --- Helper Functions ---
def create_grids(N, phys_L, device_val, dtype_val=torch.float32):
    """Creates 1D and 2D coordinate and frequency grids."""
    dx = phys_L / N
    # Spatial coordinates
    x_1d = torch.linspace(-phys_L / 2, phys_L / 2 - dx, N, device=device_val, dtype=dtype_val)
    X, Y = torch.meshgrid(x_1d, x_1d, indexing='ij')

    # Spatial frequencies (for torch.fft.fft2 output order: DC at [0,0])
    kx_1d = 2 * math.pi * torch.fft.fftfreq(N, d=dx, device=device_val, dtype=dtype_val)
    Ky, Kx = torch.meshgrid(kx_1d, kx_1d, indexing='ij') # Note: meshgrid's first arg varies slowest (rows)
                                                      # For (ky, kx) to match image (row, col) -> (y, x)
                                                      # Ky corresponds to y (rows), Kx to x (cols)
    return X, Y, Kx, Ky, dx


# --- Angular Spectrum Propagator ---
def propagate_angular_spectrum(field_xy_in, Lz_prop, lambda0_s, Lx_s, Ly_s, backward=False):
    """
    Propagates a 2D complex field using the angular spectrum method.
    Assumes field_xy_in is complex.
    For backward propagation, set backward=True.
    """
    N_y, N_x = field_xy_in.shape
    current_device = field_xy_in.device
    dtype_real = torch.float32 # Base for physical params
    dtype_complex = field_xy_in.dtype


    lambda0 = torch.tensor(lambda0_s, device=current_device, dtype=dtype_real)
    k0 = 2 * math.pi / lambda0

    _, _, Kx_grid, Ky_grid, _ = create_grids(N_x, Lx_s, current_device, dtype_val=dtype_real)
    # Ky_grid, Kx_grid from create_grids are for (Ny, Nx) if Ly, Lx used for Ny, Nx respectively
    # If Nx=Ny and Lx=Ly, then Kx_grid and Ky_grid from create_grids(Nx, Lx_s, ...) is fine.

    Kx2_plus_Ky2 = Kx_grid**2 + Ky_grid**2

    # kz component
    # sqrt_arg can be negative for evanescent waves
    sqrt_arg = k0**2 - Kx2_plus_Ky2
    kz = torch.sqrt(torch.complex(torch.relu(sqrt_arg), torch.abs(torch.minimum(sqrt_arg, torch.tensor(0.0, device=current_device)))))
    # For evanescent waves (sqrt_arg < 0): kz becomes purely imaginary: j * sqrt(abs(sqrt_arg))
    # This leads to exp(- Lz * sqrt(Kx2+Ky2-k0^2)), which is decay.

    # Propagator
    if backward:
        propagator_k_space = torch.exp(-1j * kz.to(dtype_complex) * Lz_prop) # H* or H(-Lz)
    else:
        propagator_k_space = torch.exp(1j * kz.to(dtype_complex) * Lz_prop)  # H

    # Perform propagation
    field_k_space = torch.fft.fft2(field_xy_in)
    propagated_k_space = field_k_space * propagator_k_space
    propagated_field_xy = torch.fft.ifft2(propagated_k_space)

    return propagated_field_xy

Using device: cpu


In [3]:
torch.load("t_mid_A.pt")
torch.load("t_mid_B.pt")
torch.load("U_out_A.pt")
torch.load("U_out_B.pt")

tensor([[ 0.0973-0.0071j,  0.0669+0.0203j, -0.0002+0.0417j,  ...,
          0.0089-0.0966j,  0.0399-0.0647j,  0.0853-0.0148j],
        [ 0.0375+0.0019j,  0.0487+0.0084j,  0.0418-0.0005j,  ...,
         -0.0169-0.0251j,  0.0178-0.0321j,  0.0420-0.0195j],
        [-0.0757+0.0100j, -0.0430-0.0421j,  0.0266-0.0566j,  ...,
         -0.0104+0.0761j, -0.0326+0.0571j, -0.0645+0.0214j],
        ...,
        [-0.0854+0.0472j, -0.0592+0.0141j, -0.0215-0.0449j,  ...,
          0.0144+0.0582j,  0.0064+0.0397j, -0.0575+0.0273j],
        [-0.0555+0.0024j, -0.0331-0.0092j, -0.0079-0.0170j,  ...,
         -0.0112+0.0693j, -0.0296+0.0453j, -0.0693+0.0149j],
        [ 0.0365-0.0223j,  0.0172-0.0075j, -0.0131+0.0394j,  ...,
          0.0058-0.0272j, -0.0037-0.0068j,  0.0089+0.0046j]])