In [5]:
import matplotlib.pyplot as plt
from astropy.visualization import ZScaleInterval
from matplotlib.widgets import Slider
import matplotlib.patheffects as path_effects
import os
from astropy.io import fits
from glob import glob
import pandas as pd

In [16]:
class DataCube:
	"""
	An object combining multiple 1D and 2D spectra from a given folder.
	
	
	Parameters :
	----------
	folder :
		path to the folder containing the final data. Will loop on every fits file 
		and keep the one containing the _s2d or _x1d suffix
	
	Properties : 
	----------
	table : 
		a table used for making a correspondence between a source id, a s2d file and a x1d file. 
		Those files are stored as a list of paths 
	"""
	
	def __init__(self, folder):

		#s2dList = [[file] for file in sorted(glob(os.path.join(folder, '*_s2d.fits'))) if (len(fits.open(file)[1].header.get("SHUTSTA", "")) == 3)]
		#x1dList = [[x if os.path.exists(x := file[0].replace("_s2d", "_x1d")) else [None]] for file in s2dList]
		s2dList = [[file] for file in sorted(glob(os.path.join(folder, '*_s2d.fits')))]
		x1dList = [[x] if os.path.exists(x := file[0].replace("_s2d", "_x1d")) else [None] for file in s2dList]
		sourceList = [fits.open(file[0])[1].header["SOURCEID"] for file in s2dList]
		
		self.table = pd.DataFrame({"sourceID": sourceList, "s2d": s2dList, "x1d": x1dList})

		# Initializes the dataframe which will contain the data models
		self.dataTable = pd.DataFrame({"sourceID" : [], "s2d" : [], "x1d" : []})

		
		
	def combineDataCube(self, datacube,i=1):
		"""
		Combines 2 Datacubes
		Parameters
		----------
		datacube :
			Another datacube to be appended to this one, or None, which will fill append a none to every list in the array
		i : how many Nones to add. This depends on if it's the 1st call of the function or the 2nd
		"""
		
		print("Starting Combining Datacubes")

		if datacube is None:
			for i in range(len(self.table)):
				self.table["s2d"][i].append(None)
				self.table["x1d"][i].append(None)
		else:
			print("Merging...")
			# Perform an outer join to include all sourceIDs from both DataFrames
			merged = self.table.merge(
				datacube.table, on="sourceID", how="outer", suffixes=("_self", "_other")
			)

			print("Replacing empty values...")
			# Fill missing values in 's2d' and 'x1d' columns with empty lists
			merged["s2d_self"] = merged["s2d_self"].apply(lambda x: x if isinstance(x, list) else [None for _ in range(i)])
			merged["s2d_other"] = merged["s2d_other"].apply(lambda x: x if isinstance(x, list) else [None])
			merged["x1d_self"] = merged["x1d_self"].apply(lambda x: x if isinstance(x, list) else [None for _ in range(i)])
			merged["x1d_other"] = merged["x1d_other"].apply(lambda x: x if isinstance(x, list) else [None])

			print("Appending paths...")
			# Combine the 's2d' and 'x1d' columns
			merged["s2d"] = merged["s2d_self"] + merged["s2d_other"]
			merged["x1d"] = merged["x1d_self"] + merged["x1d_other"]

			# Keep only necessary columns: 'sourceID', 's2d', 'x1d'
			self.table = merged[["sourceID", "s2d", "x1d"]]
		
		print("Finished Combining Datacubes!")

	def preloadDataCube(self):
		"""
		Initializes self.dataTable, a table structurally identical to self.table, 
		except the paths are replaced by the corresponding datamodels
		"""
		print("Starting loading data...")
		print("Copying...")
		self.dataTable = self.table.copy()
		
		# Process lists of file paths
		def processList(fileList):
			return [fits.open(file) if isinstance(file, str) else None for file in fileList]
		
		print("Loading...")
		# Process the 'x1d' and 's2d' columns
		self.dataTable["x1d"] = self.dataTable["x1d"].apply(processList)
		self.dataTable["s2d"] = self.dataTable["s2d"].apply(processList)
		
		print("Finished loading data!")
		
	def exploreDataCube(self):
		fig, axes = plt.subplots(4, 1, figsize=(18, 7), gridspec_kw={'height_ratios': [1, 1, 1, 5]})
		plt.subplots_adjust(left=0.1, bottom=0.15, right=0.9, top=0.9, hspace=0)

		idx = 0

		def drawExtraction(axe, x1d):
			x0 = x1d.header["EXTRXSTR"]
			x1 = x1d.header["EXTRXSTP"]
			y0 = x1d.header["EXTRYSTR"]
			y1 = x1d.header["EXTRYSTP"]
			axe.vlines((x0, x1), (y0, y0), (y1, y1), color='r', linestyles='dashed', linewidth=0.5)
			axe.hlines((y0, y1), (x0, x0), (x1, x1), color='r', linestyles='dashed', linewidth=0.5)

		# Update sourceID
		def update(val):
			idx = int(slider.val)  # Get the current slider value

			axes[3].clear()  # Clear the current error bar plot

			legends = ["No Subtract", "Basic Pipeline", ""]
			for i in range(len(legends)):
				axes[i].clear()
				if not self.dataTable["s2d"][idx][i] is None:
					z1, z2 = ZScaleInterval().get_limits(self.dataTable["s2d"][idx][i][1].data)
					axes[i].imshow(self.dataTable["s2d"][idx][i][1].data, aspect='auto', vmin=z1, vmax=z2, cmap="viridis", origin="lower")
					drawExtraction(axes[i], self.dataTable["x1d"][idx][i][1])

				text = axes[i].text(0.02, 0.3, legends[i], color="w", transform=axes[i].transAxes)
				text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])

				if not self.dataTable["x1d"][idx][i] is None and i != 2:
					wavelength = self.dataTable["x1d"][idx][i][1].data["WAVELENGTH"]
					flux = self.dataTable["x1d"][idx][i][1].data["FLUX"]
					err = self.dataTable["x1d"][idx][i][1].data["FLUX_ERROR"]
					axes[3].errorbar(wavelength, flux, yerr=err, label=legends[i], capsize=3)

			axes[3].text(0.05, 0.05, f"SourceID: {self.dataTable['sourceID'][idx]}", color="k", transform=axes[3].transAxes, size=15)
			axes[3].set_xlabel(r"$\lambda$ (µm)")
			axes[3].set_ylabel(r"Flux (Jy)")
			axes[3].legend()
			axes[3].grid()
			fig.canvas.draw_idle()

			axes[0].set_axis_off()
			axes[1].set_axis_off()
			axes[2].set_axis_off()

		def onKey(event):
			current = slider.val
			if event.key == "right":  # Move slider one step right
				new = min(current + 1, N - 1)  # Ensure within bounds
				slider.set_val(new)
			elif event.key == "left":  # Move slider one step left
				new = max(current - 1, 0)  # Ensure within bounds
				slider.set_val(new)

		# Slider
		ax_slider = plt.axes((0.2, 0.05, 0.6, 0.03))
		N = len(self.dataTable["sourceID"])
		slider = Slider(ax_slider, 'Source', 0, N - 1, valinit=idx, valstep=1)
		update(idx)

		# Attach the update function to the slider
		slider.on_changed(update)

		# Connect keypress handler
		fig.canvas.mpl_connect("key_press_event", onKey)

		# Show the plot
		plt.show()


