#### TOLIMAN Pupil Gluing Analysis: Preliminary Measurements

**Aim:**  
To determine:
1. The optical aberrations induced by the lab setup for pupil testing (consisting of 2 OAPs)
2. The intensity distribution ouput from the optical fiber output for later modelling

If we can show these aberrations are static over a long enough period of time (> 30min) then we can confidently remove them from the phase retrieval analysis of the later measurements (glued vs non-glued).

We have chosen to place a spider mask (necessary asymmetry) within the collimated beam to characterise these aberrations via phase retrieval (thank u differentiable modelling/dLux 💖)

In [None]:
import dLux as dl
import dLux.utils as dlu


import jax.numpy as jnp
import numpy as np
import jax.random as jr
import jax.scipy as jsp
from jax import vmap  

import zodiax as zdx
import optax
from tqdm.notebook import tqdm

from skimage.io import imread
from skimage.filters import window
import skimage as ski
from skimage.transform import resize

import matplotlib.pyplot as plt
from matplotlib.colors import PowerNorm

plt.rcParams['image.cmap'] = 'inferno'
plt.rcParams["font.family"] = "serif"
plt.rcParams["image.origin"] = 'lower'
plt.rcParams['figure.dpi'] = 72
plt.rcParams['figure.figsize'] = (10,10)
plt.rcParams["axes.titlesize"] = 18
plt.rcParams["figure.titlesize"] = 18
plt.rcParams["axes.labelsize"] = 15

In [None]:
# ------- Physical Parameters ---------------------------------------------------------------------#
aperture_npix = 512           # Number of pixels across the aperture
aperture_diameter = 126e-3    # (m)
spider_width = 20e-3          # Spider width (m)
spider_angle =270             # Spider angle (degrees), clockwise, 0 is spider pointing vertically up
coords = dlu.pixel_coords(npixels=aperture_npix, diameter=aperture_diameter)
circle = dlu.circle(coords=coords, radius=aperture_diameter/2) 

# Observations wavelengths (bandpass of 530-640nm)
wavelengths = np.linspace(530e-9, 640e-9, 100)  # Wavelengths to simulate (m)
laser_wavelength =  635e-09  # for laser data
wf_npixels = aperture_npix  # Number of pixels across the wavefront
wf_diam = aperture_diameter             # Diameter of initial wavefront to propagate wavefront (m)

# Dtector parameters (BFS-U3-200S6-BD)
BFS_px_sep = 2.4e-6 *1e3        # pixel separation (mm)
f_det = 1300#1350                    # Focal length from OAP2 to detector (mm) 
px_ang_sep = 2*np.arctan( (BFS_px_sep/2)/f_det ) # angular sep between pixels (rad)

# Simulated Detector
psf_npix = 40                 # Number of pixels along one dim of the PSF
oversample = 1                 # Oversampling factor for the PSF
psf_pixel_scale = dlu.rad2arcsec(px_ang_sep) # arcsec (to match detector plate scale) 80e-4 


#### Load in some ✨real✨ data 🌈
- Check intensity distribution across pupil first


In [None]:
fname = "data/intensity/15_07_intensity_dist.png"
data = imread(fname, as_gray=True) 
manual_lim = [1363,4203,386,3214]
data = data[manual_lim[2]:manual_lim[3], manual_lim[0]:manual_lim[1]]
data = (data - data.min())/(data.max()-data.min())

blurred = ski.filters.gaussian(data, sigma=(120, 120))

plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.imshow(data)
plt.title("Data - pre-focus")
plt.colorbar()
plt.subplot(1,3,2)
plt.imshow(blurred)
plt.title("Data blurred")
plt.colorbar()
plt.subplot(1,3,3)
intensity_dist = resize(blurred, (aperture_npix, aperture_npix))
intensity_dist = (intensity_dist - intensity_dist.min())/(intensity_dist.max()-intensity_dist.min()) # re-map from 0-1
# intensity_dist = np.fliplr(intensity_dist)
plt.title("Blurred re-sized")
plt.imshow(intensity_dist)
plt.colorbar()


