In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft2, ifft2, fftshift, ifftshift
from PIL import Image, ImageDraw, ImageFont

%matplotlib ipympl

# Parameters
Nx, Ny = 100, 100 # SLM pixels
wavelength = 632.8e-9  # meters (not explicitly used for scale in far field)
pixel_pitch = 32e-6  # meters (not explicitly used in simulation, can be used for scaling)
fill_factor = 1/3
aperture_diameter = Ny  # pixels for flat top illumination
num_iterations = 10  # GS Iterations
central_spot_radius = 10  # pixels: radius of ZOD central spot to suppress
oversample_factor = 3

dxp = pixel_pitch # pupil plane px size

In [None]:
# Weight balancing and final phase mask generation
C_target = 1.0
# Example weight; iteratively, user may choose to scan this parameter
C_corr = 0.5

def zero_pad(array, pad_factor = 2):

    Nx, Ny = array.shape
    Nx_padded = Nx * pad_factor
    Ny_padded = Ny * pad_factor

    padded_array = np.zeros((Nx_padded, Ny_padded), dtype=array.dtype)
    # Insert the original array into the center of the padded array
    start_x = (Nx_padded - Nx) // 2
    start_y = (Ny_padded - Ny) // 2

    padded_array[start_x:start_x+Nx, start_y:start_y+Ny] = array
    return padded_array

def depad(padded_array, pad_factor = 2):
    Nx_padded, Ny_padded = padded_array.shape
    Nx = Nx_padded // pad_factor
    Ny = Ny_padded // pad_factor

    start_x = (Nx_padded - Nx) // 2
    start_y = (Ny_padded - Ny) // 2

    cropped_array = padded_array[start_x:start_x+Nx, start_y:start_y+Ny]
    return cropped_array


def create_circular_aperture(Nx, Ny, diameter):
    x = np.arange(Nx) - Nx // 2
    y = np.arange(Ny) - Ny // 2
    X, Y = np.meshgrid(x, y, indexing='ij')
    radius = diameter / 2
    aperture_mask = ((X**2 + Y**2) <= radius**2).astype(float)
    return aperture_mask

def apply_fill_factor_with_oversampling(mask, oversample_factor=10, fill_factor=0.85):
    Nx, Ny = mask.shape
    # Oversample by repeating pixels
    mask_os = np.repeat(np.repeat(mask, oversample_factor, axis=0), oversample_factor, axis=1)
    # Create fill factor pixel mask (active modulation area)
    active_size = int(oversample_factor * fill_factor)
    start_idx = (oversample_factor - active_size) // 2
    pixel_mask = np.zeros((oversample_factor, oversample_factor), dtype=bool)
    pixel_mask[start_idx:start_idx+active_size, start_idx:start_idx+active_size] = True
    # Tile over entire mask
    tiled_mask = np.tile(pixel_mask, (Nx, Ny))
    # Apply fill factor by zeroing dead regions
    filled_mask = mask_os * tiled_mask
    return filled_mask

def gerchberg_saxton(target_amplitude, aperture_amplitude, num_iter=50):

    phase = np.exp(1j * 2 * np.pi * np.random.rand(*aperture_amplitude.shape))
    field =  zero_pad(aperture_amplitude * phase)
    target_amplitude = zero_pad(target_amplitude)
    aperture_amplitude = zero_pad(aperture_amplitude)

    for _ in range(num_iter):
        far_field = fft2((field))
        far_field_phase = np.angle(far_field)
        far_field = target_amplitude * np.exp(1j * far_field_phase)
        field = ifft2(far_field)
        field = aperture_amplitude * np.exp(1j * np.angle(field)) # reminder that aperture_amplitude is a mask that contains 0 in the dead-zones

    return np.angle(depad((field)))

def generate_far_field_intensity(phase, aperture_amplitude):
    # TODO add (/wl*F) after fft2 to include F/D ?
    field = aperture_amplitude * np.exp(1j * phase)
    far_field = depad(fft2((zero_pad(field))))
    intensity = np.abs(far_field)**2
    return (intensity / intensity.max())

    # field = aperture_amplitude * np.exp(1j * phase)
    # far_field = fft2(zero_pad(field))
    # intensity = np.abs(far_field)**2
    # return depad(fftshift(intensity / intensity.max()))

def compute_weighted_phase(phi_corr, phi_target, aperture_ampl, C_corr=0.35,C_target=1):
    U_corr = aperture_ampl * np.exp(1j * phi_corr)
    U_target = aperture_ampl * np.exp(1j * phi_target)
    U_slm = C_corr * U_corr + C_target * U_target
    return np.angle(U_slm)

def downsample_phase(phase_os, slm_mask, oversample_factor):
    Nx_os, Ny_os = phase_os.shape
    Nx, Ny = Nx_os // oversample_factor, Ny_os // oversample_factor

    # Reshape arrays into blocks representing original pixels and oversampled subpixels
    phase_blocks = phase_os.reshape(Nx, oversample_factor, Ny, oversample_factor)
    mask_blocks = slm_mask.reshape(Nx, oversample_factor, Ny, oversample_factor)

    # Convert phase to complex representation, multiply by mask
    complex_phase = np.exp(1j * phase_blocks) * mask_blocks

    # Sum complex values and count active pixels in each block
    sum_complex = complex_phase.sum(axis=(1, 3))
    count_active = mask_blocks.sum(axis=(1, 3))

    # Avoid divide by zero
    count_active[count_active == 0] = 1

    # Average phase by angle of complex vector sum
    avg_phase = np.angle(sum_complex / count_active)

    return avg_phase


