# Non-common path wavefront sensing with a vector-Zernike wavefront sensor

We will introduce the classical Zernike wavefront sensor (ZWFS) and a way to reconstruct phase aberrations. Then we will introduce the vector-Zernike WFS (vZWFS) and show how this version allows for simultaneous phase and amplitude aberration sensing. 

This tutorial assumes that propagation through a liquid-crystal optic is known to the reader. To learn how this works, follow the "VectorApodizingPhasePlate" tutorial.

We'll start by importing all relevant libraries and setting up our pupil and focal grids. We also import precomputed phase file for the pupil and the vAPP and the pupil the vAPP pattern was calculated for. 

In [None]:
from hcipy import *
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import os

# For notebook animations
from matplotlib import animation
from IPython.display import HTML

In [None]:
pupil_grid = make_pupil_grid(256, 1.5)

aperture = make_magellan_aperture(True)

telescope_pupil = aperture(pupil_grid)#circular_aperture(1)(pupil_grid)# 

The classical Zernike wavefront sensor is implemented in hcipy. A ZWFS is a focal plane optic, but in HCIPy it is implemented as a pupil plane to pupil plane propagation, similar to the vortex coronagraph. This ensures optimal calculation speed using matrix Fourier transforms (MFT). 

First, we create the ZWFS optical element. In principle, the only parameter you have to give is the pupil grid. Other parameters that influence the performance are:
1. The phase step. For an optimal sensitivity, use $\pi/2$.
2. The phase dot diameter. For an optimal sensitivity, use 1.06 $\lambda/D$. 
3. num_pix, sets the number of pixels the MFT uses.
4. The pupil diameter.
5. The reference wavelength.


In [None]:
ZWFS_ideal = ZernikeWavefrontSensorOptics(pupil_grid)
ZWFS_non_ideal = ZernikeWavefrontSensorOptics(pupil_grid, phase_step=0.9 * np.pi / 2, phase_dot_diameter=1.2)

In [None]:
def plot_ZWFS(wavefront_in, wavefront_out):
    '''Plot the input wavefront and ZWFS response.

    Parameters
    ---------
    wavefront_in : Wavefront
        The aberrated wavefront coming in
    wavefront_out : Wavefront
        The wavefront_in propagated through the ZWFS
    '''    

    # Plotting the phase pattern and the PSF
    fig = plt.figure()
    ax1 = fig.add_subplot(131)
    im1 = imshow_field(wavefront_in.amplitude, cmap='gray')
    ax1.set_title('Input amplitude')
    
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im1, cax=cax, orientation='vertical')

    ax2 = fig.add_subplot(132)
    im2 = imshow_field(wavefront_in.phase, cmap='RdBu')
    ax2.set_title('Input phase')
    
    divider = make_axes_locatable(ax2)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im2, cax=cax, orientation='vertical')
    
    ax3 = fig.add_subplot(133)
    im3 = imshow_field(wavefront_out.intensity, cmap='gray')
    ax3.set_title('Output intensity')
    
    divider = make_axes_locatable(ax3)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im3, cax=cax, orientation='vertical')
    plt.show()

Now it is time to measure the wavefront aberrations using the ZWFS. We create a random phase aberration using disk harmonic modes and propagate it through the ZWFS. 

In [None]:
Nmodes = 50

phase_aberrated = make_power_law_error(pupil_grid, 0.3, 1)

phase_aberrated -= np.mean(phase_aberrated[telescope_pupil >= 0.5])
    
wf = Wavefront(telescope_pupil * np.exp(1j * phase_aberrated))

wf_out = ZWFS_ideal.forward(wf)

plot_ZWFS(wf, wf_out)

From the intensity of the outcoming wavefront, it is clear that the intensity is dependent on the input phase. To show a simple reconstruction, we use the reconstruction algorithm of N'Diaye et al. 2013 [1]:

\begin{equation}
\phi = −1 + \sqrt{2I_c},
\end{equation}