In [None]:
# Use 2D gaussian to approximate this
def gauss_2d(x_0, y_0, var_x, var_y, envelope, pixel_coords):
    """
        Output 2D gaussian array with amplitude of 1, within
        aperture profile.

        Parameters:
        -----------
        x_0, y_0 : float
            (x,y) coordinate of the center of the gaussian (m)
        var_x, var_y : float
            Variance in the x and y directions (m)
        envelope : ndarray
            2D array of the aperture profile. Where this is 0, the gaussian
            is ommitted, where this is 1, the gaussian is included.
        pixel_coords : ndarray
            3D array of pixel coordinates over which gaussian is defined in the shape
            (2, npix, npix) where npix is the number of pixels across one dimension of the
            each 2D array (one for X and Y).
    """
    X, Y = pixel_coords
    assert X.shape[0] == envelope.shape[0], "Envelope and pixel coords must have the same shape"

    z = jnp.exp(-(X - x_0)**2/(2*var_x**2) - (Y - y_0)**2/(2*var_y**2))

    return z * envelope

var = 0.05
test = gauss_2d(x_0=0, y_0=0, var_x=var, var_y=var, envelope=circle, pixel_coords=coords)
plt.imshow(test)
plt.title("Approximated WF intensity")

In [None]:
# Create dLux class for this gauss transmissive layer
# maybe modelling a source object would be easier?
class GaussTransmissiveLayer(dl.layers.optical_layers.TransmissiveLayer):
    """
        Inherits from dl.layers.TransmissiveLayerm and allows for
        a Gaussian transmissive layer to be simulated. Useful for 
        simulating a wavefront that is not uniform in intensity.

        Attributes
        ----------
        transmission: Array
            The Array of transmission values to be applied to the input wavefront.
        gaussian_param: Array = [x_0, y_0, var_x, var_y], shape (4,)
            The parameters defining the 2D Gaussian to be applied to the input wavefront.
            Where:
            x_0, y_0 = float
                (x,y) coordinate of the center of the gaussian (m)
            var_x, var_y = float
                Variance in the x and y directions (m)
        envelope: Array
            2D array of the aperture profile. This masks out the Gaussian where the envelope
            is 0 and includes it where the envelope is 1. 
        pixel_coords : Array
            3D array of pixel coordinates over which gaussian is defined in the shape
            (2, npix, npix) where npix is the number of pixels across one dimension of the
            each 2D array (one for X and Y).
        normalise: bool
            Whether to normalise the wavefront after passing through the optic.
    """
    envelope: jnp.array
    X: jnp.array
    Y: jnp.array

    gauss_param: jnp.array

    def __init__(
        self: dl.layers.optical_layers.OpticalLayer,
        gaussian_param: np.array,
        envelope: np.array,
        pixel_coords: np.array,
        normalise: bool = False,
    ):
        self.X, self.Y = pixel_coords
        assert self.X.shape[0] == envelope.shape[0], "Envelope and pixel coords must have the same shape"
        assert gaussian_param.shape == (4,), "Gaussian parameters must be of shape (4,) in form [x_0, y_0, var_x, var_y] "
        self.gauss_param = gaussian_param
        self.envelope = envelope

        z = jnp.exp(-(self.X - self.gauss_param[0])**2/(2*self.gauss_param[2]**2) - (self.Y - self.gauss_param[1])**2/(2*self.gauss_param[3]**2))
        self.transmission = z * self.envelope

        super().__init__(transmission=self.transmission, normalise=normalise)

    def get_transmission(self):
        z = jnp.exp(-(self.X - self.gauss_param[0])**2/(2*self.gauss_param[2]**2) - (self.Y - self.gauss_param[1])**2/(2*self.gauss_param[3]**2))

        return z * self.envelope

    def apply(self: dl.layers.optical_layers.OpticalLayer, wavefront: dl.wavefronts.Wavefront) -> dl.wavefronts.Wavefront:
        """
        Applies the layer to the wavefront.

        Parameters
        ----------
        wavefront : Wavefront
            The wavefront to operate on.

        Returns
        -------
        wavefront : Wavefront
            The transformed wavefront.
        """
        wavefront *= self.get_transmission()
        if self.normalise:
            wavefront = wavefront.normalise()
        return wavefront


In [None]:
# --- Simulate Spider -----------------------------------------------------------------#
spider = dlu.spider(coords=coords, width=spider_width, angles=[spider_angle])
transmission = dlu.combine([circle, spider]) 
# transmission *= intensity_dist # Scale transmission by true intensity distribution
gauss_intensity_dist= gauss_2d(x_0=0, y_0=0, var_x=0.05, var_y=0.05, envelope=circle, pixel_coords=coords)

# Zernike aberrations
zernike_indicies = jnp.arange(4, 15) # up to 10th noll idxs (excluding piston)
coeffs = jnp.zeros(zernike_indicies.shape)
basis = dlu.zernike_basis(js=zernike_indicies, coordinates=coords, diameter=aperture_diameter)

