Here we would like to perform PSF esitmation and deconvolution on some real data. We took the ganymede as an example. 

A few things need to consider: 
- Can we perform any estimation on the strehl ratio?

In [None]:
# Import library

# Packages required
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
from astropy.io import fits
import os
from deconvbench import Deconvbench
from mpl_toolkits.axes_grid1 import make_axes_locatable
from amiral import instructment, utils, parameter, gradient, minimisation, array, plotting
from astropy.visualization import make_lupton_rgb
import scipy


# from plotting import plot_PSF_PSD as amiral_plt
from scipy.optimize import minimize 

%matplotlib inline

import tools

rcParams["figure.figsize"] = 20,33

In [None]:
# Function for getting snr
def snr_map (array, n_sky, n_ron): 
    
    # S/N = S/N_tot = S /  sqrt (S+Sky+Dark +N_ron^2)
    # As the RON is small, we can ignore that 
    # Sky: background -> get the mean value of the background from the corner
    
    dimension = np.shape(array)[0]
    snr_map =  np.zeros((dimension,dimension))
    snr_map[:,:] = array[:,:] / np.sqrt(array[:,:] + n_sky + n_ron**2)
    
    return snr_map

def get_snr (array):
    
    mean = np.mean(array)
    sig2 = np.std(array)
    
    snr = mean / sig2
    
    return snr

def psd_object (param):
    
    rho = np.fft.fftshift(utils.dist(768))/param[1]
    psd_obj =  param[0]/ (np.power(rho,param[2]) + 1.)
    
    return psd_obj
    