image_path = 'heig.png'  # Replace with your image file
img = Image.open(image_path).convert('L')  # Grayscale



# Resize to target resolution (Nx, Ny)
img = img.resize((50,50))

# Convert to normalized numpy array
target_amplitude = np.zeros((Nx,Ny))
target_amplitude[Nx//2 -25:Nx//2 +25,Ny//2 -25:Ny//2 +25] = np.array(img) / 255.0
target_amplitude[target_amplitude >=.5] = 1
target_amplitude[target_amplitude<.5] = 0


plt.figure()
plt.imshow(target_amplitude, cmap='gray')
plt.title('Target Image for GS algorithm')
plt.axis('off')
plt.show()

In [None]:
# Create aperture mask
# aperture_mask = create_circular_aperture(Nx, Ny, aperture_diameter)
aperture_mask = np.ones((Nx, Ny))
Nx_os, Ny_os = aperture_mask.shape



In [None]:
# Initialize amplitude_slm at oversampled resolution
amplitude_slm = aperture_mask

# Create target amplitude patterns at oversampled resolution

y_c, x_c = Ny_os // 2, Nx_os // 2
X2, Y2 = np.meshgrid(np.arange(Nx_os), np.arange(Ny_os),indexing='ij')
mask_central = (X2 - x_c)**2 + (Y2 - y_c)**2 <= (central_spot_radius * oversample_factor)**2

# Central spot correction
target_amplitude_correction = np.zeros((Nx_os, Ny_os))
target_amplitude_correction[mask_central] = 1

# blank target for all spots removal

target_amplitude1 = np.zeros((Nx_os, Ny_os))
target_amplitude1 = target_amplitude



# target_amplitude1 = apply_fill_factor_with_oversampling(target_amplitude, oversample_factor, 1)

print("Computing correction hologram phase...")
phi_corr = gerchberg_saxton(target_amplitude_correction**(1/2), amplitude_slm, num_iter=num_iterations)
print("Computing target hologram phase (central spot off)...")
phi_target = gerchberg_saxton(target_amplitude1**(1/2), amplitude_slm, num_iter=num_iterations)
print("Computing target hologram phase (all spots off)...")
phi_total = compute_weighted_phase(C_corr, phi_corr, phi_target, amplitude_slm)



In [None]:
intensity_raw = generate_far_field_intensity(amplitude_slm, amplitude_slm)  # Flat phase baseline
intensity_central_only = generate_far_field_intensity(phi_corr, amplitude_slm)
intensity_target = generate_far_field_intensity(phi_target, amplitude_slm)
intensity_total = generate_far_field_intensity(phi_total, amplitude_slm)


In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(4, 4, figsize=(15, 10))

# First row: Intensity plots
# axs[0, 0].set_title("Far field - ZOD only")
# axs[0, 0].imshow(intensity_raw_down, cmap='grey')
# axs[0, 0].axis('off')

# axs[0, 1].set_title("Far field - Central Spot Removed")
# axs[0, 1].imshow(intensity_central_only_down, cmap='grey')
# axs[0, 1].axis('off')

# axs[0, 2].set_title("Far field - Hologram only")
# axs[0, 2].imshow(intensity_target_down, cmap='grey')
# axs[0, 2].axis('off')

# axs[0, 3].set_title("Far field - Hologram and ZOD removal")
# axs[0, 3].imshow(intensity_total_down, cmap='grey')
# axs[0, 3].axis('off')

# # Second row: Phase and Amplitude plots downscaled
# axs[1, 0].set_title("aperture active area upscaled\nfor dead-zones visualisation")
# axs[1, 0].imshow(amplitude_slm, cmap='grey')
# axs[1, 0].axis('off')

# axs[1, 1].set_title("Phase mask - ZOD only")
# axs[1, 1].imshow(phi_corr_down, cmap='grey')
# axs[1, 1].axis('off')

# axs[1, 2].set_title("Phase mask - Hologram only")
# axs[1, 2].imshow(phi_target_down, cmap='grey')
# axs[1, 2].axis('off')

# axs[1, 3].set_title("Phase mask - Hologram and ZOD removal")
# axs[1, 3].imshow(phi_total_down, cmap='grey')
# axs[1, 3].axis('off')


axs[2, 0].set_title("Far field - ZOD only")
axs[2, 0].imshow(intensity_raw, cmap='grey')
axs[2, 0].axis('off')
axs[2, 1].set_title("Far field - Central Spot Removed")
axs[2, 1].imshow(intensity_central_only, cmap='grey')
axs[2, 1].axis('off')
axs[2, 2].set_title("Far field - Hologram only")
axs[2, 2].imshow(intensity_target, cmap='grey')
axs[2, 2].axis('off')
axs[2, 3].set_title("Far field - Hologram and ZOD removal")
axs[2, 3].imshow(intensity_total, cmap='grey')
axs[2, 3].axis('off')

axs[3, 0].set_title("Actual target image")
axs[3, 0].imshow(target_amplitude, cmap='grey')
axs[3, 0].axis('off')

axs[3, 1].set_title("Upscaled phase mask - ZOD only")
axs[3, 1].imshow(phi_corr, cmap='grey')
axs[3, 1].axis('off')

axs[3, 2].set_title("Upscaled phase mask - Hologram only")
axs[3, 2].imshow(phi_target, cmap='grey')
axs[3, 2].axis('off')

axs[3, 3].set_title("Upscaled phase mask - Hologram and ZOD removal")
axs[3, 3].imshow(phi_total, cmap='grey')
axs[3, 3].axis('off')



plt.tight_layout()
plt.show()