where $\phi$ is the phase and $I_c$ the measured intensity. 

[1] N'Diaye et al. "Calibration of quasi-static aberrations in exoplanet direct-imaging instruments with a Zernike phase-mask sensor", Astronomy & Astrophysics 555 (2013)

In [None]:
def plot_reconstruction_phase(phase_in, phase_out, telescope_pupil):
    '''Plot the incoming aberrated phase pattern and the reconstructed phase pattern
    
    Parameters
    ---------
    phase_in : Field
        The phase of the aberrated wavefront coming in
    phase_out : Field
        The phase of the aberrated wavefront as reconstructed by the ZWFS
    '''    
    
    diff = phase_out - phase_in
    
    diff -= np.mean(diff[telescope_pupil >= 0.5])

    # Plotting the phase pattern and the PSF
    fig = plt.figure(figsize = (10,10))
    ax1 = fig.add_subplot(131)
    im1 = imshow_field(phase_in, cmap='RdBu', vmin = -0.2, vmax = 0.2, mask = telescope_pupil)
    ax1.set_title('Input phase')
    
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im1, cax=cax, orientation='vertical')

    ax2 = fig.add_subplot(132)
    im2 = imshow_field(phase_out, cmap='RdBu', vmin = -0.2, vmax = 0.2, mask = telescope_pupil)
    ax2.set_title('Reconstructed phase')
    
    divider = make_axes_locatable(ax2)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im2, cax=cax, orientation='vertical')
    
    ax3 = fig.add_subplot(133)
    im3 = imshow_field(diff, cmap='RdBu', vmin = -0.02, vmax = 0.02, mask = telescope_pupil)
    ax3.set_title('Difference')
    
    divider = make_axes_locatable(ax3)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im3, cax=cax, orientation='vertical')
    plt.show()

In [None]:
phase_est = -1 + np.sqrt(2*wf_out.intensity)

plot_reconstruction_phase(phase_aberrated, phase_est ,telescope_pupil)

In [None]:
focal_grid = make_focal_grid(20, 2)

prop = FraunhoferPropagator(pupil_grid, focal_grid)

M = Apodizer(circular_aperture(1.06)(focal_grid))

S = get_strehl_from_pupil(wf.electric_field, telescope_pupil)

b0 = prop.backward(M(prop.forward(Wavefront(telescope_pupil*np.exp(1j*phase_aberrated))))).electric_field.real

b = np.sqrt(S) * b0

phase_est = -1 + np.sqrt(np.abs(3 - 2 * b - (1 - wf_out.intensity) / b))

phase_est -= np.mean(phase_est[telescope_pupil >= 0.5])

plot_reconstruction_phase(phase_aberrated, phase_est,telescope_pupil )

In [None]:
prop_extra = FresnelPropagator(pupil_grid, distance = 1e-4)

phase_aberrated = make_power_law_error(pupil_grid, 0.3, 1)

phase_aberrated -= np.mean(phase_aberrated[telescope_pupil >= 0.5])

zbasis = make_zernike_basis(3,1,pupil_grid)


for test in zbasis:
    test*= telescope_pupil
    phase_aberrated -= test*np.dot(phase_aberrated,test)/np.dot(test,test)

# Use super-Gaussian to avoid edge effects
p =telescope_pupil#np.exp(-(pupil_grid.as_('polar').r / 0.68)**20)
wf_new = prop_extra(Wavefront(p*np.exp(1j * phase_aberrated)))

phase_aberrated = wf_new.phase

wf_new.electric_field[telescope_pupil <0.5] = 0

plt.figure(figsize = (10,10))
plt.subplot(1,2,1)
imshow_field(wf_new.I*telescope_pupil-1,vmin = -0.02,vmax = 0.02, cmap = 'gray')
plt.subplot(1,2,2)
imshow_field(wf_new.phase*telescope_pupil,vmin = -0.2,vmax = 0.2, cmap = 'RdBu',mask = telescope_pupil)
plt.show()


