In [None]:
import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.ndimage import gaussian_filter

from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata.utils import Cutout2D


In [None]:
def round_to_significant_digit(value, error):
    
    ''' determine the number of significant digits in the error and round the value to match the error's precision
    Parameters:
    value: float or int
        the value to be rounded
    error: float or int
        the error of the value
    Returns:
    rounded_value: float
        the value rounded to the same precision as the error
    rounded_error: float
        the error rounded to its significant
    ''' 
    
    value = float(value)
    error = float(error)
    
    # Find the first non-zero digit in the error
    significant_digits = -int(f"{error:e}".split('e')[1])
    
    # Round the error to its significant digits
    rounded_error = round(error, significant_digits)
    # in case 0.96 is rounded to 1.0
    significant_digits = -int(f"{rounded_error:e}".split('e')[1])
    
    # Round the value to match the error's precision
    rounded_value = round(value, significant_digits)
    
    return rounded_value, rounded_error

In [None]:
def multi_filter_plot(
    n, size=7,
    gsmooth=0, percentage=[10, 99], save=False
    ):
    
    ''' Plot the multi-filter images of a source from HST, Subaru, and JWST
    Parameters:
    n: str
        the id of the source
    size: int
        the size of the image
    gsmooth: int
        the sigma of the gaussian smoothing
    percentage: list
        the percentage of the colorbar
    save: bool
        whether to save the image
    '''
    
    loc = '/mnt/D/JWST_data/COSMOS-Webb/'
    candi = pd.read_csv('candidate.csv')
    
    BDid = int(n)
    n = int(n) - 1 # n from id to index
    
    ra, dec = candi['ALPHA_J2000_F444W'][n], candi['DELTA_J2000_F444W'][n]
    print(ra, dec)
    
    images = {}
    headers = {}
    WCSs = {}
    pix_sizes = {}
    
    filters = ["F115W", "F150W", "F277W", "F444W"]
    
    for f in range(len(filters)):
        # change the name of the file
        with fits.open(loc + f"COSMOS-mosaic.fits") as hdul:
            images[filters[f]] = hdul[1].data
            headers[filters[f]] = hdul[1].header
            WCSs[filters[f]] = WCS(hdul[1].header)
            pix_sizes[filters[f]] = np.sqrt(hdul[1].header['PIXAR_A2'])
            # print(np.sqrt(hdul[1].header['PIXAR_A2']))
    
    fig = plt.figure(figsize=(25, 11))
    
    # plot data from JWST
    for f in range(len(filters)):
        ax = fig.add_subplot(1, 4, f, projection=WCSs[filters[f]])
        img_data = images[filters[f]]
        # print(img_data.shape)

        x_single, y_single = WCSs[filters[f]].all_world2pix(ra, dec, 0)
        # print(x_single, y_single)
        
        # Plot the image
        y_start = max(0, int(y_single - size))
        y_end = min(img_data.shape[0], int(y_single + size))
        x_start = max(0, int(x_single - size))
        x_end = min(img_data.shape[1], int(x_single + size))

        cut_data_shape = (2*size, 2*size)
        img_center = np.full(cut_data_shape, np.nan)

        y_offset = max(0, int(size - y_single))
        x_offset = max(0, int(size - x_single))
        # print(y_offset, x_offset)
        # print(y_end-y_start, x_end-x_start)

        img_center[y_offset:y_offset+(y_end-y_start), x_offset:x_offset+(x_end-x_start)] = img_data[y_start:y_end, x_start:x_end]
        
        # gaussian smoothing
        if gsmooth != 0:
            img_center = gaussian_filter(img_center, sigma=gsmooth)
        else:
            pass
        
        ax.imshow(img_center, 
                vmax=np.nanpercentile(img_center.flatten(), percentage[1]), 
                vmin=np.nanpercentile(img_center.flatten(), percentage[0])
                )

        # Plot indicator of the center of the image
        ax.plot([size, size], [size+0.5/pix_sizes[filters[f]], size+1/pix_sizes[filters[f]]], color='red', lw=3)
        ax.plot([size+0.5/pix_sizes[filters[f]], size+1/pix_sizes[filters[f]]], [size, size], color='red', lw=3)
        
        mag, mag_err = round_to_significant_digit(candi[f'MAG_AUTO_{filters[f]}'][n], candi[f'MAGERR_AUTO_{filters[f]}'][n])
        if mag_err == -1:
            mag, mag_err = ['nan', 'nan']
        ax.set_title(f"{filters[f]} \n {mag}$\pm${mag_err}" , fontsize=27)
        
        # Hiding x and y axis ticks
        ax.coords[0].set_ticks_visible(False) 
        ax.coords[0].set_ticklabel_visible(False) 
        ax.coords[1].set_ticks_visible(False)
        ax.coords[1].set_ticklabel_visible(False)
    
    if BDid<10:
        supertitle = f"BD0{BDid},  Image size: {round(2*size * pix_sizes["F115W"], 2)}[arcsec]"
    else:
        supertitle = f"BD{BDid},  Image size: {round(2*size * pix_sizes["F115W"], 2)}[arcsec]"
    fig.suptitle(supertitle, fontsize=30)
    
    if save:
        if BDid<10:
            plt.savefig(f'./BD_24/BD_0{BDid}.png') 
        else:
            plt.savefig(f'./BD_24/BD_{BDid}.png')
            
