#### TOLIMAN Pupil Gluing Analysis: Preliminary Measurements

**Aim:**  
To determine:
1. The optical aberrations induced by the lab setup for pupil testing (consisting of 2 OAPs)
2. The intensity distribution ouput from the optical fiber output for later modelling

If we can show these aberrations are static over a long enough period of time (~1hr) then we can confiendtly remove them from the phase retrieval analysis of the later measurements (glued vs non-glued).

We have chosen to place a spider mask (necessary asymmetry) within the collimated beam to characterise these aberrations via phase retrieval (thank u differentiable modelling/dLux 💖)

In [None]:
import dLux as dl
import dLux.utils as dlu


import jax.numpy as jnp
import numpy as np
import jax.random as jr
import jax.scipy as jsp
from jax import vmap  

import zodiax as zdx
import optax
from tqdm.notebook import tqdm

from skimage.io import imread

import matplotlib.pyplot as plt
from matplotlib.colors import PowerNorm

plt.rcParams['image.cmap'] = 'inferno'
plt.rcParams["font.family"] = "serif"
plt.rcParams["image.origin"] = 'lower'
plt.rcParams['figure.dpi'] = 72
plt.rcParams['figure.figsize'] = (10,10)
plt.rcParams["axes.titlesize"] = 18
plt.rcParams["figure.titlesize"] = 18
plt.rcParams["axes.labelsize"] = 15

In [None]:
# ------- Physical Parameters ---------------------------------------------------------------------#
aperture_npix = 2000  # Number of pixels across the aperture
aperture_diameter = 126e-3    # (m)
spider_width = 20e-3          # Spider width (m)
spider_angle = 90              # Spider angle (degrees), clockwise, 0 is spider pointing vertically up

# Observations wavelengths (bandpass of 530-640nm)
wavelengths = np.linspace(530e-9, 640e-9, 100)  # Wavelengths to simulate (m)
laser_wavelength =  635e-09  # for laser data
wf_npixels = aperture_npix  # Number of pixels across the wavefront
wf_diam = aperture_diameter             # Diameter of initial wavefront to propagate wavefront (m)

# Dtector parameters (BFS-U3-200S6-BD)
BFS_px_sep = 2.4e-6 *1e3        # pixel separation (mm)
f_det = 2000                     # Focal length from OAP2 to detector (mm) TODO
px_ang_sep = 2*np.arctan( (BFS_px_sep/2)/f_det ) # angular sep between pixels (rad)

# Simulated Detector
psf_npix = 3600                # Number of pixels along one dim of the PSF
oversample = 1                 # Oversampling factor for the PSF
psf_pixel_scale = 80e-4 #dlu.rad2arcsec(px_ang_sep) * oversample # arcsec (to match detector plate scale)


In [None]:
# --- Simulate Spider -----------------------------------------------------------------#
coords = dlu.pixel_coords(npixels=aperture_npix, diameter=aperture_diameter)
circle = dlu.circle(coords=coords, radius=aperture_diameter/2) 
spider = dlu.spider(coords=coords, width=spider_width, angles=[90])
transmission = dlu.combine([circle, spider], oversample)

# Zernike aberrations
zernike_indexes = jnp.arange(2, 11) # up to 10th noll idxs (excluding piston)
coeffs = jnp.zeros(zernike_indexes.shape)
basis = dlu.zernike_basis(js=zernike_indexes, coordinates=coords, diameter=aperture_diameter)

layers = [
    ('aperture', dl.layers.BasisOptic(basis, transmission, coeffs, normalise=True))
]

# Construct Optics
optics = dl.AngularOpticalSystem(wf_npixels = wf_npixels, 
                                diameter=wf_diam, 
                                layers=layers, 
                                psf_npixels=psf_npix, 
                                psf_pixel_scale=psf_pixel_scale,
                                oversample=oversample)

sim_psf = optics.propagate_mono(laser_wavelength)
opd = optics.aperture.eval_basis()

# Show setup and transmission results
plt.figure(figsize=(15,5))
plt.subplot(1,2,1)
plt.imshow(transmission)
plt.colorbar()
plt.title('Transmission')
plt.subplot(1,2,2)
norm_psf = PowerNorm(0.2, vmax=sim_psf.max(), vmin=sim_psf.min())
plt.imshow(sim_psf, norm=norm_psf)
plt.title('sqrt PSF (laser)')
plt.colorbar()



#### Load in some ✨real✨ data 🌈

In [None]:
fname = "data/lab_data/4_4_glued_002_mean_image.png"
data = imread(fname, as_gray=True) 

# Scale intensity
current_range = data.max() - data.min()
new_range = sim_psf.max() - sim_psf.min()
scaled_data = ( (data - data.min()) * new_range )/current_range + sim_psf.min()

#### Phase Retrieval

In [None]:
# Loss functions (jit-compiled)
param = 'aperture.coefficients'

# Poisson log-likelihood
flux_scale = 1e13 # scale factor for flux to help poission log-likelihood calc
@zdx.filter_jit
@zdx.filter_value_and_grad(param)
def loss_fn_poisson(model, data, wavelength_center):
    """
        Poisson log-likelihood loss function calculated over multiple psfs. 

        jsp.stats.poisson.logpmf returns the log of the Probability Mass Function (PMF)
        for a poisson distribution. By maximum likelihood estimation, we can optimise 
        the system by minimising the negative log-likelihood, i.e. the negative of the
        output from jsp.stats.poisson.logpmf (summed over)

        Parameters
        ----------
        model : dLux model
            dLux model to propagate the wavefront
        data : Array
            Array of data
        wavelength_center : float   
    """
    simu_psf = model.propagate_mono(wavelength_center)

    return -jsp.stats.poisson.logpmf(k = simu_psf, mu = data).sum()


In [None]:
learning_rate = 1e-10
optim, opt_state = zdx.get_optimiser(optics, param, optax.adam(learning_rate)) 

progress_bar = tqdm(range(200), desc='Loss: ')

# Run optimisation loop 
net_losses, models = [], []
for i in progress_bar:
    loss, grads = loss_fn_poisson(model = optics, data = scaled_data, wavelength_center = laser_wavelength)
    updates, opt_state = optim.update(grads, opt_state)
    optics = zdx.apply_updates(optics, updates)

    net_losses.append(loss)
    models.append(optics)
    
    progress_bar.set_postfix({'Loss': loss})

In [None]:
# Visualise results
plt.figure()
plt.plot(np.array(net_losses))
ax = plt.gca()
ax.set_title("Training History")
ax.set_xlabel("Training Epoch")
ax.set_ylabel("Poisson Log-Likelihood")

plt.rcParams['figure.figsize'] = (17, 17)
plt.rcParams['image.cmap'] = 'inferno'

plt.figure()
norm_psf = PowerNorm(0.2, vmax=scaled_data.max(), vmin=scaled_data.min())
plt.subplot(1,3,1)
plt.imshow(scaled_data, norm=norm_psf)
plt.colorbar()
plt.title('Data')
plt.subplot(1,3,2)
model_psf = optics.propagate_mono(laser_wavelength)
plt.imshow(model_psf, norm=norm_psf)
plt.title('Model')
plt.colorbar()
plt.subplot(1,3,3)
opd = optics.aperture.eval_basis()
plt.imshow(opd*transmission)
plt.title('Retrieved Aberrations')
plt.colorbar()