In [17]:
dc = DataCube("/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA/Final")
dc.combineDataCube(DataCube("/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA/Default/Final"))
dc.combineDataCube(None)

dc.table

Starting Combining Datacubes
Merging...
Replacing empty values...
Appending paths...
Finished Combining Datacubes!
Starting Combining Datacubes
Finished Combining Datacubes!


Unnamed: 0,sourceID,s2d,x1d
0,-168,[/home/tim-dewachter/Documents/Thèse/BetterNIR...,[/home/tim-dewachter/Documents/Thèse/BetterNIR...
1,-167,[/home/tim-dewachter/Documents/Thèse/BetterNIR...,[/home/tim-dewachter/Documents/Thèse/BetterNIR...
2,-166,[/home/tim-dewachter/Documents/Thèse/BetterNIR...,[/home/tim-dewachter/Documents/Thèse/BetterNIR...
3,-165,[/home/tim-dewachter/Documents/Thèse/BetterNIR...,[/home/tim-dewachter/Documents/Thèse/BetterNIR...
4,-164,[/home/tim-dewachter/Documents/Thèse/BetterNIR...,[/home/tim-dewachter/Documents/Thèse/BetterNIR...
...,...,...,...
439,43262,[/home/tim-dewachter/Documents/Thèse/BetterNIR...,[/home/tim-dewachter/Documents/Thèse/BetterNIR...
440,43450,[/home/tim-dewachter/Documents/Thèse/BetterNIR...,[/home/tim-dewachter/Documents/Thèse/BetterNIR...
441,43461,[/home/tim-dewachter/Documents/Thèse/BetterNIR...,[/home/tim-dewachter/Documents/Thèse/BetterNIR...
442,44055,[/home/tim-dewachter/Documents/Thèse/BetterNIR...,[/home/tim-dewachter/Documents/Thèse/BetterNIR...


In [18]:
dc.preloadDataCube()
dc.dataTable

Starting loading data...
Copying...
Loading...
Finished loading data!


Unnamed: 0,sourceID,s2d,x1d
0,-168,[[<astropy.io.fits.hdu.image.PrimaryHDU object...,[[<astropy.io.fits.hdu.image.PrimaryHDU object...
1,-167,[[<astropy.io.fits.hdu.image.PrimaryHDU object...,[[<astropy.io.fits.hdu.image.PrimaryHDU object...
2,-166,[[<astropy.io.fits.hdu.image.PrimaryHDU object...,[[<astropy.io.fits.hdu.image.PrimaryHDU object...
3,-165,[[<astropy.io.fits.hdu.image.PrimaryHDU object...,[[<astropy.io.fits.hdu.image.PrimaryHDU object...
4,-164,[[<astropy.io.fits.hdu.image.PrimaryHDU object...,[[<astropy.io.fits.hdu.image.PrimaryHDU object...
...,...,...,...
439,43262,[[<astropy.io.fits.hdu.image.PrimaryHDU object...,[[<astropy.io.fits.hdu.image.PrimaryHDU object...
440,43450,[[<astropy.io.fits.hdu.image.PrimaryHDU object...,[[<astropy.io.fits.hdu.image.PrimaryHDU object...
441,43461,[[<astropy.io.fits.hdu.image.PrimaryHDU object...,[[<astropy.io.fits.hdu.image.PrimaryHDU object...
442,44055,[[<astropy.io.fits.hdu.image.PrimaryHDU object...,[[<astropy.io.fits.hdu.image.PrimaryHDU object...


In [19]:
%matplotlib tkagg

In [20]:
dc.exploreDataCube()