layers = [
    # ('intensity_dist', dl.layers.TransmissiveLayer(transmission= gauss_intensity_dist)),
    ('intensity_dist', GaussTransmissiveLayer(gaussian_param=np.array([0,0,0.05,0.05]), envelope=circle, pixel_coords= coords)),
    ('aperture', dl.layers.BasisOptic(basis, transmission, coeffs, normalise=False)),
]

# Construct Optics
optics = dl.AngularOpticalSystem(wf_npixels = wf_npixels, 
                                diameter=wf_diam, 
                                layers=layers, 
                                psf_npixels=psf_npix, 
                                psf_pixel_scale=psf_pixel_scale,
                                oversample=oversample)

# sim_psf = optics.propagate_mono(laser_wavelength)
# opd = optics.aperture.eval_basis()

# Using PointSources instead of single PointSource object to overcome float grad issue 
# https://github.com/LouisDesdoigts/dLux/issues/271 
src = dl.PointSources(wavelengths=[laser_wavelength], flux =jnp.asarray([1e8],dtype=float))

instrument = dl.Telescope(optics, ('source', src))
sim_psf = instrument.model()

# Show setup and transmission results
plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.imshow(transmission)
plt.colorbar()
plt.title('Mask Transmission')
plt.subplot(1,3,2)
norm_psf = PowerNorm(0.5, vmax=sim_psf.max(), vmin=sim_psf.min())
print("Total flux: {}".format(sim_psf.sum()))
plt.imshow(sim_psf, norm=norm_psf)
plt.title('sqrt PSF (laser)')
plt.colorbar()
plt.subplot(1,3,3)
plt.imshow(optics.intensity_dist.transmission)
plt.title('Intensity Distribution')
plt.colorbar()




#### Load in spider images ⭐️
- 03_laser_90deg_mean.png: Spider (90 deg oriented) setup with laser
- 04_laser_90deg_mean.png: Same setup as above scenario, data taken ~1hr later 

In [None]:
fname = "data/spider/03_laser_90deg_mean.png"
data = imread(fname, as_gray=True) 

# Scale intensity
data = data**1.2 # non-linear behaviour estimation
current_range = data.max() - data.min()
new_range = sim_psf.max() - sim_psf.min()
scaled_data = ( (data - data.min()) * new_range )/current_range + sim_psf.min()
# new_range = data.max() - 0
# scaled_data = ( (data - data.min()) * new_range )/current_range + 0

psf_center_idx = np.unravel_index(np.argmax(scaled_data, axis=None), scaled_data.shape)
psf_hlf_sz = 20
scaled_data = scaled_data[psf_center_idx[0]-psf_hlf_sz:psf_center_idx[0]+psf_hlf_sz,
                            psf_center_idx[1]-psf_hlf_sz:psf_center_idx[1]+psf_hlf_sz]

data = data[psf_center_idx[0]-psf_hlf_sz:psf_center_idx[0]+psf_hlf_sz,
                            psf_center_idx[1]-psf_hlf_sz:psf_center_idx[1]+psf_hlf_sz]
data = data + sim_psf.min() # ensure mu = 0 for k !=0 does not occur (logpmf shits itself)

print("Total flux (Raw): {}".format(data.sum()))
print("Total flux (Scaled): {}".format(scaled_data.sum()))

%matplotlib inline
plt.figure(figsize=(17,7))
plt.subplot(1,2,1)
plt.imshow(scaled_data, norm=norm_psf)
# plt.imshow(data**0.2)
plt.colorbar()
plt.title("Data (scaled)")
plt.subplot(1,2,2)
plt.imshow(sim_psf, norm=norm_psf)
plt.colorbar()
plt.title("Simulated")

#### Phase Retrieval

In [None]:
# @zdx.filter_jit
# @zdx.filter_value_and_grad(param)
# def loss_fn_mse(model, data, wavelength_center):

#     simu_psf = model.propagate_mono(wavelength_center)

#     mse = 1/simu_psf.size * ((data-simu_psf)**2).sum()

#     return mse

In [None]:
params = [
    'aperture.coefficients',
    'source.position',

    #for gauss_2d params
    'intensity_dist.gauss_param',
    
    # 'source.flux', # I don't think flux makes sense to fit as we have a limited number of non-zero pixels (need to retake data)
    # 'intensity_dist.transmission',
    ]
@zdx.filter_jit
@zdx.filter_value_and_grad(params)
def loss_fn_gaussian(model, data):

    simu_psf = model.model()

    uncertainty = 0.1 # 10% err per pix TODO try increasing

    loss = -jsp.stats.norm.logpdf(x=simu_psf, loc=data, scale=data*uncertainty).sum()

    return loss
