In [12]:
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sb

from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.io import fits

from matplotlib import colors as col
from astropy.wcs import WCS

from photutils import aperture as aper
import os
import requests
from ipywidgets import widgets
from IPython.display import  FileLink
from time import perf_counter
import json


# get the WSID and password if not already defined
if not os.environ.get('CASJOBS_USERID'):
    os.environ['CASJOBS_USERID'] = "avinashck90"
if not os.environ.get('CASJOBS_PW'):
    os.environ['CASJOBS_PW'] = "INSIST369"

sb.set_style('darkgrid')
matplotlib.rcParams['font.size']=12
matplotlib.rcParams['figure.figsize']=(10,10)

In [2]:
if not os.path.exists('mastcasjobs'):
    os.system("pip install git+https://github.com/rlwastro/mastcasjobs@master --quiet")

In [3]:
import mastcasjobs

In [39]:
class PSF_gen():
    """PSF Generator using Source catalog or CasJobs GALEX Catalog

    Parameters
    ----------
    name (string) : Catalog name or coordinates of the source

    df   (pd.DataFrame ): Pandas dataframe with source catalog

    cols (dict) : dict object with column name conversions for ra,dec,mag_nuv. 
    Eg {'RA': 'ra','Dec' : 'dec', 'ABmag' : 'mag_nuv'}

    'ra' (degrees)

    'dec' (degrees)

    'mag_nuv' (ABmag)

    pixel_scale (float) : pixel scale of detector in the unit 
    of arcseconds/pixels

    fwhm_in (float) : Full Width at Half Maximum of the Gaussian PSF

    n_pix_main (int) : Number of pixels along one axis in the detector. 
    Total number of pixels in the detector is n_pix_main**2

    n_pix_sub (int) : Number of pixels along one axis in the path. Total number 
    of pixels in the patch is n_pix_sub**2

    """
    def __init__(self,name = None, df = None, cols = None, size = 0.01, 
              exp_time = 600, axis = 'on',mode = 'zmax'):   

        self.axis = axis
        self.mode = mode
        self.name = name
        self.cols = cols

        self.radius      = size/2
        self.pixel_scale = 0.1
        self.qe          = 0.5
        self.bias        = 733
        self.gain        = 1.85
        self.bit_res     = 16
        self.full_well_capacity = pow(2, self.bit_res)

        self.M_sky       = 27.5
        self.M_sky_p     = self.M_sky - 2.5*np.log10(self.pixel_scale**2)
        self.RN          = 3
        self.exp_time    = exp_time  # seconds
        self.df          = df


        self.params = {'shot_noise' : 'gaussian',
                       'sky'        : 'gaussian',
                       'PRNU'       :  True,
                       'DC'         :  True,
                       'T'          :  218,        # K
                       'DFM'        :  1.424e-2,   # 14.24pA
                       'pixel_area' :  1e-6,       # cm2
                       'DNFP'       :  False,
                       'DN'         :  0.4}
        self.init_df()
        self.init_psf_patch() 

    def init_psf_patch(self, return_psf = False):

        self.zero_mag_s_on =  self.exp_time*1.51e3*3631*np.pi*(100/2)**2*(1500/2250)*0.8**6*0.95**2*0.68*0.83 # Photons

        self.zero_mag_s_off = self.exp_time*1.51e3*3631*np.pi*(100/2)**2*(1500/2250)*0.8**5*0.95**2*0.83      # Photons

        if self.mode == 'Zeemax':
            if self.axis =='on':
                image =  np.load('data/On_PSF_Zmax.npy')
                image /= image.sum()
                self.image_g_sub = image
                F_sky_p           = self.zero_mag_s_on*pow(10,-0.4*self.M_sky_p)
                self.sky_bag_flux = F_sky_p    
                self.zero_flux    = self.zero_mag_s_on 

            elif self.axis=='off':
                image  = np.load('data/Off_PSF_Zmax.npy')
                image /= image.sum()
                self.image_g_sub  = image
                F_sky_p           = self.zero_mag_s_off*pow(10,-0.4*self.M_sky_p)
                self.sky_bag_flux = F_sky_p    
                self.zero_flux    = self.zero_mag_s_off 

        elif self.mode =='HCIPy':
            if self.axis =='on':
                image  = np.load('data/on_axis_hcipy.npy')
                image /= image.sum()
                self.image_g_sub =  image
                F_sky_p           = self.zero_mag_s_on*pow(10,-0.4*self.M_sky_p)
                self.sky_bag_flux = F_sky_p    
                self.zero_flux    = self.zero_mag_s_on  

            elif self.axis=='off':
                image  = np.load('data/off_axis_hcipy.npy')
                image /= image.sum()
                self.image_g_sub  = image
                F_sky_p           = self.zero_mag_s_off*pow(10,-0.4*self.M_sky_p)
                self.sky_bag_flux = F_sky_p    
                self.zero_flux    = self.zero_mag_s_off 
        else:
            print("Invalid Mode or axis")
        if return_psf:
            return image*self.zero_flux

    def mastQuery(self,request, json_return=False):
        """
        Perform a MAST query.

        Parameters
        ----------
        request (dictionary): The MAST request json object

        Returns the text response or (if json_return=True) the json response
        """

        url = "https://mast.stsci.edu/api/v0/invoke"

        # Encoding the request as a json string
        requestString = json.dumps(request)

        # make the query
        r = requests.post(url, data=dict(request=requestString))

        # raise exception on error
        r.raise_for_status()

        if json_return:
            return r.json()
        else:
            return r.text


    def resolve(self,name):
        """Get the RA and Dec for an object using the MAST name resolver

        Parameters
        ----------
        name (str): Name of object

        Returns RA, Dec tuple with position"""
        resolverRequest = {'service':'Mast.Name.Lookup',
                        'params':{'input':name,
                                  'format':'json'
                                  },
                        }
        resolvedObject = self.mastQuery(resolverRequest, json_return=True)
      # The resolver returns a variety of information about the resolved object, 
      # however for our purposes all we need are the RA and Dec
        try:
            objRa = resolvedObject['resolvedCoordinate'][0]['ra']
            objDec = resolvedObject['resolvedCoordinate'][0]['decl']
        except IndexError as e:
            raise ValueError("Unknown object '{}'".format(name))
        return (objRa, objDec)


    def init_df(self):     

        if self.name is not None and self.df is None:
            self.name        = self.name
            self.ra,self.dec = self.resolve(self.name)
            print("Input : MASTCasJobs")
            self.jobs = mastcasjobs.MastCasJobs(context="GALEX_Catalogs")
            self.query =   f"""SELECT ra,dec, mag_nuv,mag_fuv 
                              FROM
                              gcat_asc
                              WHERE
                              ra BETWEEN {self.ra} - {self.radius} AND {self.ra}   
                              + {self.radius}
                              AND 
                              dec BETWEEN {self.dec}- {self.radius} AND {self.dec} 
                              + {self.radius}
                          """
            print("Generating Dataframe")
            self.df = self.jobs.quick(self.query, task_name="python cone search").to_pandas()

        else: 
            print("Input : Dataframe.")
            if self.cols is not None:
                self.df = self.df.rename(columns = self.cols) 
            self.ra   = (self.df['ra'].max()+self.df['ra'].min())/2
            self.dec  = (self.df['dec'].max()+self.df['dec'].min())/2
            self.name = f" RA : {np.round(self.ra,3)} degrees, Dec : {np.round(self.dec,3)} degrees"

    def init_image_array(self, return_img = False):

        self.n_pix_sub  = self.image_g_sub.shape[0]
        del_ra  = self.df.ra.max()  - self.df.ra.min()
        del_dec = self.df.dec.max() - self.df.dec.min()

        n_pix_main = del_ra*3600/self.pixel_scale if del_ra>=del_dec else del_dec*3600/self.pixel_scale

        self.n_pix_main = int(n_pix_main) + 2*self.n_pix_sub

        if self.n_pix_main <=10000:
            self.image    = np.zeros((self.n_pix_main, self.n_pix_main))
            self.wcs      = self.create_wcs(self.n_pix_main,self.ra, self.dec, self.pixel_scale)

        else:
            print("FoV is too big.")

        if return_img:
            return self.image, self.wcs


    def create_wcs(self,npix,ra,dec,pixel_scale):
        """
        Function for creating WCS object for given ra and dec of a field or 
        a catalog using pixel scale and n_pix_main
        """
        w = WCS(naxis=2)
        w.wcs.crpix = [(npix-1)//2, (npix-1)//2]
        w.wcs.cdelt = np.array([-pixel_scale/3600, self.pixel_scale/3600])
        w.wcs.crval = [ra, dec]
        w.wcs.ctype = ["RA---TAN", "DEC--TAN"]
        return w

    def compute_coeff_arrays(self):

        if self.qe>0 and self.qe<1:

            n_pix = self.n_pix_main - self.n_pix_sub

            if n_pix > 205 :
                n_pix -= self.n_pix_sub

            self.qe_array =  np.random.normal(loc=self.qe, 
                                            scale=0.01,
                                            size=(n_pix, n_pix))
        else:
            print('QE should in the range (0,1]')


        self.bias_array =  np.random.normal(loc=self.bias, 
                                          scale=self.RN,
                                          size=(n_pix, n_pix))
        if self.params['PRNU']:

            self.PRNU_array =  np.random.normal(loc=0, 
                                            scale= 0.02,
                                            size=(n_pix, n_pix))
        if self.params['DC']:
            self.DR = self.dark_current(self.params['T'], self.params['DFM'], 
                                      self.params['pixel_area'])

            self.DC_array = np.random.normal(loc = self.DR*self.exp_time, 
                                              scale = np.sqrt(self.DR*self.exp_time),
                                              size=(n_pix, n_pix))


        if self.params['DNFP'] and self.params['DC']:
            self.DNFP_array =  np.random.lognormal(mean= 0, 
                                        sigma = self.exp_time*self.DR*self.params['DN'],
                                        size=(n_pix, n_pix))
            self.DC_array*=(1 + self.DNFP_array)

        print("Coefficients computed...")

    def dark_current(self,T, DFM, pixel_area):
        Kb  = 8.62e-5
        const	= 2.55741439581387e15

        EgT	= 1.1557 - (7.021e-4*T**2/(1108+T))
        DR	= const*pixel_area*(T**1.5)*DFM*np.exp(-EgT/(2*Kb*T))
        return DR
    
    def generate_photons(self, image,npix_m, npix_s,df):
        """
        This function creates PSFs based on ABmag and FWHM which  on a 
        small patch (2D array) of size n_pix_sub*n_pix_sub. 
        The patch with the PSF is then added to the image array of size 
        n_pix_main*n_pix_sub using wcs object.
        """
        if npix_s%2 ==0:
            patch_width_l = npix_s//2
            patch_width_r = npix_s//2

        else:
            patch_width_l = npix_s//2 
            patch_width_r = npix_s//2 +1

        for i, row in df.iterrows():

            c = SkyCoord(row['ra'],row['dec'],unit=u.deg)
            pix = self.wcs.world_to_array_index(c)
            ABmag = row['mag_nuv']

            flux  = self.zero_flux*10**(-ABmag/2.5)  # Photo-elec per second

            patch =  flux*self.image_g_sub

            x1 = pix[0] - patch_width_l
            x2 = pix[0] + patch_width_r
            y1 = pix[1] - patch_width_l
            y2 = pix[1] + patch_width_r

            image[ x1: x2, y1:y2 ] += patch

        image   = image[patch_width_l-1:-patch_width_r-1,patch_width_l-1:-patch_width_r-1]
        if image.shape[0]>205:
            image = image[patch_width_l-1:-patch_width_r-1,patch_width_l-1:-patch_width_r-1]

        return image

    def compute_shot_noise(self,array,type_ = 'gaussian'):

        if type(array) == np.float64 :
            n_pix = self.n_pix_main
        else :
            n_pix = array.shape[0]

        if type_ == 'gaussian':
            shot_noise = np.random.normal(loc=array, scale=np.sqrt(array), size = (n_pix, n_pix))
        elif type_ =='poisson':
            shot_noise = np.random.poisson(lam=array, size = (n_pix, n_pix)).astype(array.dtype)
        else:
            print('Light shot noise disabled')
            shot_noise = array

        return shot_noise  


    def __call__(self,params = None,n_stack =1,stack_type = 'median'):


        if params is not None:
            self.params.update(params)
        readout_stack = []

        self.init_image_array()
        self.compute_coeff_arrays()

        for i in range(n_stack):

            self.init_image_array()

            self.photon_array    = self.generate_photons(self.image,self.n_pix_main,self.n_pix_sub, self.df)
            self.photoelec_array = self.photon_array*self.qe_array


            self.shot_noise      = self.compute_shot_noise(self.photoelec_array,type_ = self.params['shot_noise'])

            self.photoelec_array = self.shot_noise.copy()

            self.n_pix_main = self.photoelec_array.shape[0]

            if self.params['sky']:
                self.sky_array = self.compute_shot_noise(self.sky_bag_flux,'gaussian')
                self.sky_array *= self.qe_array

                self.light_array = self.photoelec_array +  self.sky_array
            else:
                self.light_array = self.photoelec_array

            if self.params['PRNU']:
                self.light_array*=(1+self.PRNU_array)

            if self.params['DC']:
                self.charge = self.light_array + self.DC_array

            readout = (self.charge + self.bias_array)/self.gain   # ADU

          # Full well 

        self.readout = np.where(readout>self.full_well_capacity, self.full_well_capacity, readout)
            
        readout_stack.append(self.readout)

        readout_stack = np.array(readout_stack)
        if n_stack>1:
            if stack_type == 'median':
                self.readout = np.median(readout_stack, axis = 0)
            elif stack_type == 'mean':
                self.readout = np.median(readout_stack, axis = 0)

        self.wcs = self.create_wcs(self.n_pix_main,self.ra,self.dec, self.pixel_scale)

        self.header = self.wcs.to_header()

        print("Image generation completed!")
        return self.readout

In [40]:
class PSF(PSF_gen):
  def __init__(self,name = None, df = None, cols = None, size = 0.01, 
                 exp_time = 100, axis = 'on',mode = 'zmax'):
      super().__init__(name, df, cols, size ,exp_time, axis, mode)

  def __call__(self,params= None, n_stack =1, stack_type ='median'):
    super().__call__(params = params, n_stack = n_stack, stack_type = stack_type)
    
    if self.axis == 'off':
      zero_p_flux = (3010478142.88666  + self.bias)*self.gain
    else:
      zero_p_flux = (1516505736.205873 + self.bias)*self.gain

    c          = SkyCoord(self.df['ra'], self.df['dec'],unit=u.deg)
    data       = self.readout
    wcs        = self.wcs
    pix        = wcs.world_to_array_index(c)

    position        = [(i,j) for i,j in zip(pix[1],pix[0])]

    aperture        = aper.CircularAperture(position, r=0.3/0.1)
    ap_pix          = np.count_nonzero(aperture.to_mask()[0])
    aperture_bag    = aper.CircularAnnulus(position, r_in = 0.3/0.1, r_out = 1/0.1)
    bag_mask        = aperture_bag.to_mask()[0]

    bag_flux        = bag_mask.get_values(data)

    # Median bag flux
    bag_flux_med    = np.sort(bag_flux)[len(bag_flux)//2]

    phot_table      = aper.aperture_photometry(data, [aperture, aperture_bag])

    phot_table['sky_flux'] = np.pi*3**2*bag_flux_med
    phot_table['flux']     = phot_table['aperture_sum_0'].value - phot_table['sky_flux'].value

    phot_table['flux_err'] = np.sqrt( phot_table['flux'].value  + phot_table['sky_flux'].value )
 
    phot_table['SNR']      = phot_table['flux'].value/ phot_table['flux_err'].value

    phot_table['mag_in']   = self.df.mag_nuv
    phot_table['mag_0.3']  = -2.5*np.log10(phot_table['flux'].value/(zero_p_flux*self.exp_time))
    phot_table['mag_err']  = 1.087*phot_table['flux_err'].value/phot_table['flux'].value
    self.phot_table = phot_table


  def show_field(self,figsize=(10,10)):
    """
    Function for creating a scatter plot of sources within the FoV

    Returns
    -------
    fig, ax
    """
    if self.wcs is None :
      self.init_image_array()

    fig, ax = plt.subplots(1,1,figsize=figsize)
    ax.scatter(self.df['ra'],self.df['dec'],marker='.',color='black')
    ax.set_title(f" Requested Center : {self.name} \n FoV : {np.round(self.pixel_scale*(self.n_pix_main-2*self.n_pix_sub )/3600,3)} degrees | {len(self.df)} sources")
    ax.invert_xaxis()
    ax.set_xlabel('RA (Degrees)')
    ax.set_ylabel('Dec (Degrees)')
    return fig,ax

  def show_image(self, source = 'Readout', figsize = (15,10)):
    """
    Function for plotting the simulated field with PSFs

    Returns
    -------
    fig, ax
    """

    fig = plt.figure(figsize = figsize)
    norm = None

    if source =='Readout':
      data  = self.readout
      norm = col.LogNorm()
    elif source == 'Sky':
      data = self.sky_array
    elif source == 'DC':
      data = self.DC_array
    elif source == 'QE':
      data = self.qe_array
    elif source =='Bias':
      data = self.bias_array + self.DC_array
    elif source == 'PRNU':
      data = self.PRNU_array
    elif source == 'DNFP':
      norm = col.LogNorm()
      data = self.DNFP_array

    ax = fig.add_subplot(projection=self.wcs)
    ax.patch.set_edgecolor('black')  
    ax.patch.set_linewidth('3') 
    img = ax.imshow(data,cmap='jet' , norm = norm)
    plt.colorbar(img)
    ax.set_title(f'{source} \nRequested center : {self.name}')
    ax.grid(False)
    plt.show()
    return fig,ax

  def show_hist(self, source = 'Readout',bins = None,figsize=(15,8)):
    fig, ax = plt.subplots(1,1,figsize=figsize)

    if source =='Readout':
      data  = self.readout.ravel()
    elif source == 'Sky':
      data = self.sky_array.ravel()
    elif source == 'DC':
      data = self.DC_array.ravel()
    elif source == 'QE':
      data = self.qe_array.ravel()
    elif source =='Bias':
      data = (self.bias_array + self.DC_array).ravel()
    elif source == 'PRNU':
      data = self.PRNU_array.ravel()
    elif source == 'DNFP':
      data = self.DNFP_array.ravel()

    if bins is None:
      bins  = np.linspace(data.min(), data.max(), 20)
    ax.hist(data, bins = bins)
    ax.set_title(f'{source} histogram')
    ax.set_ylabel('Count')
    ax.set_yscale('log')
    plt.show()
    return fig, ax

  def writeto(self,name):
    """
    Function for downloading a fits file of simulated field
    """
    if np.all(self.image) !=None:
      hdu = fits.PrimaryHDU(self.charge, header = self.header)
      hdu.wcs = self.wcs
      hdul = fits.HDUList([hdu])
      hdul.writeto(f'{name}',overwrite= True)
    else:
      print("Generate PSF")

# **PSF Simulator by INSIST**


In [41]:
l1 = widgets.Label("Enter Source name or coordinates(Eg. M 67, 06h 03m 20s 42 18 00) or Upload CSV")
display(l1)
btn_up = widgets.FileUpload(multiple=False)
l2 = widgets.Label("\t Or ")
name = widgets.Text()
b1 = widgets.HBox([name,l2, btn_up])

display(b1)

l2 = widgets.Label("Enter size of field (degrees)")
display(l2)
size = widgets.Text()
display(size)

l3 = widgets.Label("Choose mode:")
display(l3)
mode = widgets.Dropdown(
    options=["HCIPy","Zeemax"],
    value='HCIPy')
display(mode)

l4 = widgets.Label("Choose axis:")
display(l4)
axis = widgets.Dropdown(
    options=["off","on"],
    value='off')
display(axis)

l5 = widgets.Label("Enter exposure time (seconds): ")
display(l5)
exp_time = widgets.Text()
display(exp_time)

btn = widgets.Button(description="Submit")

btn_show_image = widgets.Button(description = "Show Image")

btn_dn = widgets.Button(description = "Download FITS")
buttons = widgets.HBox([btn_show_image,btn_dn])

display(btn)

output = widgets.Output()

@output.capture(clear_output=True,wait=True)

def submit(b):
    l4= widgets.Label("Loading sources in the field...",font_size=75)
    display(l4)

    start=perf_counter()

    if len(name.value)<1 or (np.all(np.char.isnumeric(ps.value.split('.')))!=True ):
        name_in = 'M 67'
        l5= widgets.Label("Default source : M 67 set")
        display(l5)
    else :
        name_in = name.value
        
    if len(size.value)<1 or (np.all(np.char.isnumeric(ps.value.split('.')))!=True ) or  float(size.value)<0.0002:
        size_in = 0.02
        l5= widgets.Label("Default size : 0.02 set")
        display(l5)   
    else:
        size_in = float(size.value)

    if len(exp_time.value)<1 or np.all(np.char.isnumeric(fwhm.value.split('.')))!=True:
        exp_time_in = 600
        l5= widgets.Label("Default exp_time : 600 seconds set")
        display(l5)
    else :
        exp_time_in = float(exp_time.value)

    if len(btn_up.value)== 0:
        psf = PSF(name = name_in, size = size_in , mode = mode.value, axis = axis.value,exp_time = exp_time_in)
        global psf_
        psf_ = psf
        psf()  
        time_taken = perf_counter() - start
        l5= widgets.Label(f"Time taken : {np.round(time_taken,3)} seconds")
        display(l5)
        display(buttons) 
    else:
        l5= widgets.Label("Incorrect source Syntax")
        display(l5) 


def show(b):
    with output:
        l5= widgets.Label("Loading Image...")
        display(l5) 
        psf_.show_image()  

def download(b):
    with output:
        l5= widgets.Label("Coming Soon..")
        display(l5) 
        #psf_.writeto(f'{psf_.name}')
        #local_file = FileLink(f'{psf_.name}.fits', result_html_prefix="Click here to download: ")
        #display(local_file)

btn.on_click(submit) 

btn_show_image.on_click(show)

btn_dn.on_click(download) 
output

Label(value='Enter Source name or coordinates(Eg. M 67, 06h 03m 20s 42 18 00) or Upload CSV')

HBox(children=(Text(value=''), Label(value='\t Or '), FileUpload(value={}, description='Upload')))

Label(value='Enter size of field (degrees)')

Text(value='')

Label(value='Choose mode:')

Dropdown(options=('HCIPy', 'Zeemax'), value='HCIPy')

Label(value='Choose axis:')

Dropdown(options=('off', 'on'), value='off')

Label(value='Enter exposure time (seconds): ')

Text(value='')

Button(description='Submit', style=ButtonStyle())

Output()

In [42]:
psf_.__dict__

{'axis': 'on',
 'mode': 'HCIPy',
 'name': 'M 67',
 'cols': None,
 'radius': 0.01,
 'pixel_scale': 0.1,
 'qe': 0.5,
 'bias': 733,
 'gain': 1.85,
 'bit_res': 16,
 'full_well_capacity': 65536,
 'M_sky': 27.5,
 'M_sky_p': 32.5,
 'RN': 3,
 'exp_time': 600,
 'df':            ra        dec   mag_nuv   mag_fuv
 0  132.832869  11.811535  17.48029  16.84529,
 'params': {'shot_noise': 'gaussian',
  'sky': 'gaussian',
  'PRNU': True,
  'DC': True,
  'T': 218,
  'DFM': 0.01424,
  'pixel_area': 1e-06,
  'DNFP': True,
  'DN': 0.4},
 'ra': 132.83387,
 'dec': 11.81196,
 'jobs': <mastcasjobs.MastCasJobs at 0x7fd23a9a2190>,
 'query': 'SELECT ra,dec, mag_nuv,mag_fuv \n                              FROM\n                              gcat_asc\n                              WHERE\n                              ra BETWEEN 132.83387 - 0.01 AND 132.83387   \n                              + 0.01\n                              AND \n                              dec BETWEEN 11.81196- 0.01 AND 11.81196 \n        