def plot_psd_object(psd_obj): 
    
    fig, ax = plt.subplots()
    ycent = int((256*aosys_cls.samp_factor[0])//2)

    ax.plot(np.abs(psd_obj[ycent,...]))
    ax.set_title('PSD Object(total)')
    ax.axhline(y=1, color = 'r', ls = '--')
    
    pass

def create_psfao19_otf (otf_tel, guess, aosys_cls): 
    
    # Use PSFAO19 model to create a PSF
    psd_ao = aosys_cls.psd_residual_ao (guess = guess)
    psd_halo = aosys_cls.psd_residual_halo(r0 = guess[0])
    
    psd = psd_ao + psd_halo

    otf_atmo = aosys_cls.otf_atmo(psd)
    otf_total = otf_atmo*otf_tel
      
    return otf_atmo,otf_total


def resize_array (array, size):
    """
    Resize the array to a given size. 
    """
    
    cent = np.shape(array)
    zoomed_array = array[cent[0]//2 - size//2:cent[0]//2 +size//2, cent[0]//2 - size//2:cent[0]//2 +size//2]
    
    return zoomed_array


def zero_padding(array, pad):
    
    (_nx, _ny) = array.shape
    
    nx = _nx * pad 
    ny = _ny * pad 
    
    print(nx,ny)
    
    array_resize = np.zeros((nx, ny))
    
    dx = _nx*pad//2 - _nx//2
    dy = _ny*pad//2 - _ny//2
    
    print(array_resize.shape)
    
    array_resize[dx:dx+_nx,dy:dy+_ny] = array
    
    return array_resize



In [None]:
# PATH

wdir = "/Users/alau/Data/MUSE_DATA/Ganymede/2019sep08/"
data_cube = ["Ganymede_clean_cube_1", "Ganymede_clean_cube_2"]


Have a look at the actual data before performing any data analysis

In [None]:
# _data = fits.open(wdir+data_fname+".fits")
# _data.info()

# data = _data[0].data

_cube = fits.open(wdir+data_cube[1]+".fits")
_cube.info()

cube = _cube[1].data

In [None]:
# Compare the SNR

# snr_slice = get_snr(data)
# snr_cube = get_snr(array)

# print(snr_slice,snr_cube)

In [None]:
# Get information from the input image
DIMENSION = np.shape(cube[0])[0]
FLUX = np.sum(cube[0]) # Check the unit and make sure you know what it is

wvl = _cube[1].header['CRVAL3']*1e-10/1e-9

Let's take a slice out for analysis


In [None]:
num = 100
_slice = cube[num]
wvl = _cube[1].header['CRVAL3']*1e-10/1e-9 + num*_cube[1].header['CD3_3']*1e-10/1e-9
print(wvl)

In [None]:
# Other vairables from the VLT-MUSE instrument
RON = 15. # CCD read-out noise standard-deviation [e-]
GAIN = 5.0

Get the object from the high resolution simulation

In [None]:
aosys_dict = {
    'diameter': 8. , 
    'occ_ratio': 0.14 , 
    'no_acutuator' : 39, 
    'wavelength': wvl, 
    'dimension': DIMENSION,
    'resolution_rad' : 1.1977272727272726e-07
}

print(wvl)

# Passing parametpsd_arrayers from the telesope to the aosystem
aosys_cls = instructment.aoSystem( 
        diameter = aosys_dict['diameter'], occ_ratio = aosys_dict['occ_ratio'], 
        no_acutuator= aosys_dict['no_acutuator'], wavelength = aosys_dict['wavelength']*1e-9, 
        resolution_rad = aosys_dict['resolution_rad'], 
        dimension=aosys_dict['dimension'])  

We need a PSF

In [None]:
# Set up the telescope and produce a PSF
amiral_dict = {
    "r0": 0.25,                  
    "background": 1e-7,      
    "amplitude": 2.5,       
    "ax": 0.05,                            
    "beta": 1.5, 
    "mu": 0., 
    "rho0": 0., 
    "p": 3.0
}

amiral_keys, psf_guess = utils.dict2array(amiral_dict)

In [None]:
# ft_zoom_array = zoom_array(np.fft.fftshift(ft_array), 50)
# ft_zoom_padded = zoom_array(np.fft.fftshift(ft_data_resize), 50)

# fig, ax = plt.subplots(1,2)

# rcParams['figure.figsize'] = 16,21

# divider = make_axes_locatable(ax[0])
# cax = divider.append_axes("right", size="5%", pad=0.05)
# im = ax[0].imshow(np.real(np.log10(np.fft.fftshift(ft_array))),cmap = 'gray')
# fig.colorbar(im,cax ,ax=ax[0])
# ax[0].set_title('Array',fontsize = 18)

# divider = make_axes_locatable(ax[1])
# cax = divider.append_axes("right", size="5%", pad=0.05)
# im1 = ax[1].imshow(np.real(np.log10(np.fft.fftshift(ft_data_resize))),cmap = 'gray')
# fig.colorbar(im1, cax ,ax=ax[1])
# ax[1].set_title('Padded array', fontsize = 18)




In [None]:
cube_masked = np.ma.masked_invalid(cube)
_ind = np.where(cube_masked[0].mask == True)

cube_masked[0][_ind] = 0
# img = cube_masked[0]
# img = data_resize

img = _slice
amiralparam = parameter.amiralParam(img ,guess = psf_guess, aosys = aosys_cls)

In [None]:
# What variables to be minimised
param_mask = np.asarray([1,1,1,1,1])
hyper_param_mask = np.asarray([1,1,0])

mask = np.concatenate((param_mask,hyper_param_mask))

hyper_guess = amiralparam.hyperparam_initial(psf_guess)
hyper_min, hyper_max = amiralparam.hyperparam_bound(psf_guess, p_upperbound = 100.)

psf_guess[-3] = hyper_guess[0] 
psf_guess[-2] = hyper_guess[1] 
psf_guess[-1] = hyper_guess[2] 

# r, background, sig2, ax, beta
param_min = np.asarray([0.1,0,0,1e-8,1.01])
param_max =  np.asarray([0.99,1e8,1e8,3,10])

upperbound = np.concatenate((param_max, hyper_max))
lowerbound = np.concatenate((param_min, hyper_min))

param_numerical_condition = np.array([1., 1e-4, 1., 1., 1.])
hyperparam_numerical_condition = np.array([hyper_guess[0], hyper_guess[1], 1.])

numerical_condition = np.concatenate((param_numerical_condition, hyperparam_numerical_condition))

amiral_cls = parameter.amiral(img=img, guess=psf_guess, aosys = aosys_cls, upperbound = upperbound, lowerbound= lowerbound, numerical_condition = numerical_condition, fourier_variable = amiralparam.fourier_variable, mask = mask)

In [None]:
est_criterion, value_criterion, value_grad = amiral_cls.minimisation(psf_guess)

In [None]:
print(est_criterion)

In [None]:
psd_ao = aosys_cls.psd_residual_ao (est_criterion)
psd_halo = aosys_cls.psd_residual_halo(est_criterion[0])

psd = psd_halo + psd_ao 

pupil = aosys_cls.get_pupil_plane()
otf_tel = aosys_cls.pupil_to_otf_tel(pupil)

est_otf_atmo, est_otf = create_psfao19_otf(otf_tel, est_criterion[0:5], aosys_cls)
est_psf = np.fft.ifft2(np.fft.ifftshift(est_otf))

In [None]:
psf_tel = np.fft.ifft2(otf_tel)
est_SR = np.max(np.real(est_psf)) / np.max(np.real(psf_tel))
# SR = np.max(psf_total) / np.max(psf_tel)

In [None]:
est_SR

Intake the estimated criterion for plotting the PSF

In [None]:
# # SNR map
# n_ron = np.sqrt(98*15**2)

# _map = snr_map(data_resize,16000., n_ron)

# plt.imshow(_map)

# print(np.min(_map))

In [None]:
def rebin(im, bin):
    """
    Rebin an image im by bins of size bin x bin. Taken from
    https://www.southampton.ac.uk/~sdc1g08/AstropyFitsImageRebin.py
    :param im: Input image, 2D array
    :param bin: bin size in pixels, integer
    :return: Binned image

    """
    # Resize array by getting rid of extra columns and rows
    xedge = np.shape(im)[0] % bin
    yedge = np.shape(im)[1] % bin
    im = im[xedge:, yedge:]

    # Reshape image to new size
    binim = np.reshape(im, (int(np.shape(im)[0] / bin), bin, int(np.shape(im)[1] / bin), bin))

    # Sum each bin x bin subarray
    binim = np.sum(binim, axis=3)
    binim = np.sum(binim, axis=1)

    return binim

binned_psf = rebin(est_psf, aosys_cls.samp_factor[0])




# # Output the file to a fits file in here
# hdu1 = fits.PrimaryHDU()
# hdu2 = fits.ImageHDU(data=np.real(binned_psf))
# new_hdul = fits.HDUList([hdu1, hdu2])

# hdr = new_hdul[1].header

# new_hdul.info()

# import datetime

# # Get the today's date
# date  = datetime.datetime.now()

# # Add the date to the header

# hdr['r0'] = est_criterion[0]
# hdr['bck']= est_criterion[1]
# hdr['sig2'] = est_criterion[2]
# hdr['ax']  = est_criterion[3]
# hdr['beta']  = est_criterion[4]
# hdr['mu']  = est_criterion[5]
# hdr['rho0']  = est_criterion[6]
# hdr['p']  = est_criterion[7]
# hdr['DATE'] = date.strftime("%Y-%m-%d")
# hdr['wvl(A)'] = 7307

# # the name can input from the .ini file
# array.save_fits(img_obj=new_hdul, name = os.path.join('psf_wvl_7307'))

In [None]:
est_param = est_criterion[0:5]
est_otf_atmo, est_otf_total = create_psfao19_otf(otf_tel,est_param,aosys_cls)


from amiral.plotting import plotting_PSF_PSD
# Plot the PSF slice
rcParams['figure.figsize'] = 11,9

ycent = int((est_psf.shape[0]//2))
    
fig, ax = plt.subplots()

ax.set_title(r"PSF Estimation (with noise)$\mathrm{ \{p_{fixed} \}}$"
             "\nSR Error: %.2f %%\nFlux[e-]: %.2E\nWvl: %.2fnm" %(0,0,wvl),fontsize = 18)
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_ylabel(r'$\mathrm{Mean Intensity [e^-]}$', fontsize = 18)
ax.set_xlabel(r'$\mathrm{Position [pixel]}$', fontsize = 18)
# ax.plot(utils.mean_cir_array(np.real(psf_total)), label = "True")
ax.plot(utils.mean_cir_array(np.real(binned_psf)), label = "Estimated PSF")
ax.tick_params(axis='both', which='major', labelsize=18)
ax.legend()

# fig.savefig('/Users/alau/Data/amiral_fits/VESTA/SNR/2021apr06/deconv/VESTA_PSF_Estimation.pdf')

In [None]:
plt.title(r"Estimated PSF cut$\mathrm{ \{p_{fixed} \}}$"
             "\nSR Error: %.2f %%\nFlux[e-]: %.2E\nWvl: %.2fnm" %(0,0,wvl),fontsize = 18)

ycent = int((binned_psf.shape[0]//2))

plt.plot(binned_psf[ycent,:])
plt.yscale('log')

In [None]:
# # masking

# def masking (array, pad): 
    
#     (_nx, _ny) = array.shape
    
#     nx = _nx * pad 
#     ny = _ny * pad 
    
#     print(nx,ny)
    
#     mask = np.full((nx,ny),False, dtype=bool)
#     _mask_arr = np.full((_nx,_ny),True, dtype=bool)
    
#     dx = _nx*pad//2 - _nx//2
#     dy = _ny*pad//2 - _ny//2
    
#     mask[dx:dx+_nx,dy:dy+_ny] = _mask_arr
    
#     return mask

# mask = masking(data, aosys_cls.samp_factor[0])

# plt.imshow(mask)



In [None]:
from maoppy.instrument import muse_nfm

deadMap = np.full((cube_masked[0].shape),True, dtype=bool)

bad = np.where(img == 0) # all columns and rows in [i]
deadMap[bad[0],bad[1]] = False

plt.imshow(deadMap)

In [None]:
# binned
import time
start_time = time.time()
dec = Deconvbench(img, binned_psf,ron = 15)
# dec.weights *= ~deadMap # take into account dead pixels
dec.verbose_modulo = 500 # print every 10 iteration
dec.regularization.scale *= 4. # sharpen details (reduce regularization)
# weights in here can be a mask as the snr is calculated interna;;y

estim = dec.run()
end_time = time.time()
runtime = end_time - start_time

print("Run Time (mintues): ",runtime/60)

Right, we have the deconvolved object now. Is the flux conserved? 

In [None]:
# print("Flux of the image [e-]: %f" %(np.sum(data)))
# print("Flux of the deconvolved object: %f" %(np.sum(estim)))
# print("Flux difference [%%]: %f" %(np.sum(data-estim)/np.sum(data_resize)))

def write2header (param,wvl,flux,snr,keys): 

    hdr = fits.Header()

    param = np.append(param, flux)
    param = np.append(param, snr)
    param = np.append(param, wvl)

    for i in range (len(keys)): 
        hdr[keys[i]] = param[i]

    return hdr

hdr = write2header (est_criterion,0.,0.,wvl,amiral_keys) 

data_dir = "/Users/alau/Data/amiral_fits/ganymede_cube/deconv/"

# fits.writeto(data_dir+"ganymede_"+str(wvl)+"_binned_10"+'.fits', estim, hdr)

os.getcwd()

plt.imshow(estim)


In [None]:
# zoom into the array

# zoom_residual = zoom_array(estim-asteriod_resize, 200)

fig, ax = plt.subplots(1,3)
rcParams['figure.figsize'] = 19,21
fig.tight_layout(pad=1.0, w_pad=1.0, h_pad=5.0)

divider = make_axes_locatable(ax[0])
cax = divider.append_axes("right", size="5%", pad=0.05)
im = ax[0].imshow(resize_array(estim,128),cmap = 'gray')
# im = ax[0].imshow(estim,cmap = 'gray')
fig.colorbar(im,cax ,ax=ax[0])
ax[0].set_title('L2-L1 Deconvolved with the esitmated PSF''\nFlux [e-]: %.2E''\nWvl: %.3fnm'%(np.sum(resize_array(estim,128)),wvl), fontsize = 14)

divider = make_axes_locatable(ax[1])
cax = divider.append_axes("right", size="5%", pad=0.05)
im1 = ax[1].imshow(resize_array(img,128),cmap = 'gray')
fig.colorbar(im1, cax ,ax=ax[1])
ax[1].set_title('Original image''\nFlux [e-]: %.2E''\nWvl: %.3fnm'%(np.sum(resize_array(img,128)),wvl))

divider = make_axes_locatable(ax[2])
cax = divider.append_axes("right", size="5%", pad=0.05)
im2 = ax[2].imshow(resize_array(img-estim,128),cmap = 'gray')
fig.colorbar(im2, cax ,ax=ax[2])
ax[2].set_title('Difference''\nFlux Diff (%%): %.2f'%(100*np.sum(resize_array(img-estim,128))/np.sum(resize_array(img,128))), fontsize = 14)



print(wvl)

fig.savefig('test.pdf')

# fig.savefig('/Users/alau/Data/amiral_fits/VESTA/SNR/2021apr06/deconv/VESTA_noise_free_residual.pdf')

In [None]:
fig, ax = plt.subplots()
rcParams['figure.figsize'] = 3,3
fig.tight_layout(pad=1.0, w_pad=1.0, h_pad=5.0)

ft_img = np.fft.fft2(img)
ax.imshow(np.log10(np.abs(np.fft.fftshift(ft_img))))

In [None]:
fig, ax = plt.subplots()
rcParams['figure.figsize'] = 9,6

fig.tight_layout(pad=1.0, w_pad=1.0, h_pad=5.0)

ft_img = np.fft.fft2(img)

ax.plot(utils.mean_cir_array(np.abs(np.fft.fftshift(ft_img))))
ax.set_title('OTF')
ax.set_xscale('log')
ax.set_yscale('log')

In [None]:
# import time

# start_time = time.time()

# dec = Deconvbench(obs_image,psf_total,ron = 10)
# dec.verbose_modulo = 100 # print every 10 iteration
# dec.regularization.scale *= 2. # sharpen details (reduce regularization)
# estim_1 = dec.run()

# end_time = time.time()
# runtime = end_time - start_time

# print("Run Time (mintues): ",runtime/60)

In [None]:
# perfect_deconv = estim_1

# zoom_diff = zoom_array(estim-perfect_deconv, 200)
# zoom_deconv = zoom_array(perfect_deconv, 200)

# fig, ax = plt.subplots(1,2)

# rcParams['figure.figsize'] = 16,19

# divider = make_axes_locatable(ax[0])
# cax = divider.append_axes("right", size="5%", pad=0.05)
# im = ax[0].imshow(zoom_deconv)
# fig.colorbar(im,cax ,ax=ax[0])
# ax[0].set_title('L2-L1 Deconvolved with the exact PSF''\nFlux [e-]: %.2E' %(np.sum(perfect_deconv)), fontsize = 14)

# divider = make_axes_locatable(ax[1])
# cax = divider.append_axes("right", size="5%", pad=0.05)
# im1 = ax[1].imshow(zoom_diff)
# fig.colorbar(im1, cax ,ax=ax[1])
# ax[1].set_title('Residual (Estimated_PSF - True_PSF)\nDiff(SR) = %.2f %%\nFlux Difference: %.2f %%\nr0 = %.2f cm sig2 = %.2f'
#              %(diff_SR,np.sum(obs_image-perfect_deconv)/np.sum(obs_image), 100*psf_param[0], psf_param[2]), fontsize = 14)


# print(np.sum(obs_image-perfect_deconv)/np.sum(obs_image))


In [None]:
# fig, ax = plt.subplots()

# rcParams['figure.figsize'] = 33 ,24

# im = ax.imshow(np.abs(estim-perfect_deconv))
# fig.colorbar(im, ax=ax)
# ax.set_title('Exact PSF Residual(Normalised)\nDiff(SR) = 2.16%')

In [None]:
# fig, ax = plt.subplots()

# rcParams['figure.figsize'] = 33 ,24

# im = ax.imshow(asteriod_resize)
# fig.colorbar(im, ax=ax)
# ax.set_title('Object')

In [None]:
# from deconvbench.stat import DSPFit, Circmoyto2D
# from deconvbench import RegulPSD

# rho, psd_param, _, psd1d = DSPFit(obs_image)
# psd1 = Circmoyto2D(rho,psd1d,obs_image.shape[0])

# hyper = est_crtierion[-3:-1]


# psd_object = psd_object(hyper)




# # plot PSD

# rcParams['figure.figsize'] = 13 ,11

# ycent = int((psf_total.shape[0]//2))
    
# fig, ax = plt.subplots()
# ax.set_xscale('log')
# ax.set_yscale('log')
# ax.plot(utils.mean_cir_array(np.real(psd1)), label = "True")
# ax.plot(utils.mean_cir_array(np.real(est_psf)), label = "Estimated PSF")

# ax.legend()

# print(psd_param)




In [None]:
# #%% ITERATIVE DECONVOLUTION
# dec_psd1 = Deconvbench(obs_image, est_psf, ron=10, positivity=False, verbose=True)
# dec_psd1.verbose_modulo = 50
# dec_psd1.regularization = RegulPSD(psd1) # set PSD regularization
# estD1 = dec_psd1.run()

In [None]:
# fig, ax = plt.subplots()

# rcParams['figure.figsize'] = 33 ,24

# im = ax.imshow(estD1)
# fig.colorbar(im, ax=ax)
# ax.set_title('Deconv')

In [None]:
# from deconvbench.stat import DSPFit, Circmoyto2D
# from deconvbench import RegulPSD

# rho, psd_param, _, psd1d = DSPFit(obs_image)
# psd1 = Circmoyto2D(rho,psd1d,obs_image.shape[0])


# #%% ITERATIVE DECONVOLUTION
# dec_psd = Deconvbench(obs_image, psf_total, ron=10, positivity=False, verbose=True)
# dec_psd.verbose_modulo = 50
# dec_psd.regularization = RegulPSD(psd1) # set PSD regularization
# estD = dec_psd.run()

In [None]:
# fig, ax = plt.subplots(1,3)

# rcParams['figure.figsize'] = 33 ,24

# divider = make_axes_locatable(ax[0])
# cax = divider.append_axes("right", size="5%", pad=0.05)
# im = ax[0].imshow(estD)
# fig.colorbar(im, cax, ax = ax[0])
# ax[0].set_title('PSD Regularisation (With exact PSF)',fontsize = '18')

# divider = make_axes_locatable(ax[1])
# cax = divider.append_axes("right", size="5%", pad=0.05)
# im1 = ax[1].imshow(estD1)
# fig.colorbar(im1, cax, ax = ax[1])
# ax[1].set_title('PSD Regularisation (With estimated PSF)',fontsize = '18')


# divider = make_axes_locatable(ax[2])
# cax = divider.append_axes("right", size="5%", pad=0.05)
# im2 = ax[2].imshow(obs_image)
# fig.colorbar(im2, cax ,ax=ax[2])
# ax[2].set_title('Observed Image',fontsize = '18')




In [None]:
# import RGB and form a RGB image
# data_path = ["ganymede_6051_02",  "ganymede_5051_02", "ganymede_4749_77"]

_R_data = fits.open(data_path[0]+"_test"+".fits")
R_data = _R_data[0].data

_G_data = fits.open(data_path[1]+"_test"+".fits")
G_data = _G_data[0].data

_B_data = fits.open(data_path[2]+"_test"+".fits")
B_data = _B_data[0].data


In [None]:
image = make_lupton_rgb(R_data, G_data, B_data)
plt.imshow(image)

In [None]:
plt.imshow(R_data)

In [None]:
plt.imshow(G_data)

In [None]:
plt.imshow(B_data)

In [None]:
import os 
dirname = os.path.dirname(wdir+data_cube[1]+".fits")
print(dirname)