In [None]:
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sb

import requests

import os
from ipywidgets import widgets
from IPython.display import  FileLink
from time import perf_counter
import json

# get the WSID and password if not already defined
if not os.environ.get('CASJOBS_USERID'):
    os.environ['CASJOBS_USERID'] = "avinashck90"
if not os.environ.get('CASJOBS_PW'):
    os.environ['CASJOBS_PW'] = "INSIST369"

from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.stats import gaussian_fwhm_to_sigma, gaussian_sigma_to_fwhm
from astropy.modeling import models, fitting
from astropy.visualization import LogStretch
from astropy.visualization.mpl_normalize import ImageNormalize
from astropy import wcs
from astropy.io import fits
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

sb.set_style('dark')
matplotlib.rcParams['font.size']=12
matplotlib.rcParams['figure.figsize']=(10,10)

In [None]:
if not os.path.exists('mastcasjobs'):
    os.system("pip install git+https://github.com/rlwastro/mastcasjobs@master --quiet")

In [None]:
import mastcasjobs

In [None]:
def mastQuery(request, json_return=False):
    """Perform a MAST query.

    Parameters
    ----------
    request (dictionary): The MAST request json object
    
    Returns the text response or (if json_return=True) the json response
    """
    
    url = "https://mast.stsci.edu/api/v0/invoke"

    # Encoding the request as a json string
    requestString = json.dumps(request)
    
    # make the query
    r = requests.post(url, data=dict(request=requestString))
    
    # raise exception on error
    r.raise_for_status()
    
    if json_return:
        return r.json()
    else:
        return r.text


def resolve(name):
    """Get the RA and Dec for an object using the MAST name resolver
    
    Parameters
    ----------
    name (str): Name of object

    Returns RA, Dec tuple with position"""

    resolverRequest = {'service':'Mast.Name.Lookup',
                       'params':{'input':name,
                                 'format':'json'
                                },
                      }
    resolvedObject = mastQuery(resolverRequest, json_return=True)
    # The resolver returns a variety of information about the resolved object, 
    # however for our purposes all we need are the RA and Dec
    try:
        objRa = resolvedObject['resolvedCoordinate'][0]['ra']
        objDec = resolvedObject['resolvedCoordinate'][0]['decl']
    except IndexError as e:
        raise ValueError("Unknown object '{}'".format(name))
    return (objRa, objDec)

