In [235]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import ZScaleInterval
from matplotlib.widgets import Slider
import matplotlib.patheffects as path_effects
import os
import stdatamodels.jwst.datamodels as dm
from astropy.io import fits
from glob import glob
import pandas as pd

In [240]:
class DataCube:
	"""
	An object combining multiple 1D and 2D spectra from a given folder.
	
	
	Parameters :
	----------
	folder :
		path to the folder containing the final data. Will loop on every fits file 
		and keep the one containing the _s2d or _x1d suffix
	
	Properties : 
	----------
	table : 
		a table used for making a correspondence between a source id, a s2d file and a x1d file. 
		Those files are stored as a list of paths 
	"""
	
	def __init__(self, folder):	
		s2dList = [[file] for file in sorted(glob(os.path.join(folder, '*_s2d.fits')))]
		x1dList = [[x] if os.path.exists(x := file[0].replace("_s2d", "_x1d")) else [None] for file in s2dList]
		sourceList = [fits.open(file[0])[1].header["SOURCEID"] for file in s2dList]
		
		self.table = pd.DataFrame({"sourceID": sourceList, "s2d": s2dList, "x1d": x1dList})
		
		# Initializes the dataframe which will contain the data models
		self.dataTable = pd.DataFrame({"sourceID" : [], "s2d" : [], "x1d" : []})

		
		
	def combineDataCube(self, datacube):
		"""
		Combines 2 Datacubes
		Parameters
		----------
		datacube :
			Another datacube to be appended to this one
		"""
		
		print("Starting Combining Datacubes")
		
		print("Merging...")
		# Perform an outer join to include all sourceIDs from both DataFrames
		merged = self.table.merge(
			datacube.table, on="sourceID", how="outer", suffixes=("_self", "_other")
		)
		
		print("Replacing empty values...")
		# Fill missing values in 's2d' and 'x1d' columns with empty lists or zero, as appropriate
		merged["s2d_self"] = merged["s2d_self"].apply(lambda x: x if isinstance(x, list) else [])
		merged["s2d_other"] = merged["s2d_other"].apply(lambda x: x if isinstance(x, list) else [])
		merged["x1d_self"] = merged["x1d_self"].apply(lambda x: x if isinstance(x, list) else [])
		merged["x1d_other"] = merged["x1d_other"].apply(lambda x: x if isinstance(x, list) else [])
		
		print("Appending paths...")
		# Combine the 's2d' and 'x1d' columns
		merged["s2d"] = merged["s2d_self"] + merged["s2d_other"]
		merged["x1d"] = merged["x1d_self"] + merged["x1d_other"]
		
		# Keep only necessary columns: 'sourceID', 's2d', 'x1d'
		self.table = merged[["sourceID", "s2d", "x1d"]]
		
		print("Finished Combining Datacubes!")

	def preloadDataCube(self):
		"""
		Initializes self.dataTable, a table structurally identical to self.table, 
		except the paths are replaced by the corresponding datamodels
		"""
		print("Starting loading data...")
		print("Copying...")
		self.dataTable = self.table.copy()
		
		# Process lists of file paths
		def processList(file_list):
			return [dm.open(file) for file in file_list if isinstance(file, str)]
		
		print("Loading...")
		# Process the 'x1d' and 's2d' columns
		self.dataTable["x1d"] = self.dataTable["x1d"].apply(processList)
		self.dataTable["s2d"] = self.dataTable["s2d"].apply(processList)
		
		print("Finished loading data!")
		
	def exploreDataCube(self):
		fig, axes = plt.subplots(4, 1, figsize=(18, 7), gridspec_kw={'height_ratios': [1, 1, 1, 5]})
		plt.subplots_adjust(left=0.1, bottom=0.15, right=0.9, top=0.9, hspace=0)
		
		idx = 0

		z1, z2 = ZScaleInterval().get_limits(self.dataTable["s2d"][idx][1].data)
		
		# Display initial images
		axes[0].imshow(self.dataTable["s2d"][idx][0].data, aspect='auto', vmin=z1, vmax=z2, cmap="viridis")
		axes[1].imshow(self.dataTable["s2d"][idx][1].data, aspect='auto', vmin=z1, vmax=z2, cmap="viridis")
		axes[2].imshow(self.dataTable["s2d"][idx][2].data, aspect='auto', vmin=z1, vmax=z2, cmap="viridis")
		
		text = axes[0].text(0.02, 0.3, "Custom Pipeline", color="w", transform=axes[0].transAxes)
		text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])
		text = axes[1].text(0.02, 0.3, "Basic Pipeline", color="w", transform=axes[1].transAxes)
		text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])
		text = axes[2].text(0.02, 0.3, "No Subtraction", color="w", transform=axes[2].transAxes)
		text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])
		
		axes[0].set_axis_off()
		axes[1].set_axis_off()
		axes[2].set_axis_off()
		
		wavelength = self.dataTable["x1d"][idx][0].spec[0].spec_table["WAVELENGTH"]
		flux = self.dataTable["x1d"][idx][0].spec[0].spec_table["FLUX"]
		err = self.dataTable["x1d"][idx][0].spec[0].spec_table["FLUX_ERROR"]
		axes[3].errorbar(wavelength, flux, yerr=err, label='Custom Pipeline', capsize=3)
		
		wavelength = self.dataTable["x1d"][idx][1].spec[0].spec_table["WAVELENGTH"]
		flux = self.dataTable["x1d"][idx][1].spec[0].spec_table["FLUX"]
		err = self.dataTable["x1d"][idx][1].spec[0].spec_table["FLUX_ERROR"]
		axes[3].errorbar(wavelength, flux, yerr=err, label='Basic Pipeline', capsize=3)
		
		wavelength = self.dataTable["x1d"][idx][2].spec[0].spec_table["WAVELENGTH"]
		flux = self.dataTable["x1d"][idx][2].spec[0].spec_table["FLUX"]
		err = self.dataTable["x1d"][idx][2].spec[0].spec_table["FLUX_ERROR"]
		axes[3].errorbar(wavelength, flux, yerr=err, label='No Subtraction', capsize=3)
		
		axes[3].text(0.05, 0.05, f"SourceID: {self.dataTable['sourceID'][idx]}", color="k", transform=axes[3].transAxes, size=15)
		axes[3].grid()
		axes[3].legend()
		
		
		# Update sourceID
		def update(val):
			idx = int(slider.val)  # Get the current slider value
			
			z1, z2 = ZScaleInterval().get_limits(self.dataTable["s2d"][idx][1].data)
			
			axes[0].clear()
			axes[0].imshow(self.dataTable["s2d"][idx][0].data, aspect='auto', vmin=z1, vmax=z2, cmap="viridis")
			axes[1].clear()
			axes[1].imshow(self.dataTable["s2d"][idx][1].data, aspect='auto', vmin=z1, vmax=z2, cmap="viridis")
			axes[2].clear()
			axes[2].imshow(self.dataTable["s2d"][idx][2].data, aspect='auto', vmin=z1, vmax=z2, cmap="viridis")
			
			text = axes[0].text(0.02, 0.3, "Custom Pipeline", color="w", transform=axes[0].transAxes)
			text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])
			text = axes[1].text(0.02, 0.3, "Basic Pipeline", color="w", transform=axes[1].transAxes)
			text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])
			text = axes[2].text(0.02, 0.3, "No Subtraction", color="w", transform=axes[2].transAxes)
			text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])
				
			axes[3].clear()  # Clear the current error bar plot
			
			wavelength = self.dataTable["x1d"][idx][0].spec[0].spec_table["WAVELENGTH"]
			flux = self.dataTable["x1d"][idx][0].spec[0].spec_table["FLUX"] * 1.1
			err = self.dataTable["x1d"][idx][0].spec[0].spec_table["FLUX_ERROR"]
			axes[3].errorbar(wavelength, flux, yerr=err, label='Custom Pipeline', capsize=3)
			
			wavelength = self.dataTable["x1d"][idx][1].spec[0].spec_table["WAVELENGTH"]
			flux = self.dataTable["x1d"][idx][1].spec[0].spec_table["FLUX"] * 0.9
			err = self.dataTable["x1d"][idx][1].spec[0].spec_table["FLUX_ERROR"]
			axes[3].errorbar(wavelength, flux, yerr=err, label='Basic Pipeline', capsize=3)
			
			wavelength = self.dataTable["x1d"][idx][2].spec[0].spec_table["WAVELENGTH"]
			flux = self.dataTable["x1d"][idx][2].spec[0].spec_table["FLUX"]
			err = self.dataTable["x1d"][idx][2].spec[0].spec_table["FLUX_ERROR"]
			axes[3].errorbar(wavelength, flux, yerr=err, label='No Subtraction', capsize=3)
			
			axes[3].text(0.05, 0.05, f"SourceID: {self.dataTable['sourceID'][idx]}", color="k", transform=axes[3].transAxes, size=15)
			axes[3].legend()
			axes[3].grid()
			fig.canvas.draw_idle()
		
		# Slider
		ax_slider = plt.axes((0.2, 0.05, 0.6, 0.03))
		N = len(self.dataTable["sourceID"])
		slider = Slider(ax_slider, 'Source', 0, N - 1, valinit=idx, valstep=1)
		
		# Attach the update function to the slider
		slider.on_changed(update)
		
		def onKey(event):
			current = slider.val
			if event.key == "right":  # Move slider one step right
				new = min(current + 1, N - 1)  # Ensure within bounds
				slider.set_val(new)
			elif event.key == "left":  # Move slider one step left
				new = max(current - 1, 0)  # Ensure within bounds
				slider.set_val(new)

		# Connect keypress handler
		fig.canvas.mpl_connect("key_press_event", onKey)
		
		# Show the plot
		plt.show()


In [241]:
import copy

dc = DataCube("/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA/Final")
_ = copy.deepcopy(dc)
dc.combineDataCube(dc)
dc.combineDataCube(_)

dc.preloadDataCube()
dc.dataTable

Starting Combining Datacubes
Merging...
Replacing empty values...
Appending paths...
Finished Combining Datacubes!
Starting Combining Datacubes
Merging...
Replacing empty values...
Appending paths...
Finished Combining Datacubes!
Starting loading data...
Copying...
Loading...
Finished loading data!


Unnamed: 0,sourceID,s2d,x1d
0,44404,"[<SlitModel(11, 428) from jw01345-o063_s44404_...",[<MultiSpecModel from jw01345-o063_s44404_nirs...
1,44419,"[<SlitModel(11, 440) from jw01345-o063_s44419_...",[<MultiSpecModel from jw01345-o063_s44419_nirs...


In [242]:
%matplotlib tkagg

dc.exploreDataCube()