In [None]:
wf_out = ZWFS_ideal.forward(wf_new.copy())

S = get_strehl_from_pupil(wf_new.electric_field, telescope_pupil)
b = np.sqrt(S) * b0

phase_est = -1 + np.sqrt(np.abs(3 - 2 * b0 - (1 - wf_out.intensity) / b0))

plot_reconstruction_phase(phase_aberrated, phase_est, telescope_pupil)

In [None]:
vZWFS_ideal = VectorZernikeWavefrontSensorOptics(pupil_grid,num_pix =32)

vZWFS_non_ideal = VectorZernikeWavefrontSensorOptics(pupil_grid, phase_retardation = np.pi*1.05, phase_step= 0.9* np.pi / 2, phase_dot_diameter=1.2)

In [None]:
def plot_reconstruction_amplitude(amplitude_in, amplitude_out, telescope_pupil):
    '''Plot the incoming aberrated phase pattern and the reconstructed phase pattern
    
    Parameters
    ---------
    amplitude_in : Field
        The phase of the aberrated wavefront coming in
    amplitude_out : Field
        The amplitude of the aberrated wavefront as reconstructed by the vZWFS
    '''    

    amplitude_in = amplitude_in-1
    amplitude_out = amplitude_out-1
    
    # Plotting the phase pattern and the PSF
    fig = plt.figure(figsize = (10,10))
    ax1 = fig.add_subplot(131)
    im1 = imshow_field(amplitude_in, cmap='gray', vmin = -0.05, vmax = 0.05, mask = telescope_pupil)
    ax1.set_title('Input amplitude')
    
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im1, cax=cax, orientation='vertical')

    ax2 = fig.add_subplot(132)
    im2 = imshow_field(amplitude_out, cmap='gray', vmin = -0.05, vmax = 0.05, mask = telescope_pupil)
    ax2.set_title('Reconstructed amplitude')
    
    divider = make_axes_locatable(ax2)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im2, cax=cax, orientation='vertical')
    
    ax3 = fig.add_subplot(133)
    im3 = imshow_field(amplitude_out - amplitude_in, cmap='gray', vmin = -0.01, vmax = 0.01, mask = telescope_pupil)
    ax3.set_title('Difference')
    
    divider = make_axes_locatable(ax3)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im3, cax=cax, orientation='vertical')
    plt.show()

In [None]:
def analyzer(wavefront):
    '''Plot the incoming aberrated phase pattern and the reconstructed phase pattern
    
    Parameters
    ---------
    phase_in : Field
        The phase of the aberrated wavefront coming in
    phase_out : Field
        The phase of the aberrated wavefront as reconstructed by the ZWFS
    '''
    CPBS = CircularPolarizingBeamSplitter()
    wf_ch1, wf_ch2 = CPBS.forward(wavefront.copy())
    I_L = wf_ch1.I 
    I_R = wf_ch2.I


    M = Apodizer(circular_aperture(1.06)(focal_grid))   
    b0 = np.abs(prop.backward(M(prop.forward(Wavefront(telescope_pupil)))).electric_field)
    S = get_strehl_from_pupil(wf_new.electric_field, telescope_pupil)
    b = np.sqrt(S) * b0
    
    for i in range(4):
        amp_est = np.sqrt(I_L + I_R + np.sqrt(4 * b**2 * (I_R + I_L) - (I_R - I_L)**2 - 4 * b**4))
        phase_est = np.arcsin(I_L-I_R)/(2*amp_est*b)
        wf_est = Wavefront(amp_est*telescope_pupil*np.exp(1j*phase_est))

        b = prop.backward(M(prop.forward(wf_est))).electric_field.real

        
    return amp_est, phase_est
    

In [None]:
amp_est, phase_est = analyzer(vZWFS_ideal(wf_new.copy()))

amp_est=amp_est**2

plot_reconstruction_phase(wf_new.phase , phase_est, telescope_pupil )
plot_reconstruction_amplitude(wf_new.I, amp_est, telescope_pupil)

