In [126]:
from astroquery.esa.jwst import Jwst
import numpy as np
from datetime import datetime

from tqdm.notebook import tqdm
tqdm.pandas()

import pandas as pd
from astropy.io import fits
import astropy.time as at
from astroquery.jplhorizons import Horizons
import re
import os
import sys
import logging
import glob
import shapely.wkt
from shapely.geometry import Polygon, Point
from reproject import reproject_interp
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
import matplotlib.pyplot as plt
import math as mt
from astropy.time import Time
from astroquery.imcce import Skybot
import astropy.units as u
from astropy.visualization import ZScaleInterval, ImageNormalize,LinearStretch, SqrtStretch, simple_norm
from astropy.nddata import Cutout2D
from PIL import Image
from matplotlib import cm
import numpy as np
from sbident import SBIdent
import time

from scipy import ndimage
from skimage.filters import threshold_otsu
import matplotlib.colors as mcolors

import warnings
warnings.filterwarnings('ignore')

In [127]:
#Cosmos Account Credentials
Jwst.login(user='nmartind', password='Mr.Fantastic1999')

INFO: OK [astroquery.utils.tap.core]


In [128]:
def filterCSV4asteroids(df):    
    # Filter rows where the specified column contains a '[' which contain asteroids
    filteredDF = df[df['Asteroids'].str.contains(r'\[', na=False)]
    
    return(filteredDF)

In [129]:
def isLocatedInImage(targetRA, targetDEC, imagePOLYGON):
    #Check if the RA and DEC coordinate exist within the image bound polygon (returns boolean)
    return(imagePOLYGON.contains(Point(targetRA, targetDEC)))

In [130]:
def formatPolygon(string):
    #Format the saved polygon from the archive to the one suitable with shapely
    values_ = string.split(' ')
    string_ = string.replace('Polygon ','POLYGON((')
    
    # add in the commas at every other space
    string_ = re.sub(r'(\s\S*?)\s', r'\1, ',string_)
    
    #Add in the final chracter so that it matches the first (ie forms a closed polygon)
    string_ = string_ + f', {values_[1]} {values_[2]}))'
    string_ = string_.replace('POLYGON(', 'POLYGON (')
    return(string_)

In [131]:
def filterCSVbyProposal(propRange, df):
    startingProposal = propRange[0]
    endingProposal = propRange[1]
    
    #Filter rows based on the proposal range of interest
    filteredDF = df[df['Proposal'].apply(lambda x: startingProposal <= int(x) <= endingProposal)]
    
    return(filteredDF)

In [132]:
def pullCSVdata(CSVloc, propRange):
    #Pull level 3 CSV data    
    df = pd.read_csv(CSVloc)

    # Filter rows where the specified column contains a '['
    dfAsteroids = filterCSV4asteroids(df)
    
    if not dfAsteroids.empty:
    
        #filter dataframe fro proposals within the range specified
        dfFiltered = filterCSVbyProposal(propRange, dfAsteroids)
        
        if not dfFiltered.empty:
            #return the useful dataframe
            proposals = dfFiltered['Proposal'].to_list()
            observations = dfFiltered['Observation'].to_list()
            asteroids = dfFiltered['Asteroids'].to_list()
            asteroidListGrouped = [x.replace('[','').replace(']','').replace("'","").split(', ') for x in asteroids]
            #asteroidListFlattened  = [item for sublist in asteroidListGrouped for item in sublist]

            polygonString = dfFiltered['Polygon'].to_list()
            expStartMJD = dfFiltered['Exp Start'].to_list()
            expEndMJD = dfFiltered['Exp End'].to_list()

            return(proposals, observations, asteroidListGrouped, polygonString, expStartMJD, expEndMJD)
            
            
        else:
            print(f"Empty Dataframe, No asteroids in that proposal range {propRange}")
            return ([],[],[],[],[])
            
    else:
        print(f"Empty Dataframe, No asteroids in CSV")
        return([],[],[],[],[])

In [133]:
def pullWCS(imagePath):
    #recover data from a fits file
    if os.path.exists(imagePath):
        with fits.open(imagePath) as hdul:        
            data = hdul[1].data
            header = hdul[1].header
            wcs_info = WCS(header)
                       
    else:
        wcs_info = None
        data = None
   
    return(data, wcs_info)

In [134]:
def generateFolder(existingfolder, newfolder):
    if os.path.exists(existingfolder):
        if os.path.exists(newfolder):
            pass
        else:
            os.mkdir(newfolder)
        return(True)

    else:
        return(False)