@zdx.filter_jit
@zdx.filter_value_and_grad(params)
def loss_fn_poisson(model, data):

    simu_psf = model.model()

    loss = -jsp.stats.poisson.logpmf(k=simu_psf, mu=data).sum()

    return loss

In [None]:
learning_rate = 1e-8
# optimisers = [
#             optax.adam(optax.linear_schedule(init_value=learning_rate, end_value=0, transition_begin=500, transition_steps=1)),
#             optax.adam(optax.linear_schedule(init_value=learning_rate, end_value=0, transition_begin=500, transition_steps=1)),
#             # optax.adam(optax.piecewise_constant_schedule(init_value=1e-2*1e4, boundaries_and_scales={400: int(1e4)})),
#             optax.sgd(optax.linear_schedule(init_value=0, end_value=1e-7, transition_begin=500, transition_steps=1)),
#             # optax.adam(optax.linear_schedule(init_value=0, end_value=1e6, transition_begin=500, transition_steps=1)),
#             # optax.adam(learning_rate=1e6),
#             # optax.adam(learning_rate=1e-3),
#               ]
optimisers = [
            optax.adam(learning_rate=learning_rate),
            optax.adam(learning_rate=learning_rate),
            # optax.adam(optax.piecewise_constant_schedule(init_value=1e-2*1e4, boundaries_and_scales={400: int(1e4)})),
            optax.adam(learning_rate=1e-4),
            # optax.adam(optax.linear_schedule(init_value=0, end_value=1e6, transition_begin=500, transition_steps=1)),
            # optax.adam(learning_rate=1e6),
            # optax.adam(learning_rate=1e-3),
              ]


# Optimise flux first
optim, opt_state = zdx.get_optimiser(instrument, params, optimisers) # Needs to be iterable param (i.e. accessible via instrument class)

progress_bar = tqdm(range(2000), desc='Loss: ')

# Run optimisation loop 
net_losses, Coeffs, Positions, Fluxes, Dist_sum = [],[],[],[], []
for i in progress_bar:
    poiss_loss, poiss_grads = loss_fn_poisson(model = instrument, data = scaled_data)
    # gauss_loss, gauss_grads = loss_fn_gaussian(model = instrument, data = data)
    # poiss_grads = poiss_grads.set("source.flux",gauss_grads.source.flux) # flux won't converge with poission (converges to Nan)

    updates, opt_state = optim.update(poiss_grads, opt_state)
    instrument = zdx.apply_updates(instrument, updates) 
    
    # # Manual update for gauss transmissive 
    # # (updates are additive)
    # prev_trans= instrument.intensity_dist.transmission
    # new_trans = instrument.intensity_dist.get_transmission()
    # instrument = instrument.set('intensity_dist.transmission', new_trans)


    net_losses.append(poiss_loss)
    Fluxes.append(instrument.source.flux)
    Coeffs.append(instrument.aperture.coefficients)
    Positions.append(instrument.source.position)
    # Dist_sum.append((new_trans - prev_trans).sum())
    Dist_sum.append(instrument.intensity_dist.gauss_param)
    
    progress_bar.set_postfix({'Loss': poiss_loss})




In [None]:
# Visualise results
%matplotlib inline
plt.figure(figsize=(7,4))
plt.plot(np.array(net_losses))
ax = plt.gca()
ax.set_title("Training History")
ax.set_xlabel("Training Epoch")
ax.set_ylabel("Poisson Log-Likelihood")

plt.rcParams['figure.figsize'] = (17, 17)
plt.rcParams['image.cmap'] = 'inferno'

plt.figure(figsize=(30,10))
plt.subplot(1,4,1)
plt.plot(np.asarray(Positions)[:,0,0], label="Position X")
plt.plot(np.asarray(Positions)[:,0,1], label="Position Y")
plt.title("Position")
plt.legend()
plt.subplot(1,4,2)
plt.plot(np.asarray(Dist_sum))
plt.title("Intensity Distribution Sum of Diff")
plt.subplot(1,4,3)
arr_coeffs = np.asarray(Coeffs)
for i in range(len(Coeffs[0])):
    label = "Coeff " + str(zernike_indicies[i])
    plt.plot(arr_coeffs[:,i], label=label)
plt.legend()
plt.subplot(1,4,4)
plt.plot(np.asarray(Fluxes))
plt.title("Flux")

plt.figure(figsize=(12,10))
norm_psf = PowerNorm(0.2, vmax=scaled_data.max(), vmin=scaled_data.min())
plt.subplot(2,2,1)
plt.imshow(scaled_data, norm=norm_psf)
plt.colorbar()
plt.title('Data')

