In [98]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import ZScaleInterval
from matplotlib.widgets import Slider
import os
import stdatamodels.jwst.datamodels as dm
from astropy.io import fits
from glob import glob
import pandas as pd

In [147]:
class DataCube:
	"""
	An object combining multiple 1D and 2D spectra from a given folder.
	
	
	Parameters :
	----------
	folder :
		path to the folder containing the final data. Will loop on every fits file 
		and keep the one containing the _s2d or _x1d suffix
	
	Properties : 
	----------
	table : 
		a table used for making a correspondence between a source id, a s2d file and a x1d file. 
		Those files are stored as a list of paths 
	"""
	
	def __init__(self, folder):	
		s2dList = [[file] for file in sorted(glob(os.path.join(folder, '*_s2d.fits')))]
		x1dList = [[x] if os.path.exists(x := file[0].replace("_s2d", "_x1d")) else [None] for file in s2dList]
		sourceList = [fits.open(file[0])[1].header["SOURCEID"] for file in s2dList]
		
		self.table = pd.DataFrame({"sourceID": sourceList, "s2d": s2dList, "x1d": x1dList})
		
		# Initializes the dataframe which will contain the data models
		self.dataTable = pd.DataFrame({"sourceID" : [], "s2d" : [], "x1d" : []})

		
		
	def combineDataCube(self, datacube):
		"""
		Combines 2 Datacubes
		Parameters
		----------
		datacube :
			Another datacube to be appended to this one
		"""
		
		print("Starting Combining Datacubes")
		
		print("Merging...")
		# Perform an outer join to include all sourceIDs from both DataFrames
		merged = self.table.merge(
			datacube.table, on="sourceID", how="outer", suffixes=("_self", "_other")
		)
		
		print("Replacing empty values...")
		# Fill missing values in 's2d' and 'x1d' columns with empty lists or zero, as appropriate
		merged["s2d_self"] = merged["s2d_self"].apply(lambda x: x if isinstance(x, list) else [])
		merged["s2d_other"] = merged["s2d_other"].apply(lambda x: x if isinstance(x, list) else [])
		merged["x1d_self"] = merged["x1d_self"].apply(lambda x: x if isinstance(x, list) else [])
		merged["x1d_other"] = merged["x1d_other"].apply(lambda x: x if isinstance(x, list) else [])
		
		print("Appending paths...")
		# Combine the 's2d' and 'x1d' columns
		merged["s2d"] = merged["s2d_self"] + merged["s2d_other"]
		merged["x1d"] = merged["x1d_self"] + merged["x1d_other"]
		
		# Keep only necessary columns: 'sourceID', 's2d', 'x1d'
		self.table = merged[["sourceID", "s2d", "x1d"]]
		
		print("Finished Combining Datacubes!")

	def preloadDataCube(self):
		"""
		Initializes self.dataTable, a table structurally identical to self.table, 
		except the paths are replaced by the corresponding datamodels
		"""
		print("Starting loading data...")
		print("Copying...")
		self.dataTable = self.table.copy()
		
		# Process lists of file paths
		def processList(file_list):
			return [dm.open(file) for file in file_list if isinstance(file, str)]
		
		print("Loading...")
		# Process the 'x1d' and 's2d' columns
		self.dataTable["x1d"] = self.dataTable["x1d"].apply(processList)
		self.dataTable["s2d"] = self.dataTable["s2d"].apply(processList)
		
		print("Finished loading data!")
		
	def exploreDataCube(self):
		fig, axes = plt.subplots(4, 1, figsize=(18, 7), gridspec_kw={'height_ratios': [1, 1, 1, 3]})
		plt.subplots_adjust(left=0.1, bottom=0.15, right=0.9, top=0.9, hspace=0.05)
		
		current_index = 0

		z1, z2 = ZScaleInterval().get_limits(self.dataTable[current_index]["s2d"][1])
		
		# Display initial images
		img1 = axes[0].imshow(self.dataTable[current_index]["s2d"][0].data, aspect='auto', vmin=z1, vmax=z2)
		img2 = axes[1].imshow(self.dataTable[current_index]["s2d"][1].data, aspect='auto', vmin=z1, vmax=z2)
		img3 = axes[2].imshow(self.dataTable[current_index]["s2d"][2].data, aspect='auto', vmin=z1, vmax=z2)
		
		axes[0].set_title("Custom Pipeline")
		axes[1].set_title("Basic Pipeline")
		axes[2].set_title("No Subtraction")
		
		wavelength = self.dataTable[current_index]["x1d"][0].spec[0].spec_table["WAVELENGTH"]
		flux = self.dataTable[current_index]["x1d"][0].spec[0].spec_table["FLUX"]
		err = self.dataTable[current_index]["x1d"][0].spec[0].spec_table["FLUX_ERROR"]
		axes[3].errorbar(wavelength, flux, yerr=err, label='Custom Pipeline')
		
		wavelength = self.dataTable[current_index]["x1d"][1].spec[0].spec_table["WAVELENGTH"]
		flux = self.dataTable[current_index]["x1d"][1].spec[0].spec_table["FLUX"]
		err = self.dataTable[current_index]["x1d"][1].spec[0].spec_table["FLUX_ERROR"]
		axes[3].errorbar(wavelength, flux, yerr=err, label='Basic Pipeline')
		
		wavelength = self.dataTable[current_index]["x1d"][2].spec[0].spec_table["WAVELENGTH"]
		flux = self.dataTable[current_index]["x1d"][2].spec[0].spec_table["FLUX"]
		err = self.dataTable[current_index]["x1d"][2].spec[0].spec_table["FLUX_ERROR"]
		axes[3].errorbar(wavelength, flux, yerr=err, label='No Subtraction')
		
		axes[3].legend()
		
		
		# Update sourceID
		def update(val):
			idx = int(slider.val)  # Get the current slider value
			img1.set_array(self.dataTable[idx]["s2d"][0].data)
			img2.set_array(self.dataTable[idx]["s2d"][1].data)
			img3.set_array(self.dataTable[idx]["s2d"][2].data)
			
			axes[3].clear()  # Clear the current error bar plot
			
			wavelength = self.dataTable[idx]["x1d"][0].spec[0].spec_table["WAVELENGTH"]
			flux = self.dataTable[idx]["x1d"][0].spec[0].spec_table["FLUX"]
			err = self.dataTable[idx]["x1d"][0].spec[0].spec_table["FLUX_ERROR"]
			axes[3].errorbar(wavelength, flux, yerr=err, label='Custom Pipeline', capsize=3)
			
			wavelength = self.dataTable[idx]["x1d"][1].spec[0].spec_table["WAVELENGTH"]
			flux = self.dataTable[idx]["x1d"][1].spec[0].spec_table["FLUX"]
			err = self.dataTable[idx]["x1d"][1].spec[0].spec_table["FLUX_ERROR"]
			axes[3].errorbar(wavelength, flux, yerr=err, label='Basic Pipeline', capsize=3)
			
			wavelength = self.dataTable[idx]["x1d"][2].spec[0].spec_table["WAVELENGTH"]
			flux = self.dataTable[idx]["x1d"][2].spec[0].spec_table["FLUX"]
			err = self.dataTable[idx]["x1d"][2].spec[0].spec_table["FLUX_ERROR"]
			axes[3].errorbar(wavelength, flux, yerr=err, label='No Subtraction', capsize=3)
			
			axes[3].set_title(f"SourceID: {self.dataTable[idx]['sourceID']}")
			fig.canvas.draw_idle()
		
		# Slider
		ax_slider = plt.axes((0.2, 0.05, 0.6, 0.03))
		N = len(self.dataTable["sourceID"])
		slider = Slider(ax_slider, 'Source', 0, N - 1, valinit=current_index, valstep=1)


In [69]:
_ = DataCube("/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA/Final")
_.table

Unnamed: 0,sourceID,s2d,x1d
0,44419,[/home/tim-dewachter/Documents/Thèse/BetterNIR...,[/home/tim-dewachter/Documents/Thèse/BetterNIR...
