In [7]:
import matplotlib.pyplot as plt
from astropy.visualization import ZScaleInterval
from matplotlib.patches import Rectangle
from matplotlib.widgets import Slider
from matplotlib.lines import Line2D
import astropy.units as u
from matplotlib.image import AxesImage
from matplotlib.text import Text
import matplotlib.patheffects as path_effects
import os
from astropy.io import fits
from glob import glob
import pandas as pd
import numpy as np
from astropy.table import Table

In [8]:
class DataCube:
	"""
	An object combining multiple 1D and 2D spectra from a given folder.


	Parameters :
	----------
	folder :
		path to the folder containing the final data. Will loop on every fits file
		and keep the one containing the _s2d or _x1d suffix

	keyword :
		A keyword used for identifying what type of data we're looking at.

	Properties :
	----------
	table :
		a table used for making a correspondence between a source id, a s2d file and a x1d file.
		Those files are stored as a list of paths
	"""

	def __init__(self, folder, keyword):

		s2dList = [file for file in sorted(glob(os.path.join(folder, '*_s2d*')))]
		x1dList = [x if os.path.exists(x := file.replace("_s2d", "_x1d")) else None for file in s2dList]
		sourceList = [fits.open(file)[1].header["SOURCEID"] for file in s2dList]

		self.table = pd.DataFrame({"sourceID": sourceList, "s2d": s2dList, "x1d": x1dList, "keyword": keyword})
		self.table = self.table.groupby("sourceID", sort=False).agg(lambda s : list(s))

		# Initializes the dataframe which will contain the data models
		self.dataTable = self.table.copy()


	def combineDataCube(self, datacube):
		"""
		Combines 2 Datacubes.

		Parameters :
		----------
		datacube : DataCube
			Another datacube to be appended to this one, or None,
			which will fill append a none to every list in the array.

		Returns :
		---------
		dc : DataCube
			The combined dataframe
		"""

		print("Starting Combining Datacubes")

		print("Merging...")
		# Perform an outer join to include all sourceIDs from both DataFrames
		merged = pd.merge(
			self.table, datacube.table, on="sourceID", how="outer", suffixes=("_self", "_other"), sort=False
		)
		merged = merged.map(lambda x: [] if (not isinstance(x, list) and pd.isna(x)) else x)

		print("Appending paths...")
		# Combine the 's2d' and 'x1d' columns
		merged["s2d"] = merged["s2d_self"] + merged["s2d_other"]
		merged["x1d"] = merged["x1d_self"] + merged["x1d_other"]
		merged["keyword"] = merged["keyword_self"] + merged["keyword_other"]
		merged.reset_index(drop=False, inplace=True)

		# Keep only necessary columns: 'sourceID', 's2d', 'x1d'
		self.table = merged[["sourceID", "s2d", "x1d", "keyword"]]

		print("Finished Combining Datacubes!")

	def preloadDataCube(self):
		"""
		Initializes self.dataTable, a table structurally identical to self.table,
		except the paths are replaced by the corresponding datamodels
		"""
		print("Starting loading data...")
		print("Copying...")
		self.dataTable = self.table.copy()

		print("Loading...")
		# Process the 'x1d' and 's2d' columns
		self.dataTable["x1d"] = self.dataTable["x1d"].apply(
			lambda fileList :
			[Table.read(file, 1) if isinstance(file, str) else None for file in fileList])
		self.dataTable["s2d"] = self.dataTable["s2d"].apply(
			lambda fileList :
			[fits.open(file) if isinstance(file, str) else None for file in fileList])

		print("Getting Extraction")
		self.dataTable["extract"] = self.table["x1d"].apply(
			lambda fileList :
			[fits.open(x1d)[1] for x1d in fileList]
		)
		self.dataTable["extract"] = self.dataTable["extract"].apply(
			lambda l:
			[(x1d.header["EXTRXSTR"],
			x1d.header["EXTRXSTP"],
			x1d.header["EXTRYSTR"],
			x1d.header["EXTRYSTP"])
			for x1d in l]
		)

		print("Flux Correction...")
		for idx,s2ds in enumerate(self.dataTable["s2d"]):
			for j in range(len(s2ds)):
				toJy = 1
				s2d = s2ds[j]
				if s2d is None:
					continue
				if s2d[1].header["BUNIT"] == "MJy/sr" :
					toJy = s2d[1].header["PIXAR_SR"]*1e6
					s2d[1].header["BUNIT"] = "Jy"
					#print("MJy/sr")
				if s2d[1].header["BUNIT"] == "MJy" :
					toJy = 1e6
					s2d[1].header["BUNIT"] = "Jy"
					#print("MJy")
				x1d = self.dataTable["x1d"][idx][j]
				x1d["FLUX"] *= toJy
				x1d["FLUX_ERROR"] *= toJy

		print("Finished loading data!")

	@staticmethod
	def exploreDataCube(dc):

		n = 2
		fig, axes = plt.subplots(n+1, 1, figsize=(18, 7), gridspec_kw={'height_ratios': [1]*n +  [4*n]})
		plt.subplots_adjust(left=0.1, bottom=0.15, right=0.9, top=0.9, hspace=0)

		idx = np.random.randint(len(dc.dataTable))
		colors = ["xkcd:teal", "xkcd:violet", "xkcd:hot pink", "xkcd:yellow"]

		# Vertical line for mouse tracking
		vline = Line2D([0, 0], [0, 1], color='r', linestyle='dashed', linewidth=0.5)
		axes[-1].add_line(vline)
		image_vlines = [axes[i].axvline(x=0, color='r', linestyle='dashed', linewidth=1) for i in range(n)]

		# Initialize spectrum plot
		spectrum_lines : list[Line2D] = [axes[-1].plot([], [], label=f"Legend {i}", color=colors[i])[0] for i in range(n)]
		error_lines = [axes[-1].plot([], [], color=colors[i], linewidth=0.7)[0] for i in range(n)]
		src_text = axes[-1].text(0.05, 0.05, f"SourceID: 0", color="k", transform=axes[-1].transAxes, size=15)

		# Initialize image plot
		img_artists : list[AxesImage] = []
		text_artist : list[Text] = []
		for i in range(n):
			im = axes[i].imshow(np.zeros((1,1)), aspect='auto', cmap="viridis", origin="lower", interpolation="none")
			img_artists.append(im)

			txt = axes[i].text(0.02, 0.3, f"legend {i}", color="w", transform=axes[i].transAxes)
			txt.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])
			text_artist.append(txt)

		# Extraction lines
		extr_rects = []
		for i in range(n):
			rect = Rectangle((0,0),1,1, edgecolor='r', facecolor='none', lw=1, linestyle='dotted')
			axes[i].add_patch(rect)
			extr_rects.append(rect)

		# Update sourceID
		def update(val):
			idx = int(slider.val)  # Get the current slider value
			#c = 10**slider_coeff.val

			y1,y2 = np.nan,np.nan

			for i in range(n):
				legend = dc.table["keyword"][idx][i]

				if dc.dataTable["s2d"][idx][i] is not None:
					img = dc.dataTable["s2d"][idx][i][1].data
					z1, z2 = ZScaleInterval().get_limits(img[img>0])
					img_artists[i].set_data(img)
					img_artists[i].set_clim(z1, z2)
					img_artists[i].set_extent((0,img.shape[1],0,img.shape[0]))

					xy = dc.dataTable["extract"][idx][i]
					extr_rects[i].set_bounds((xy[0], xy[2], xy[1]-xy[0], xy[3]-xy[2]))

				text_artist[i].set_text(legend)

				if not dc.dataTable["x1d"][idx][i] is None:
					wavelength = dc.dataTable["x1d"][idx][i]["WAVELENGTH"].copy()
					flux = dc.dataTable["x1d"][idx][i]["FLUX"].copy()
					err = dc.dataTable["x1d"][idx][i]["FLUX_ERROR"].copy()
					mask = (flux>0)

					#if i == 1:
					#	flux *= c
					#	err *= c

					spectrum_lines[i].set_data(wavelength[mask], flux[mask])
					spectrum_lines[i].set_label(legend)
					error_lines[i].set_data(wavelength[mask], err[mask])

					y1,y2 = np.nanmin(np.append(flux[mask],y1)), np.nanmax(np.append(flux[mask],y2))

			src_text.set_text(f"SourceID: {dc.dataTable['sourceID'][idx]}")
			axes[-1].legend()
			axes[-1].relim()
			axes[-1].set_ylim(y1,y2)
			"""
			y = dc.dataTable["x1d"][idx][0]["FLUX"].copy()
			x = dc.dataTable["x1d"][idx][0]["WAVELENGTH"].copy()
			yi = c*np.interp(x,dc.dataTable["x1d"][idx][1]["WAVELENGTH"],dc.dataTable["x1d"][idx][1]["FLUX"])
			dyi = c**2*(np.interp(x,dc.dataTable["x1d"][idx][1]["WAVELENGTH"],dc.dataTable["x1d"][idx][1]["FLUX_ERROR"])**2
				   + dc.dataTable["x1d"][idx][0]["FLUX_ERROR"])**2
			chi2 = np.nansum((y-yi)**2/dyi)
			ax_chi.scatter(chi2,np.log10(c), color='k', marker=".")
			"""
			fig.canvas.draw_idle()


		def onKey(event):
			current = slider.val
			if event.key == "right":  # Move slider one step right
				new = min(current + 1, N - 1)  # Ensure within bounds
				slider.set_val(new)
			elif event.key == "left":  # Move slider one step left
				new = max(current - 1, 0)  # Ensure within bounds
				slider.set_val(new)

		# Slider
		ax_slider = plt.axes((0.2, 0.05, 0.6, 0.03))
		N = len(dc.dataTable["sourceID"])
		slider = Slider(ax_slider, 'Source', 0, N - 1, valinit=idx, valstep=1)

		"""ax_coeff = plt.axes((0.92,0.1,0.0225,0.8))
		slider_coeff = Slider(ax_coeff, 'Coefficient', -2, 2, valinit=0, valstep=0.001, orientation="vertical")
		ax_chi = plt.axes((0.95,0.1,0.045,0.8))
		ax_chi.set_ylim(-2,2)
		ax_chi.set_xscale('log')
		ax_chi.grid(True)"""

		update(idx)

		# Attach the update function to the slider
		slider.on_changed(update)
		#slider_coeff.on_changed(update)

		def on_move(event):
			if event.inaxes == axes[-1]:  # Only update if cursor is in the spectral plot
				idx = slider.val

				cursor_wavelength = event.xdata  # Get the wavelength from the cursor

				if cursor_wavelength is None:
					return

				vline.set_xdata([cursor_wavelength, cursor_wavelength])  # Update main plot
				vline.set_ydata(axes[-1].get_ylim())
				for i in range(n):
					if dc.dataTable["s2d"][idx][i] is not None:
						wavelength_map = np.nanmean(dc.dataTable["s2d"][idx][i][3].data, axis=0)  # Get the wavelength mapping

						# Find the closest pixel index
						pixel_pos = np.abs(wavelength_map - cursor_wavelength).argmin()

						# Update vertical line in the corresponding image subplot
						image_vlines[i].set_xdata([pixel_pos, pixel_pos])

				fig.canvas.draw_idle()


		# Connect keypress handler
		fig.canvas.mpl_connect("key_press_event", onKey)
		fig.canvas.mpl_connect("motion_notify_event", on_move)

		axes[-1].set_xlim(0.5,5.4)
		axes[-1].set_xlabel(r"$\lambda$ (µm)")
		axes[-1].set_ylabel(r"Flux (Jy)")
		axes[-1].set_yscale("log")
		axes[-1].legend()
		axes[-1].grid()

		# Show the plot
		plt.show(block=False)