plt.subplot(2,2,2)
model_psf = instrument.model()
current_range = model_psf.max() - model_psf.min()
new_range = scaled_data.max() - scaled_data.min()
model_psf = ( (model_psf - model_psf.min()) * new_range )/current_range + scaled_data.min()
norm_psf = PowerNorm(0.2, vmax=model_psf.max(), vmin=model_psf.min())
plt.imshow(model_psf, norm=norm_psf)
plt.title('Model')
plt.colorbar()

plt.subplot(2,2,3)
resid = scaled_data - model_psf
# resid = data - model_psf
plt.imshow(resid, cmap='bwr', vmax = np.abs(resid).max(), vmin = -np.abs(resid).max())
plt.colorbar()
plt.title('Residuals')

plt.subplot(2,2,4)
opd = instrument.aperture.eval_basis()
plt.imshow(opd*transmission, cmap='viridis')
plt.title('Retrieved Aberrations')
plt.colorbar()

In [None]:
print(instrument.intensity_dist.gauss_param)

In [None]:
plt.imshow(instrument.intensity_dist.get_transmission())
# plt.imshow(instrument.intensity_dist.transmission)
plt.colorbar()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(scaled_data, cmap='bone_r')
ax[1].imshow(model_psf, cmap='bone_r')

point = [20, 20]

for i in [0, 1]:
    ax[i].scatter(*point, marker='x', color='r')

plt.show()

In [None]:
print("Found coefficients for noll idxs: {}\n{}".format(zernike_indicies, optics.aperture.coefficients))

run1_coeffs = optics.aperture.coefficients

Lets try with second data set... (1hr apart)

In [None]:
# Re-initialise model
optics = dl.AngularOpticalSystem(wf_npixels = wf_npixels, 
                                diameter=wf_diam, 
                                layers=layers, 
                                psf_npixels=psf_npix, 
                                psf_pixel_scale=psf_pixel_scale,
                                oversample=oversample)

print(optics.aperture.coefficients)

In [None]:
fname = "data/spider/04_laser_90deg_mean.png"
data = imread(fname, as_gray=True) 

# Scale intensity
data = data**1.2 # non-linear behaviour estimation
current_range = data.max() - data.min()
new_range = sim_psf.max() - sim_psf.min()
scaled_data = ( (data - data.min()) * new_range )/current_range + sim_psf.min()

psf_center_idx = np.unravel_index(np.argmax(scaled_data, axis=None), scaled_data.shape)
psf_hlf_sz = 50
scaled_data = scaled_data[psf_center_idx[0]-psf_hlf_sz:psf_center_idx[0]+psf_hlf_sz,
                            psf_center_idx[1]-psf_hlf_sz:psf_center_idx[1]+psf_hlf_sz]

plt.figure(figsize=(17,7))
plt.subplot(1,2,1)
plt.imshow(scaled_data, norm = norm_psf)
plt.colorbar()
plt.title("Data")
plt.subplot(1,2,2)
plt.imshow(sim_psf, norm = norm_psf)
plt.colorbar()
plt.title("Simulated")

In [None]:
learning_rate = 1e-10
optim, opt_state = zdx.get_optimiser(optics, param, optax.adam(learning_rate)) 

progress_bar = tqdm(range(1000), desc='Loss: ')

# Run optimisation loop 
net_losses, models = [], []
for i in progress_bar:
    loss, grads = loss_fn_poisson(model = optics, data = scaled_data, wavelength_center = laser_wavelength)
    updates, opt_state = optim.update(grads, opt_state)
    optics = zdx.apply_updates(optics, updates)

    net_losses.append(loss)
    models.append(optics)
    
    progress_bar.set_postfix({'Loss': loss})

In [None]:
# Visualise results
plt.figure(figsize=(7,4))
plt.plot(np.array(net_losses))
ax = plt.gca()
ax.set_title("Training History")
ax.set_xlabel("Training Epoch")
ax.set_ylabel("Poisson Log-Likelihood")

plt.rcParams['figure.figsize'] = (17, 17)
plt.rcParams['image.cmap'] = 'inferno'

plt.figure(figsize=(12,10))
norm_psf = PowerNorm(0.5, vmax=scaled_data.max(), vmin=scaled_data.min())
plt.subplot(2,2,1)
plt.imshow(scaled_data, norm=norm_psf)
plt.colorbar()
plt.title('Data')

plt.subplot(2,2,2)
model_psf = optics.propagate_mono(laser_wavelength)
plt.imshow(model_psf, norm=norm_psf)
plt.title('Model')
plt.colorbar()