In [135]:
def produceOriginalImage(imageData, WCS, obsID, filePath):
    if not os.path.exists(f'{filePath}/Original_{obsID}.png'):
        fig = plt.figure(figsize = (20,20))
        ax  = plt.subplot(1,1,1, projection = WCS)

        ax.set_xlabel('RA')
        ax.set_ylabel('DEC')

        ax.title.set_text(f'{obsID} i2d Image')

        #interval = ZScaleInterval()
        #vmin, vmax = interval.get_limits(imageData)
        #norm = ImageNormalize(vmin=vmin, vmax=vmax)
        norm = simple_norm(imageData, 'sqrt')

        cax = ax.imshow(imageData, cmap='viridis',norm=norm)

        plt.savefig(f'{filePath}/Original_{obsID}.png') 
        plt.close()
        #print(f'{filePath}/Original_{obsID}.png')
    #else:
    #    continue
    #    print('previously generated original')

In [136]:
def produceOverlayImage(imageData, WCS, obsID, asteroidList, Xlist, Ylist, filePath):
    if not os.path.exists(f'{filePath}/Overlay_{obsID}.png'):
        try:
            fig = plt.figure(figsize = (20+10*(len(asteroidList)-1),10))
            ax1 = plt.subplot(1,len(asteroidList)+1,1, projection = WCS)

            ax1.set_xlabel('RA')
            ax1.set_ylabel('DEC')

            ax1.title.set_text(f'{obsID} i2d Image')

            #asinh_norm = AsinhNorm(vmin=np.min(imageData), vmax=np.max(imageData), a=2)
            norm = simple_norm(imageData, 'sqrt')

            cax1 = ax1.imshow(imageData, cmap='viridis',norm=norm, aspect='auto', zorder = 1)

            colors = ['orangered','maroon','firebrick','goldenrod','tomato','chocolate', 'tab:red', 'tab:orange','yellow','red','orange']

            for asteroid in asteroidList:
                indx = asteroidList.index(asteroid)

                ax_ = plt.subplot(1, len(asteroidList)+1, indx+2, projection = WCS)
                ax_.set_xlabel('RA')
                ax_.set_ylabel('DEC')
                ax_.title.set_text(f'{obsID} Asteroid {asteroid}')

                # Apply custom normalization
                #exp_norm = ExpNorm(vmin=np.min(imageData), vmax=np.max(imageData), k=0.95)

                cax_ = ax_.imshow(imageData, cmap='viridis',norm=norm, aspect='auto', zorder = 1)

                X = Xlist[indx]
                Y = Ylist[indx]

                extraPixels = 10        

                if len(X) >> 1:
                    P1 = [X[0],  Y[0]]
                    P2 = [X[-1], Y[-1]]

                else:
                    P1 = [X[0],Y[0]]
                    P2 = [X[0],Y[0]]

                #Define the default bounds    
                lowerx = round(min(P1[0],P2[0])) - extraPixels
                upperx = round(max(P1[0],P2[0])) + extraPixels
                lowery = round(min(P1[1],P2[1])) - extraPixels
                uppery = round(max(P1[1],P2[1])) + extraPixels

                ax1.plot([lowerx, upperx, upperx, lowerx, lowerx],[lowery, lowery, uppery, uppery, lowery], alpha = 0.6,  c=colors[indx], linewidth = 2, zorder = 2, label = asteroid)

                ax_.scatter(X, Y, s=15, alpha = 0.8, c=colors[indx], zorder = 2, label = asteroid)

                ax_.set(xlim=(lowerx, upperx), ylim=(lowery, uppery))       

            #plt.colorbar(cax, orientation='vertical')
            ax1.legend()
            #plt.tight_layout()
            plt.savefig(f'{filePath}/Overlay_{obsID}.png') 
            plt.close()

        except Exception as e:
            print(e)
            print('re plotting overlay')
            plt.close()
            fig = plt.figure(figsize = (20+10*(len(asteroidList)-1),10))
            ax1 = plt.subplot(1,len(asteroidList)+1,1, projection = WCS)
            ax1.set_xlabel('RA')
            ax1.set_ylabel('DEC')
            ax1.title.set_text(f'{obsID} i2d Image')

            asinh_norm = AsinhNorm(vmin=np.min(imageData), vmax=np.max(imageData), a=2)
            norm = simple_norm(imageData, 'sqrt')

            cax1 = ax1.imshow(imageData, cmap='viridis',norm=asinh_norm, aspect='auto', zorder = 1)

            for asteroid in asteroidList:
                indx = asteroidList.index(asteroid)
                ax_ = plt.subplot(1, len(asteroidList)+1, indx+2, projection = WCS)
                ax_.set_xlabel('RA')
                ax_.set_ylabel('DEC')
                ax_.title.set_text(f'{obsID} Asteroid {asteroid}')

                # Apply custom normalization
                exp_norm = ExpNorm(vmin=np.min(imageData), vmax=np.max(imageData), k=0.95)
                cax_ = ax_.imshow(imageData, cmap='viridis',norm=exp_norm, aspect='auto', zorder = 1)

                X = Xlist[indx]
                Y = Ylist[indx]
                extraPixels = 10        
                if len(X) >> 1:
                    P1 = [X[0],  Y[0]]
                    P2 = [X[-1], Y[-1]]
                else:
                    P1 = [X[0],Y[0]]
                    P2 = [X[0],Y[0]]
                #Define the default bounds    
                lowerx = round(min(P1[0],P2[0])) - extraPixels
                upperx = round(max(P1[0],P2[0])) + extraPixels
                lowery = round(min(P1[1],P2[1])) - extraPixels
                uppery = round(max(P1[1],P2[1])) + extraPixels
                ax1.plot([lowerx, upperx, upperx, lowerx, lowerx],[lowery, lowery, uppery, uppery, lowery], alpha = 0.6, linewidth = 2, zorder = 2, label = asteroid)
                ax_.scatter(X, Y, s=15, alpha = 0.8, zorder = 2, label = asteroid)
                ax_.set(xlim=(lowerx, upperx), ylim=(lowery, uppery))       
            #plt.colorbar(cax, orientation='vertical')
            ax1.legend()
            #plt.tight_layout()
            plt.savefig(f'{filePath}/Overlay_{obsID}.png') 
            plt.close()