In [None]:
class PSF_gen():
    def __init__(self,name,pixel_scale=0.1,fwhm_in=0.2,n_pix_main=8000,n_pix_sub=200):
        self.name = name
        self.ra,self.dec = resolve(name)
        self.pixel_scale = pixel_scale
        self.fwhm = fwhm_in/pixel_scale
        self.n_pix_main = n_pix_main
        self.n_pix_sub = n_pix_sub
        self.jobs = mastcasjobs.MastCasJobs(context="GALEX_Catalogs")
        self.query =   f"""SELECT ra,dec,mag_nuv,mag_fuv 
                            FROM
                            gcat_asc
                            WHERE
                            ra BETWEEN {self.ra} -0.1 AND {self.ra}+0.1
                            AND 
                            dec BETWEEN {self.dec}-0.1 AND {self.dec}+0.1
                        """
        self.df = self.jobs.quick(self.query, task_name="python cone search").to_pandas()
        self.wcs = self.create_wcs()
        self.df_field = self.create_field()
        self.image = None
        
    def create_wcs(self):
        w = wcs.WCS(naxis=2)
        w.wcs.crpix = [self.n_pix_main//2,self.n_pix_main//2]
        w.wcs.cdelt = np.array([self.pixel_scale/3600, self.pixel_scale/3600])
        w.wcs.crval = [self.ra, self.dec]
        w.wcs.ctype = ["RA---TAN", "DEC--TAN"]
        return w
    def create_field(self):
        field_ra_min  = self.ra  - 0.5*self.n_pix_main*self.pixel_scale/3600 + 0.5*self.n_pix_sub*self.pixel_scale/3600
        field_ra_max  = self.ra  + 0.5*self.n_pix_main*self.pixel_scale/3600 - 0.5*self.n_pix_sub*self.pixel_scale/3600
        field_dec_min = self.dec - 0.5*self.n_pix_main*self.pixel_scale/3600 + 0.5*self.n_pix_sub*self.pixel_scale/3600
        field_dec_max = self.dec + 0.5*self.n_pix_main*self.pixel_scale/3600 - 0.5*self.n_pix_sub*self.pixel_scale/3600

        df_field = self.df[(self.df['ra']>field_ra_min) &(self.df['ra']<field_ra_max) &(self.df['dec']>field_dec_min) &(self.df['dec']<field_dec_max)]
        return df_field
    
    def show_field(self,figsize=(10,10)):
        
        fig, ax = plt.subplots(1,1,figsize=figsize)
        ax.scatter(self.df_field['ra'],self.df_field['dec'],marker='.',color='black')
        ax.set_title(f" Requested Center : {self.name} \n FoV : {np.round(self.pixel_scale*self.n_pix_main/3600,3)} degrees | {len(self.df_field)} sources")
        ax.invert_xaxis()
        ax.set_xlabel('RA (Degrees)')
        ax.set_ylabel('Dec (Degrees)')
        plt.show()
        return fig, ax
        
    def generate_psf(self):
        n_pix_sub = self.n_pix_sub
        image_g_main = np.zeros((self.n_pix_main,self.n_pix_main))
        f=[]
        for i, row in self.df_field.iterrows():
            c = SkyCoord(row['ra'],row['dec'],unit=u.deg)
            pix = self.wcs.world_to_array_index(c)

            ABmag = row['mag_nuv']
            flux = 3631*10**(-ABmag/2.5)
            f.append(flux)     
            sigma_psf = self.fwhm*gaussian_fwhm_to_sigma

            amplitude = flux/(sigma_psf**2*2*np.pi)
            model_gauss = models.Gaussian2D(amplitude,n_pix_sub//2 , n_pix_sub//2 , sigma_psf, sigma_psf)
            y, x  = np.mgrid[0:(n_pix_sub), 0:(n_pix_sub)]

            image_g_sub = model_gauss(x,y)

            image_g_main[pix[0]-n_pix_sub//2:pix[0]+n_pix_sub//2,pix[1]-n_pix_sub//2:pix[1]+n_pix_sub//2]+=image_g_sub
        if np.round(np.array(f).sum(),4)!= np.round(image_g_main.sum(),4):
            print("Patch width is too small")
        self.image = image_g_main

    def show_image(self):
        if np.all(self.image) !=None:
            fig = plt.figure(figsize = (15,10))
            ax = fig.add_subplot(projection=self.wcs)
            image = self.image.astype(np.float32)
            ax.imshow(image,cmap='gray',vmin = 1e-11,vmax = 5.5e-9 )
            ax.set_title(f'Requested center : {self.name}\n Pixel Scale : {self.pixel_scale} arcseconds/pixel | FWHM = {self.fwhm*self.pixel_scale} arcsecs')
            ax.invert_xaxis()
            plt.show()
            return fig,ax
        else:
            print("Generate PSF")
    def writeto(self,name):
        if np.all(self.image) !=None:
            hdu = fits.PrimaryHDU(self.image.astype(np.float32))
            hdu.wcs= self.wcs
            hdul = fits.HDUList([hdu])
            hdul.writeto(f'{name}.fits')
        else:
            print("Generate PSF")


# **PSF Generator using MAST CasJobs**


In [None]:
l1 = widgets.Label("Enter Source name or coordinates(Eg. M 67, 06h 03m 20s 42 18 00)")
display(l1)
name = widgets.Text()
display(name)

l2 = widgets.Label("Enter Pixel scale of detector (arcseconds/pixels)")
display(l2)
ps = widgets.Text()
display(ps)


l3 = widgets.Label("Enter FWHM of the PSF (arcseconds)" )
display(l3)
fwhm = widgets.Text()
display(fwhm)
btn = widgets.Button(description="Submit")

btn_show_image = widgets.Button(description = "Show Image")

btn_dn = widgets.Button(description = "Download FITS")
buttons = widgets.HBox([btn_show_image,btn_dn])

display(btn)

output = widgets.Output()

@output.capture(clear_output=True,wait=True)

def submit(b):
    l4= widgets.Label("Loading sources in the field...",font_size=75)
    display(l4)
    
    start=perf_counter()
    if len(ps.value)<1 or np.all(np.char.isnumeric(ps.value.split('.')))!=True:
        pixel_scale = 0.1
        l5= widgets.Label("Default Pixel Scale : 0.1 set")
        display(l5)
            
    else:
        pixel_scale = float(ps.value)
        
    if len(fwhm.value)<1 or np.all(np.char.isnumeric(fwhm.value.split('.')))!=True:
        fwhm_in = 0.3
        l5= widgets.Label("Default FWHM : 0.3 set")
        display(l5)
    else:
        fwhm_in = float(fwhm.value)
    
    if len(name.value)> 2:
        psf = PSF_gen(name.value,pixel_scale,fwhm_in,n_pix_main=8000,n_pix_sub=200)
        psf.show_field()
        psf.generate_psf()
        global psf_
        psf_ = psf
        time_taken = perf_counter() - start
        l5= widgets.Label(f"Time taken : {np.round(time_taken,3)} seconds")
        display(l5)
        display(buttons) 
    
    else:
        l5= widgets.Label("Incorrect source Syntax")
        display(l5) 
    

def show(b):
    with output:
        l5= widgets.Label("Loading Image...")
        display(l5) 
        psf_.show_image()  
    
def download(b):
    with output:
        l5= widgets.Label("Coming Soon..")
        display(l5) 
        #psf_.writeto(f'{psf_.name}')
        #local_file = FileLink(f'{psf_.name}.fits', result_html_prefix="Click here to download: ")
        #display(local_file)
    
btn.on_click(submit) 

btn_show_image.on_click(show)

btn_dn.on_click(download) 
output

Label(value='Enter Source name or coordinates(Eg. M 67, 06h 03m 20s 42 18 00)')

Text(value='')

Label(value='Enter Pixel scale of detector (arcseconds/pixels)')

Text(value='')

Label(value='Enter FWHM of the PSF (arcseconds)')

Text(value='')

Button(description='Submit', style=ButtonStyle())

Output()