plt.subplot(2,2,3)
resid = scaled_data - model_psf
plt.imshow(resid, cmap='bwr', vmax = np.abs(resid).max(), vmin = -np.abs(resid).max())
plt.colorbar()
plt.title('Residuals')

plt.subplot(2,2,4)
opd = optics.aperture.eval_basis()
plt.imshow(opd*transmission)
plt.title('Retrieved Aberrations')
plt.colorbar()

In [None]:
print("Found coefficients for noll idxs: {}\n{}".format(zernike_indicies, optics.aperture.coefficients))

run2_coeffs = optics.aperture.coefficients

plt.figure(figsize=(10,5))
x = np.arange(2,11)
plt.plot(zernike_indicies,run1_coeffs, label='Run 1')
plt.plot(zernike_indicies, run2_coeffs, label='Run 2')
plt.xlabel('Noll Index')
plt.ylabel('Coefficient')
plt.title("Phase Retrieval 1hr Apart")
plt.grid()

#### Let's try different orientation spider 🕷️

In [None]:
# --- Simulate Spider -----------------------------------------------------------------#
spider = dlu.spider(coords=coords, width=spider_width, angles=[180]) #TODO try 0 deg for spider because loading in data flips along x-axis
transmission = dlu.combine([circle, spider]) 
transmission *= intensity_dist

# Zernike aberrations
coeffs = np.zeros(zernike_indicies.shape)#run1_coeffs
basis = dlu.zernike_basis(js=zernike_indicies, coordinates=coords, diameter=aperture_diameter)

layers = [
    ('aperture', dl.layers.BasisOptic(basis, transmission, coeffs, normalise=True))
]

# Construct Optics
optics = dl.AngularOpticalSystem(wf_npixels = wf_npixels, 
                                diameter=wf_diam, 
                                layers=layers, 
                                psf_npixels=psf_npix, 
                                psf_pixel_scale=psf_pixel_scale,
                                oversample=oversample)


sim_psf = optics.propagate_mono(laser_wavelength)
opd = optics.aperture.eval_basis()

# Show setup and transmission results
plt.figure(figsize=(17,5))
plt.subplot(1,3,1)
plt.imshow(transmission)
plt.colorbar()
plt.title('Transmission')
plt.subplot(1,3,2)
norm_psf = PowerNorm(0.5, vmax=sim_psf.max(), vmin=sim_psf.min())
plt.imshow(sim_psf, norm=norm_psf)
plt.title('sqrt PSF (laser)')
plt.colorbar()

plt.subplot(1,3,3)
plt.imshow(opd)
plt.title('Initialised Aberrations')
plt.colorbar(label='OPD (m)')


In [None]:
fname = "data/spider/02_laser_180deg_mean.png"
data = imread(fname, as_gray=True) 

# Scale intensity
data = data**1.2 # non-linear behaviour estimation
current_range = data.max() - data.min()
new_range = sim_psf.max() - sim_psf.min()
scaled_data = ( (data - data.min()) * new_range )/current_range + sim_psf.min()

psf_center_idx = np.unravel_index(np.argmax(scaled_data, axis=None), scaled_data.shape)
psf_hlf_sz = 50
scaled_data = scaled_data[psf_center_idx[0]-psf_hlf_sz:psf_center_idx[0]+psf_hlf_sz,
                            psf_center_idx[1]-psf_hlf_sz:psf_center_idx[1]+psf_hlf_sz]

plt.figure(figsize=(17,7))
plt.subplot(1,2,1)
plt.imshow(scaled_data, norm = norm_psf)
plt.colorbar()
plt.title("Data")
plt.subplot(1,2,2)
plt.imshow(sim_psf, norm = norm_psf)
plt.colorbar()
plt.title("Simulated")

In [None]:
learning_rate = 1e-10
optim, opt_state = zdx.get_optimiser(optics, param, optax.adam(learning_rate)) 

progress_bar = tqdm(range(1000), desc='Loss: ')

# Run optimisation loop 
net_losses, models = [], []
for i in progress_bar:
    loss, grads = loss_fn_poisson(model = optics, data = scaled_data, wavelength_center = laser_wavelength)
    updates, opt_state = optim.update(grads, opt_state)
    optics = zdx.apply_updates(optics, updates)

    net_losses.append(loss)
    models.append(optics)
    
    progress_bar.set_postfix({'Loss': loss})

In [None]:
# Visualise results
plt.figure(figsize=(7,4))
plt.plot(np.array(net_losses))
ax = plt.gca()
ax.set_title("Training History")
ax.set_xlabel("Training Epoch")
ax.set_ylabel("Poisson Log-Likelihood")