In [137]:
def highContrastImage(imageData, WCS, asteroidName, X,Y, path, observation):
    #if not os.path.exists(f'{path}/HighContrastZoom_{observation}_{asteroidName}.png'):
    
    #Generate cutout
    rows, cols = imageData.shape

    extraPixels = 30

    # Calculate the crop boundaries, with bounds checking
    col_start = max(mt.floor(min(X)-extraPixels), 0)  # Ensure not below 0
    col_end =   min(mt.floor(max(X)+extraPixels), cols - 1)  # Ensure not above max columns
    row_start = max(mt.floor(min(Y)-extraPixels), 0)  # Ensure not below 0
    row_end =   min(mt.floor(max(Y)+extraPixels), rows - 1)  # Ensure not above max rows

    yBound = [col_start, col_end]
    xBound = [row_start, row_end]

    
    fig = plt.figure(figsize = (20,20))

    plt.title(f"High Contrast Zoom of Asteroid {asteroidName}")
    plt.xlabel('Pixel Columns')
    plt.ylabel('Pixel Rows')

    #exp_norm = ExpNorm(vmin=np.min(imageData), vmax=np.max(imageData), k=0.95)

    cax = plt.imshow(imageData[xBound[0]:xBound[1],yBound[0]:yBound[1]],cmap='viridis')

    plt.colorbar(cax)
    plt.gca().invert_yaxis()
    plt.savefig(f'{path}/HighContrastZoom_{observation}_{asteroidName}.png') 
    plt.close() 
    
    return(imageData[xBound[0]:xBound[1],yBound[0]:yBound[1]])

    #else:
    #    return(False)

In [138]:
# Define custom asinh normalization function
class AsinhNorm(mcolors.Normalize):
    def __init__(self, vmin=None, vmax=None, a=0.1, clip=False):
        self.a = a  # Scaling parameter
        super().__init__(vmin=vmin, vmax=vmax, clip=clip)

    def __call__(self, value, clip=None):
        value = np.ma.masked_array(value, mask=np.isnan(value))
        if self.vmin is None:
            self.vmin = np.min(value)
        if self.vmax is None:
            self.vmax = np.max(value)
        normalized = (value - self.vmin) / (self.vmax - self.vmin)
        return np.arcsinh(self.a * normalized) / np.arcsinh(self.a)