In [9]:
dc = DataCube("../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P5", "P5-Basic")
dc.combineDataCube(DataCube("../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3", "P3-Basic"))
dc.combineDataCube(DataCube("../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P2", "P2-Basic"))
dc.combineDataCube(DataCube("../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P1", "P1-Basic"))

dc.combineDataCube(DataCube("../../mastDownload/JWST/CAPERS/P5/Final", "P5-BNBG"))
dc.combineDataCube(DataCube("../../mastDownload/JWST/CAPERS/P3/Final", "P3-BNBG"))
dc.combineDataCube(DataCube("../../mastDownload/JWST/CAPERS/P2/Final", "P2-BNBG"))

dc.combineDataCube(DataCube("../../mastDownload/JWST/CAPERS/P1/Final", "P1-BNBG"))

mask = dc.table["keyword"].apply(lambda x : any("BNBG" in item for item in x) and any("Basic" in item for item in x))
dc.table = dc.table[mask]
dc.table.reset_index(drop=True, inplace=True) # Reset indexes

display(dc.table)

Starting Combining Datacubes
Merging...
Appending paths...
Finished Combining Datacubes!
Starting Combining Datacubes
Merging...
Appending paths...
Finished Combining Datacubes!
Starting Combining Datacubes
Merging...
Appending paths...
Finished Combining Datacubes!
Starting Combining Datacubes
Merging...
Appending paths...
Finished Combining Datacubes!
Starting Combining Datacubes
Merging...
Appending paths...
Finished Combining Datacubes!
Starting Combining Datacubes
Merging...
Appending paths...
Finished Combining Datacubes!
Starting Combining Datacubes
Merging...
Appending paths...
Finished Combining Datacubes!