plt.rcParams['figure.figsize'] = (17, 17)
plt.rcParams['image.cmap'] = 'inferno'

plt.figure(figsize=(12,10))
norm_psf = PowerNorm(0.5, vmax=scaled_data.max(), vmin=scaled_data.min())
plt.subplot(2,2,1)
plt.imshow(scaled_data, norm=norm_psf)
plt.colorbar()
plt.title('Data')

plt.subplot(2,2,2)
model_psf = optics.propagate_mono(laser_wavelength)
plt.imshow(model_psf, norm=norm_psf)
plt.title('Model')
plt.colorbar()

plt.subplot(2,2,3)
resid = scaled_data - model_psf
plt.imshow(resid, cmap='bwr', vmax = np.abs(resid).max(), vmin = -np.abs(resid).max())
plt.colorbar()
plt.title('Residuals')

plt.subplot(2,2,4)
opd = optics.aperture.eval_basis()
plt.imshow(opd*transmission)
plt.title('Retrieved Aberrations')
plt.colorbar()

In [None]:
run3_coeffs = optics.aperture.coefficients

plt.figure(figsize=(10,5))
x = np.arange(2,11)
plt.plot(zernike_indicies,run1_coeffs, label='Run 1')
plt.plot(zernike_indicies, run2_coeffs, label='Run 2')
plt.plot(zernike_indicies, run3_coeffs, label='Run 3')
plt.xlabel('Noll Index')
plt.ylabel('Coefficient')
plt.title("Phase Retrieval 1hr Apart")
plt.legend()
plt.grid()

#### Let's try ✨Phase Diversity (sort of)✨ Using two different Spider Orientations ☀️
- Phase Diversity involves taking a second image with a known aberration applied. We have just taken a second image
with a different known spider orientation, but we can optimise using both scenarios simultaneously. 

In [None]:
# Loss functions (jit-compiled)
param = 'aperture.coefficients'

# Poisson log-likelihood
@zdx.filter_jit
@zdx.filter_value_and_grad(param)
def loss_fn_poisson(model_one, model_two_psf, data_one, data_two, wavelength_center):
    """
        Poisson log-likelihood loss function calculated over multiple psfs. 

        jsp.stats.poisson.logpmf returns the log of the Probability Mass Function (PMF)
        for a poisson distribution. By maximum likelihood estimation, we can optimise 
        the system by minimising the negative log-likelihood, i.e. the negative of the
        output from jsp.stats.poisson.logpmf (summed over)

        Parameters
        ----------
        model : dLux model
            dLux model to propagate the wavefront
        data : Array
            Array of data
        wavelength_center : float   
    """
    simu_psf_one = model_one.propagate_mono(wavelength_center)
 
    net_loss = -(jsp.stats.poisson.logpmf(k = simu_psf_one, mu = data_one).sum() + jsp.stats.poisson.logpmf(k = model_two_psf, mu = data_two).sum())

    return net_loss


In [None]:
# --- Simulate Spider -----------------------------------------------------------------#
circle = dlu.circle(coords=coords, radius=aperture_diameter/2) 
spider = dlu.spider(coords=coords, width=spider_width, angles=[180]) 
transmission = dlu.combine([circle, spider]) 
transmission *= intensity_dist

# Zernike aberrations
coeffs = np.zeros(zernike_indicies.shape)
basis = dlu.zernike_basis(js=zernike_indicies, coordinates=coords, diameter=aperture_diameter)

layers = [
    ('aperture', dl.layers.BasisOptic(basis, transmission, coeffs, normalise=True))
]

# Construct Optics
model_one = dl.AngularOpticalSystem(wf_npixels = wf_npixels, 
                                diameter=wf_diam, 
                                layers=layers, 
                                psf_npixels=psf_npix, 
                                psf_pixel_scale=psf_pixel_scale,
                                oversample=oversample)

sim_psf = model_one.propagate_mono(laser_wavelength)
plt.figure(figsize=(10,10))
plt.subplot(2,2,1)
plt.imshow(transmission)
plt.colorbar()
plt.title('Model 1 Transmission')
plt.subplot(2,2,2)
norm_psf = PowerNorm(0.5, vmax=sim_psf.max(), vmin=sim_psf.min())
plt.imshow(sim_psf, norm=norm_psf)
plt.title('Model 1 sqrt PSF (laser)')
plt.colorbar()

