In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft2, ifft2, fftshift
from PIL import Image, ImageDraw, ImageFont

%matplotlib ipympl

# Parameters
Nx, Ny = 80, 60 # SLM pixels
wavelength = 632.8e-9  # meters (not explicitly used for scale in far field)
pixel_pitch = 23e-6  # meters (not explicitly used in simulation, can be used for scaling)
fill_factor = 0.85
aperture_diameter = Ny  # pixels for flat top illumination
num_iterations = 50  # GS Iterations
central_spot_radius = 5  # pixels: radius of ZOD central spot to suppress
oversample_factor = 10

In [None]:
# Weight balancing and final phase mask generation
C_target = 1.0
# Example weight; iteratively, user may choose to scan this parameter
C_corr = 0.35

def create_circular_aperture(Nx, Ny, diameter):
    x = np.arange(Nx) - Nx // 2
    y = np.arange(Ny) - Ny // 2
    X, Y = np.meshgrid(x, y, indexing='ij')
    radius = diameter / 2
    aperture_mask = ((X**2 + Y**2) <= radius**2).astype(float)
    return aperture_mask

def apply_fill_factor_with_oversampling(mask, oversample_factor=10, fill_factor=0.85):
    Nx, Ny = mask.shape
    # Oversample by repeating pixels
    mask_os = np.repeat(np.repeat(mask, oversample_factor, axis=0), oversample_factor, axis=1)
    # Create fill factor pixel mask (active modulation area)
    active_size = int(oversample_factor * fill_factor)
    start_idx = (oversample_factor - active_size) // 2
    pixel_mask = np.zeros((oversample_factor, oversample_factor), dtype=bool)
    pixel_mask[start_idx:start_idx+active_size, start_idx:start_idx+active_size] = True
    # Tile over entire mask
    tiled_mask = np.tile(pixel_mask, (Nx, Ny))
    # Apply fill factor by zeroing dead regions
    filled_mask = mask_os * tiled_mask
    return filled_mask

def gerchberg_saxton(target_amplitude, aperture_amplitude, num_iter=50):
    phase = np.exp(1j * 2 * np.pi * np.random.rand(*aperture_amplitude.shape))
    field = aperture_amplitude * phase
    for _ in range(num_iter):
        far_field = fft2(field)
        far_field_phase = np.angle(far_field)
        far_field = target_amplitude * np.exp(1j * far_field_phase)
        field = ifft2(far_field)
        field = aperture_amplitude * np.exp(1j * np.angle(field)) # reminder that aperture_amplitude is a mask that contains 0 in the dead-zones
    return np.angle(field)

def generate_far_field_intensity(phase, aperture_amplitude):
    field = aperture_amplitude * np.exp(1j * phase)
    far_field = fft2(field)
    intensity = np.abs(far_field)**2
    return fftshift(intensity / intensity.max())



def compute_weighted_phase( phi_corr, phi_target, aperture_ampl, C_corr=0.35,C_target=1):
    U_corr = aperture_ampl * np.exp(1j * phi_corr)
    U_target = aperture_ampl * np.exp(1j * phi_target)
    U_slm = C_corr * U_corr + C_target * U_target
    return np.angle(U_slm)

#Create a blank greyscale image
width, height = Nx*oversample_factor, Ny*oversample_factor
image = Image.new('L', (width, height), color=0)  # black background

draw = ImageDraw.Draw(image)

# Use a truetype font (default system font or specify path)
font = ImageFont.truetype("arial.ttf", size=100)

# Draw text (white color)
text = 'HEIG'
position = (0,0)#(Nx*oversample_factor//2 -100, Ny*oversample_factor//2 - 10)
draw.text(position, text, fill=255, font=font)

# Convert to numpy array
array_text = np.array(image).astype(np.bool)

plt.figure()
plt.imshow(array_text)

In [None]:
# Create aperture mask
# aperture_mask = create_circular_aperture(Nx, Ny, aperture_diameter)
aperture_mask = np.ones((Nx, Ny))
aperture_os = apply_fill_factor_with_oversampling(aperture_mask, oversample_factor, fill_factor)
Nx_os, Ny_os = aperture_os.shape



In [None]:
# Initialize amplitude_slm at oversampled resolution
amplitude_slm = aperture_os

# Create target amplitude patterns at oversampled resolution

y_c, x_c = Ny_os // 2, Nx_os // 2
X2, Y2 = np.meshgrid(np.arange(Nx_os), np.arange(Ny_os),indexing='ij')
mask_central = (X2 - x_c)**2 + (Y2 - y_c)**2 <= (central_spot_radius * oversample_factor)**2

# Central spot correction
target_amplitude_correction = np.zeros((Nx_os, Ny_os))
target_amplitude_correction[mask_central] = 1

# blank target for all spots removal

target_amplitude = np.zeros((Nx_os, Ny_os))

target_amplitude = array_text.T




print("Computing correction hologram phase...")
phi_corr = gerchberg_saxton(target_amplitude_correction, amplitude_slm, num_iter=num_iterations)
print("Computing target hologram phase (central spot off)...")
phi_target = gerchberg_saxton(target_amplitude, amplitude_slm, num_iter=num_iterations)
print("Computing target hologram phase (all spots off)...")


In [None]:
phi_total = compute_weighted_phase(C_corr, phi_corr, phi_target, amplitude_slm)

intensity_raw = generate_far_field_intensity(aperture_os, amplitude_slm)  # Flat phase baseline
intensity_central_only = generate_far_field_intensity(phi_corr, amplitude_slm)
intensity_target = generate_far_field_intensity(phi_target, amplitude_slm)
intensity_total = generate_far_field_intensity(phi_total, amplitude_slm)

In [None]:

# Visualize results
plt.figure(figsize=(15, 5))

plt.subplot(1, 4, 1)
plt.title("Far field - ZOD only")
plt.imshow(intensity_raw, cmap='inferno')
plt.axis('off')

plt.subplot(1, 4, 2)
plt.title("Far field - Central Spot Removed")
plt.imshow(intensity_central_only, cmap='inferno')
plt.axis('off')

plt.subplot(1, 4, 3)
plt.title("Far field - Hologram only")
plt.imshow(intensity_target, cmap='inferno')
plt.axis('off')

plt.subplot(1, 4, 4)
plt.title("Far field - Hologram and ZOD removal")
plt.imshow(intensity_total, cmap='inferno')
plt.axis('off')


plt.figure(figsize=(15, 5))

plt.subplot(1, 4, 1)
plt.title("aperture active area")
plt.imshow(amplitude_slm, cmap='grey')
plt.axis('off')

plt.subplot(1, 4, 2)
plt.title("Phase mask - ZOD only")
plt.imshow(phi_corr, cmap='inferno')
plt.axis('off')

plt.subplot(1, 4, 3)
plt.title("Phase mask - Hologram only")
plt.imshow(phi_target, cmap='inferno')
plt.axis('off')

plt.subplot(1, 4, 4)
plt.title("Phase mask - Hologram and ZOD removal")
plt.imshow(phi_total, cmap='inferno')
plt.axis('off')