Unnamed: 0,sourceID,s2d,x1d,keyword
0,299,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3/CAPER...,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3/CAPER...,"[P3-Basic, P3-BNBG]"
1,501,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3/CAPER...,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3/CAPER...,"[P3-Basic, P3-BNBG]"
2,1279,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3/CAPER...,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3/CAPER...,"[P3-Basic, P3-BNBG]"
3,1294,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3/CAPER...,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3/CAPER...,"[P3-Basic, P3-BNBG]"
4,1432,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3/CAPER...,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3/CAPER...,"[P3-Basic, P3-BNBG]"
...,...,...,...,...
665,154203,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P1/CAPER...,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P1/CAPER...,"[P1-Basic, P1-BNBG]"
666,157001,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P1/CAPER...,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P1/CAPER...,"[P1-Basic, P1-BNBG]"
667,157420,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P5/CAPER...,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P5/CAPER...,"[P5-Basic, P5-BNBG]"
668,159638,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3/CAPER...,[../../../CAPERS/V0.1/CAPERS_UDS_V0.1/P3/CAPER...,"[P3-Basic, P3-BNBG]"


In [10]:
dc.preloadDataCube()

Starting loading data...
Copying...
Loading...
Getting Extraction
Flux Correction...
Finished loading data!


In [11]:
%matplotlib Qt5Agg

In [12]:
plt.close("all")
DataCube.exploreDataCube(dc)