spider = dlu.spider(coords=coords, width=spider_width, angles=[270]) 
transmission = dlu.combine([circle, spider]) 
transmission *= intensity_dist
layers = [
    ('aperture', dl.layers.BasisOptic(basis, transmission, coeffs, normalise=True))
]

model_two = dl.AngularOpticalSystem(wf_npixels = wf_npixels, 
                                diameter=wf_diam, 
                                layers=layers, 
                                psf_npixels=psf_npix, 
                                psf_pixel_scale=psf_pixel_scale,
                                oversample=oversample)

sim_psf = model_two.propagate_mono(laser_wavelength)

plt.subplot(2,2,3)
plt.imshow(transmission)
plt.colorbar()
plt.title('Model 2 Transmission')
plt.subplot(2,2,4)
norm_psf = PowerNorm(0.5, vmax=sim_psf.max(), vmin=sim_psf.min())
plt.imshow(sim_psf, norm=norm_psf)
plt.title('Model 2 sqrt PSF (laser)')
plt.colorbar()


In [None]:
DATA_PSFS = []
fnames = ["data/spider/02_laser_180deg_mean.png", "data/spider/04_laser_90deg_mean.png"]
for i in range(len(fnames)):
    fname = fnames[i]
    data = imread(fname, as_gray=True) 

    # Scale intensity
    data = data**1.2 # non-linear behaviour estimation
    current_range = data.max() - data.min()
    new_range = sim_psf.max() - sim_psf.min()
    scaled_data = ( (data - data.min()) * new_range )/current_range + sim_psf.min()

    psf_center_idx = np.unravel_index(np.argmax(scaled_data, axis=None), scaled_data.shape)
    psf_hlf_sz = 50
    scaled_data = scaled_data[psf_center_idx[0]-psf_hlf_sz:psf_center_idx[0]+psf_hlf_sz,
                                psf_center_idx[1]-psf_hlf_sz:psf_center_idx[1]+psf_hlf_sz]
    
    DATA_PSFS.append(scaled_data)  

plt.figure(figsize=(17,7))
plt.subplot(1,2,1)
plt.imshow(DATA_PSFS[0], norm = norm_psf)
plt.title("Data One")
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(DATA_PSFS[1], norm = norm_psf)
plt.title("Data Two")
plt.colorbar()


In [None]:
learning_rate = 1e-11

optim, opt_state = zdx.get_optimiser(model_one, param, optax.adam(learning_rate)) 
progress_bar = tqdm(range(500), desc='Loss: ')

# Run optimisation loop 
net_losses = []
for i in progress_bar:
    model_two_psf = model_two.propagate_mono(laser_wavelength)
    loss, grads = loss_fn_poisson(model_one = model_one, # grads calculated on model one coeffs but loss on both 
                                  model_two_psf=model_two_psf, 
                                  data_one = DATA_PSFS[0], 
                                  data_two = DATA_PSFS[1],
                                  wavelength_center = laser_wavelength)
    
    # Update model one
    updates, opt_state = optim.update(grads, opt_state)
    model_one = zdx.apply_updates(model_one, updates)

    # Update model two
    model_two = zdx.apply_updates(model_two, updates)

    net_losses.append(loss)
    
    progress_bar.set_postfix({'Loss': loss})


In [None]:
# Visualise results
plt.figure(figsize=(7,4))
plt.plot(np.array(net_losses))
ax = plt.gca()
ax.set_title("Training History")
ax.set_xlabel("Training Epoch")
ax.set_ylabel("Poisson Log-Likelihood")

plt.rcParams['figure.figsize'] = (17, 17)
plt.rcParams['image.cmap'] = 'inferno'

models = [model_one, model_two]
for i, data in enumerate(DATA_PSFS):
    plt.figure(figsize=(12,10))
    norm_psf = PowerNorm(0.5, vmax=data.max(), vmin=data.min())
    plt.subplot(2,2,1)
    plt.imshow(data, norm=norm_psf)
    plt.colorbar()
    plt.title('Data')

    plt.subplot(2,2,2)
    model_psf = models[i].propagate_mono(laser_wavelength)
    plt.imshow(model_psf, norm=norm_psf)
    plt.title('Model')
    plt.colorbar()

    plt.subplot(2,2,3)
    resid = data - model_psf
    plt.imshow(resid, cmap='bwr', vmax = np.abs(resid).max(), vmin = -np.abs(resid).max())
    plt.colorbar()
    plt.title('Residuals')

    plt.subplot(2,2,4)
    opd = models[i].aperture.eval_basis()
    plt.imshow(opd*models[i].aperture.transmission)
    plt.title('Retrieved Aberrations')
    plt.colorbar()