In [139]:
class ExpNorm(mcolors.Normalize):
    def __init__(self, vmin=None, vmax=None, k=5, clip=False):
        self.k = k  # Scaling parameter for exponential stretching
        super().__init__(vmin=vmin, vmax=vmax, clip=clip)

    def __call__(self, value, clip=None):
        value = np.ma.masked_array(value, mask=np.isnan(value))
        if self.vmin is None:
            self.vmin = np.min(value)
        if self.vmax is None:
            self.vmax = np.max(value)
        normalized = (value - self.vmin) / (self.vmax - self.vmin)
        return np.exp(self.k * normalized) - 1

In [145]:
def generateSNRimage(imageData, path, observation, asteroidName):
    #if not os.path.exists(f'{path}/SNR_{observation}_{asteroidName}.png'):
    # Step 1: Mask out the zero values and nan values
    #valid_pixels = image[~np.isnan(image)].flatten()
    non_zero_image = np.ma.masked_equal(imageData, 0)

    masked_image = np.ma.masked_where(np.isnan(non_zero_image), non_zero_image)

    # Step 2: Flatten the image, excluding the masked zero values
    valid_pixels = masked_image.compressed()  # This gives a 1D array of non-zero values

    # Step 3: Apply Otsu's threshold to separate the source from the background
    threshold_value = threshold_otsu(valid_pixels)

    # Step 4: Separate signal (pixels above threshold) and background (below threshold)
    signal_pixels = valid_pixels[valid_pixels > threshold_value]
    background_pixels = valid_pixels[valid_pixels <= threshold_value]

    # Step 5: Calculate mean and standard deviation of signal and background
    mean_signal = np.mean(signal_pixels)
    mean_background = np.mean(background_pixels)

    snr = mean_signal / mean_background

    #print('background',mean_background)
    #print('signal',mean_signal)

    # Plot the image and mask
    plt.figure(figsize=(30, 10))

    plt.subplot(1, 3, 1)
    plt.title('Original Image with Source')
    cax1 = plt.imshow(masked_image, cmap='gray', zorder = 2)
    plt.scatter([0],[0], c = 'white', label = asteroidName, zorder = 1)
    plt.legend()
    plt.colorbar(cax1)
    plt.gca().invert_yaxis()

    plt.subplot(1, 3, 2)
    plt.title(f'Mask using Otsu Threshold = {threshold_value:.2f},  S/N = {snr:.2f}')
    cax2 = plt.imshow(masked_image > threshold_value, cmap='gray')
    plt.colorbar(cax2)
    plt.gca().invert_yaxis()


    # Step 1: Flatten the 2D image into a 1D array
    flattened_data = imageData.ravel()  # Flatten the 2D array to 1D
    #remove any nans from the 1D array        
    cleaned_data = [x for x in flattened_data if x!= 0 and not np.isnan(x)]                                                                  

    # Step 2: Create a 1D histogram of pixel values
    # Binning pixel values into 50 bins (you can adjust the number of bins)
    hist, bin_edges = np.histogram(cleaned_data, bins=50)

    # Step 3: Plot the 1D histogram
    plt.subplot(1, 3, 3)
    plt.bar(bin_edges[:-1], hist, width=np.diff(bin_edges), edgecolor='black', align='edge')
    plt.title('1D Histogram of Pixel Values')
    plt.xlabel('Pixel Value')
    plt.ylabel('Frequency')
    plt.grid(True)

    # Step 4: Find the most common pixel value (the bin with the maximum count)
    max_bin_index = max(hist)  # Index of the bin with the highest frequency
    #most_common_pixel_value = bin_edges[max_bin_index]  # The corresponding pixel value

    plt.axvline(x = mean_background, ymax=(max_bin_index / plt.ylim()[1])+20, color = 'black', linestyle = 'dashed', label = f'Average Background = {mean_background:.2f}')
    plt.axvline(x = threshold_value, ymax=(max_bin_index / plt.ylim()[1])+20, color = 'green', linestyle = 'dashed', label = f'Otsu Threshold = {threshold_value:.2f}')
    plt.axvline(x = mean_signal,     ymax=(max_bin_index / plt.ylim()[1])+20, color = 'red',   linestyle = 'dashed', label = f'Average Signal = {mean_signal:.2f}')

    plt.legend(loc = 'upper right')
    #print(f'The most common pixel value is approximately: {most_common_pixel_value}')

    plt.savefig(f'{path}/SNR_{observation}_{asteroidName}.png') 
    plt.close()
    print(f'{path}/SNR_{observation}_{asteroidName}.png')

    return(f'{snr:.3f}')

In [161]:
def jplHorizonsSearch(targetID, startTime, stopTime, polyString):
    #Search the JPL Horizons data for a specific target to get orbital values. This method is more accurate then the cone search and provides a double check for asteroids (named from the cone search) existing in the image
    
    polyStringFmt =  str(formatPolygon(str(polyString)))
    #generate polygon as a shapely function
    poly = shapely.wkt.loads(polyStringFmt)
    
    probeMinutes = 1
        
    #Grab the data every 600min since we only care about it if belongs in the image not the shape at this point, always returns first point even for short exposures
    obj = Horizons(id=targetID, location='Geocentric@JWST', epochs={'start': str(startTime), 'stop' : str(stopTime), 'step' : str(probeMinutes) + 'm'}) 

    #get Ephemerides data
    ephemJPL = obj.ephemerides()
    ephemJPL = ephemJPL.to_pandas()
    
    asteroidRA =   ephemJPL['RA']
    asteroidDEC =  ephemJPL['DEC']
    
    contained = False
    
    for positionIndx in range(len(asteroidRA)):
        isContained = isLocatedInImage(asteroidRA[positionIndx], asteroidDEC[positionIndx], poly)
        if isContained:           
            contained = True
                
    return(asteroidRA, asteroidDEC, contained)

In [162]:
def generateImages(prop, obs, asteroidList, polygon, expStart, expEnd, filePath):

    containedAsteroidNames = []
    containedAsteroidX = []
    containedAsteroidY = []
    
    for asteroid in asteroidList:
        raList, decList, containedCheck = jplHorizonsSearch(asteroid, expStart, expEnd, polygon)
        
        #print('contained',containedCheck)
        if containedCheck:
            #pull the WCS and image data for the observation
            dataPath = f'/data/user/jwst_jw01/jw0{prop}/{obs}_i2d.fits.gz'
            imageData, WCS = pullWCS(dataPath)
            
            #print('data exists',WCS != None)
            if WCS is None:
                print(f"ERROR: Image Data from {dataPath} Not Found")
            else:
                #convert RA and DEC to pixel locations
                asteroidPixelX, asteroidPixelY = WCS.all_world2pix(raList, decList, 0)
                
                containedAsteroidNames.append(asteroid)
                containedAsteroidX.append(asteroidPixelX)
                containedAsteroidY.append(asteroidPixelY)
                
                    
    
    if len(containedAsteroidNames) > 0:
    
        produceOverlayImage(imageData, WCS, obs, containedAsteroidNames, containedAsteroidX, containedAsteroidY, filePath)
        produceOriginalImage(imageData, WCS, obs, filePath)
        for indx in range(len(containedAsteroidNames)):
            croppedImg = highContrastImage(imageData, WCS, containedAsteroidNames[indx], containedAsteroidX[indx], containedAsteroidY[indx], filePath, obs)

In [163]:
def main(csvName, outputFolder, propRange):
    
    #generateFolder(outputFolder)
    #pull the useful asteroid rows from the provided lvl 3 csv
    proposals, lvl3observations, asteroids, polyString, expStartList, expEndList = pullCSVdata(csvName, propRange)
    
    if proposals:
        
        uniqueProposals = sorted(list(set(proposals)))
        print(f"From proposals {uniqueProposals} found {len(lvl3observations)} Observations containing asteroids")
        
        for indx in tqdm(range(len(proposals))):
            prop = proposals[indx]
            obs = lvl3observations[indx]
            asteroidList  = asteroids[indx]
            polygon =  (str(polyString[indx]))
            expStart = expStartList[indx]
            expEnd = expEndList[indx]
            
            propPath = f'{outputFolder}/{prop}'    
            lvl3path = f'{propPath}/Level3_Products'
            lvl2path = f'{propPath}/Level2_Members'
            
            print(lvl2path)
            
            if generateFolder(lvl2path, lvl3path):
    
                generateImages(prop, obs, asteroidList, polygon, expStart, expEnd, lvl3path)

In [164]:
%%time
main("MIRI/LVL3/LVL3_Full.csv", "MIRI/LVL2", [1720,1730])

#check the snr 

From proposals [1727] found 17 Observations containing asteroids


  0%|          | 0/17 [00:00<?, ?it/s]

MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
MIRI/LVL2/1727/Level2_Members
CPU times: user 28.4 s, sys: 11.6 s, total: 40 s
Wall time: 1min 4s
