In [7]:
from astropy.io import fits
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import ZScaleInterval
from glob import glob
import os
from matplotlib.colors import ListedColormap
from scipy.special import linestyle
from stdatamodels.jwst import datamodels as dm
from BNBG.Pipeline.BetterBackgroundSubtractStep import getDataWithMask, cleanupImage
from BNBG.utils import getSourcePosition
from matplotlib.patches import Rectangle
import matplotlib.patheffects as path_effects

In [74]:
class MultiStepObject:
	def __init__(self, folder):
		"""
		Acts as a container for MOS data at Stage 2 (cals), Stage 2 + BNBG (BNBG) and backgrounds (bkg), and Stage 3 (s2d).

		Parameters
		----------
		folder : str
			Path to the folder
		"""
		def get_sorted_files(folder, pattern):
			files = glob(os.path.join(folder, pattern))
			files.sort()
			return files

		cal_list = get_sorted_files(folder, '*nrs1_cal.fits') + get_sorted_files(folder, '*nrs2_cal.fits')
		print(cal_list)
		s2d2_list = get_sorted_files(folder, '*nrs1_s2d.fits') + get_sorted_files(folder, '*nrs2_s2d.fits')
		bkg_list = get_sorted_files(folder, '*nrs1_bkg-BNBG.fits') + get_sorted_files(folder, '*nrs2_bkg-BNBG.fits')
		BNBG_list = get_sorted_files(folder, '*nrs1_cal-BNBG.fits') + get_sorted_files(folder, '*nrs2_cal-BNBG.fits')

		# Assuming lists are of format
		# [001_nrs1, 002_nrs1, 003_nrs1, 001_nrs2, 002_nrs2, 003_nrs2]

		s2d_list = glob(os.path.join(folder, 'Final/*_s2d.fits'))
		print("Opening CALs")
		self.cal_list = [fits.open(_) for _ in cal_list]

		# list of lists, each containing an HDU representing wavelengths from cal
		self.wavelength_list = [[_ for _ in cal if _.name == "WAVELENGTH"] for cal in self.cal_list]

		# list of lists, each containing an HDU representing data from cal
		self.cal_list = [[_ for _ in cal if _.name == "SCI"] for cal in self.cal_list]

		# list of lists, each containing the source id of every slit
		# TODO : some duplicate sources, need to find another way to discriminate between slits
		self.cal_sources = [[_.header["sourceid"] for _ in cal] for cal in self.cal_list]

		print("Opening S2D2S")
		# list of Datamodels for each s2d from stage 2, dm.slits acts the same as a list of lists
		self.s2d2_list = [dm.open(_) for _ in s2d2_list]

		print("Opening BKGs")
		# list of lists, each containing an HDU representing the bkg of each cal
		self.bkg_list = [fits.open(_) for _ in bkg_list]
		self.bkg_list = [[_ for _ in bkg if _.name == "SCI"] for bkg in self.bkg_list]

		print("Opening BNBGs")
		# list of lists, each containing an HDU representing the background subtracted cal
		self.BNBG_list = [fits.open(_) for _ in BNBG_list]
		self.BNBG_list = [[_ for _ in BNBG if _.name == "SCI"] for BNBG in self.BNBG_list]

		print("Opening S2Ds")
		# list of s2d HDU from stage 3
		self.s2d_list = [fits.open(_) for _ in s2d_list]

		# list of corresponding source ids
		self.s2d_sources = [s2d[1].header["sourceid"] for s2d in self.s2d_list]

	def plot(self, directory):
		"""
		Iterates on every source and saves the plot for each.

		Parameters
		----------
		directory : str
			Path to the folder where the plots will be saved
		"""
		for source_id in self.s2d_sources:
			cal_list = self._get_hdus_by_source(source_id, self.cal_list)
			bkg_list = self._get_hdus_by_source(source_id, self.bkg_list)
			bnbg_list = self._get_hdus_by_source(source_id, self.BNBG_list)
			wave_list = self._get_hdus_by_source(source_id, self.wavelength_list)
			s2d_list = MultiStepObject._get_slit_by_source(source_id, self.s2d2_list)
			s2d = self.s2d_list[self.s2d_sources.index(source_id)]

			isnrs1 = not all(_ is None for _ in cal_list[:3])
			isnrs2 = not all(_ is None for _ in cal_list[3:])

			if isnrs1:
				fig = MultiStepObject._plot(cal_list[:3], bkg_list[:3], bnbg_list[:3], wave_list[:3], s2d_list[:3], s2d)
				fig.savefig(f"{directory}{source_id}-nrs1.png")
				plt.show()
			if isnrs2:
				fig = MultiStepObject._plot(cal_list[3:], bkg_list[3:], bnbg_list[3:], wave_list[3:], s2d_list[3:], s2d)
				fig.savefig(f"{directory}{source_id}-nrs2.png")
				plt.show()


	def _get_hdus_by_source(self, source_id, data_list):
		"""
		For a given source_id, will return a list the size of each list of lists of HDU,
		with None if no HDU corresponds to source_id within the file, or the first HDU found if there is.

		Parameters
		----------
		source_id : int
			The id of the source

		data_list :
			A list of lists of HDUs. Some HDUs may correpond to the same source_id, only the first one will be returned

		Returns
		-------
		result : list
			A list of length len(data_list), with each element being either an HDU or None
		"""
		# Hdus, sids are lists, one for each 00N_nrsM file (N=1,2,3, M=1,2)
		# Next will return the first element of all hdu that verify sid = source_id
		# If none is found, returns None
		return [
			next((hdu for hdu, sid in zip(hdus, sids) if sid == source_id), None)
			for hdus, sids in zip(data_list, self.cal_sources)
		]

	@staticmethod
	def _get_slit_by_source(source_id, slits_list):
		"""
		For a given source_id, will return a list the size of slits_list,
		with None if no slit corresponds to source_id within the slits, or the first slit found if there is.

		Parameters
		----------
		source_id : int
			The id of the source

		slits_list :
			A list of MultiSlitModels. Some slits may correpond to the same source_id, only the first one will be returned

		Returns
		-------
		result : list
			A list of length len(slits_list), with each element being either a slit or None
		"""
		return [next((slit for slit in multi_slit_model.slits if slit.source_id == source_id), None) for multi_slit_model in slits_list]

	@staticmethod
	def _binary_to_colormap(use_first_group=True):
		"""
		Create a colormap for 6-bit binary values based on additive color mixing.

		Parameters
		----------
		use_first_group : boolean
			If True, use bits 0,2,4; if False, use bits 1,3,5

		Returns
		-------

		"""
		colors = []
		for i in range(64):  # Iterate over all possible 6-bit values
			binary_str = f"{i:06b}"  # Convert number to 6-bit binary string
			if use_first_group:
				r, g, b = int(binary_str[0]), int(binary_str[2]), int(binary_str[4])
			else:
				r, g, b = int(binary_str[1]), int(binary_str[3]), int(binary_str[5])
			colors.append((r, g, b))  # Additive color mixing

		return ListedColormap(colors)

	@staticmethod
	def _plot(cal: list, bkg: list, bnbg: list, wave : list, s2d2: list, s2d):
		"""
		Creates a plot for a given sourceid and nrs.

		Parameters
		----------
		cal : list
			Should be of length 3, the cal HDUs corresponding to source_id

		bkg : list
			Should be of length 3, the bkg HDUs corresponding to source_id

		bnbg : list
			Should be of length 3, the background subtracted cals HDUs corresponding to source_id

		wave : list
			Should be of length 3, the wavelength HDUs corresponding to source_id

		s2d2 : list
			Should be of length 3, the s2d datamodels slits corresponding to source_id

		s2d : HDU
			A single HDU corresponding to the stage 3 s2d file for source_id

		Returns
		-------
		fig : matplotlib.figure.Figure
			The figure containing the plot

		"""
		# Get information from cal on source_id and nrs
		nrs = [_ for _ in s2d2 if not _ is None][0].meta.instrument.detector
		sourceid = [_ for _ in cal if not _ is None][0].header["SOURCEID"]
		cmap = NRS2_CMAP if nrs == "NRS2" else NRS2_CMAP

		# Create figure grid
		fig = plt.figure(figsize=(16, 12))
		gs = fig.add_gridspec(6, 4, height_ratios=[4, 4, 1.5, 1.5, 1.5, 12], hspace=0, wspace=0)

		# Block of 4x3 plots
		ax_s2d2 = [fig.add_subplot(gs[i+2, 0], xticks=[], yticks=[]) for i in range(3)]
		ax_cal = [fig.add_subplot(gs[i+2, 1], xticks=[], yticks=[]) for i in range(3)]
		ax_bkg = [fig.add_subplot(gs[i+2, 2], xticks=[], yticks=[]) for i in range(3)]
		ax_bnbg = [fig.add_subplot(gs[i+2, 3], xticks=[], yticks=[]) for i in range(3)]

		# 1 long bottom plot for the extracted background
		ax_spec = fig.add_subplot(gs[5, :])

		# 1 long top plot for the final stage 3 s2d
		ax_s2d = fig.add_subplot(gs[1, :], xticks=[], yticks=[])
		# The corresponding CON plot
		ax_con = fig.add_subplot(gs[0, :], xticks=[], yticks=[])

		def plot_cal(axs, imgs, masked=False):

			# Define same scale for every axs
			values = np.array([])
			for i in range(3):
				if imgs[i] is not None:
					values = np.append(values, imgs[i].data.ravel())
			values = values[np.isfinite(values)]
			z1, z2 = (0, 1) if len(values) == 0 else ZScaleInterval().get_limits(values)

			# Plot 3 nods
			for i in range(3):
				if imgs[i] is not None:
					# Can be called either if imgs is a list of HDUs or of slits
					img = imgs[i].data.copy()
					if not masked:
						axs[i].imshow(img, interpolation="none", origin="lower", vmin=z1, vmax=z2, aspect="auto", zorder=0)
					else :
						# If mask=True, will calculate a mask and plot the image twice, once in black and white,
						# and a second time in color and on top but with masked pixels
						axs[i].imshow(img, interpolation="none", origin="lower", vmin=z1, vmax=z2, aspect="auto", zorder=0, cmap="gray")

						# It is assumed here that imgs is a list of slits
						source = getSourcePosition(imgs[i])
						mask = cleanupImage(img,imgs[i].err.copy(),source=source)
						img[mask] = np.nan

						axs[i].imshow(img, interpolation="none", origin="lower", vmin=z1, vmax=z2, aspect="auto", zorder=1)


		plot_cal(ax_cal, cal)
		plot_cal(ax_bkg, bkg)
		plot_cal(ax_bnbg, bnbg)
		plot_cal(ax_s2d2, s2d2, masked=True)

		for ax, c, i in zip(ax_s2d2, ["r", "g", "b"], range(1,4)):
			ax.add_patch(Rectangle((0, 0), 4, 1, transform=ax.transAxes, color=c, alpha=0.3, zorder=-1, clip_on=False))
			text = ax.text(0.1, 0.3, f"00{i}", color="w", transform=ax.transAxes)
			text.set_path_effects([path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal()])

		z1, z2 = 1e10, -1e10
		for i, c in enumerate([np.array([1,0,0,1]), np.array([0,1,0,1]), np.array([0,0,1,1])]):
			if s2d2[i] is not None:
				# Data
				Y, X = np.indices(s2d2[i].data.shape)
				_, _, dataLambda = s2d2[i].meta.wcs.transform("detector", "world", X, Y)
				x,y,dy = getDataWithMask(s2d2[i].data.copy(), s2d2[i].err.copy(), dataLambda.copy(), source=getSourcePosition(s2d2[i]))
				ax_spec.scatter(x, y, color=c, marker='+', alpha=0.6, label=f"00{i+1}")
				z1 = min(np.nanmin(y),z1)
				z2 = max(np.nanmax(y),z2)

				# Fit
				X = wave[i].data.ravel()
				Y = bkg[i].data.ravel()
				mask = (np.isfinite(X)) & (np.isfinite(Y))
				X, Y = X[mask], Y[mask]
				indices = np.argsort(X)
				X, Y = X[indices], Y[indices]

				color = np.clip(c+np.array([0.5, 0.5, 0.5, 1]), 0, 1)
				ax_spec.plot(X, Y, color=color)

		ax_spec.set_ylabel("Flux")
		ax_spec.set_xlabel(fr"$\lambda$ (µm)")

		# Hard coding the y limits to prevent the fit from scaling them
		dz = (z2 - z1) / 2
		z = (z1 + z2) / 2
		ax_spec.set_ylim(z-dz*1.05,z+dz*1.05)
		ax_spec.grid(True)
		ax_spec.legend()

		# Stage 3 s2d of source
		s2d_sci = s2d["SCI"].data
		z1, z2 = ZScaleInterval().get_limits(s2d_sci)
		ax_s2d.imshow(s2d_sci, interpolation="none", origin="lower", vmin=z1, vmax=z2, aspect="auto")

		# Context image for stage 3 s2d
		con_img = s2d["CON"].data[0, :, :].copy()
		ax_con.imshow(con_img, interpolation="none", origin="lower", cmap=cmap, vmin=0, vmax=63, aspect="auto")
		ax_con.set_title(f"{sourceid}-{nrs}")

		print(np.unique(con_img))

		for ax in ax_s2d2 + ax_cal + ax_bkg + ax_bnbg + [ax_s2d] + [ax_con]:
			ax.set_axis_off()

		return fig

NRS2_CMAP = MultiStepObject._binary_to_colormap(use_first_group=True)
NRS1_CMAP = MultiStepObject._binary_to_colormap(use_first_group=False)

In [70]:
MSO = MultiStepObject("/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA")

['/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA/jw01345063001_03101_00001_nrs1_cal.fits', '/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA/jw01345063001_03101_00002_nrs1_cal.fits', '/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA/jw01345063001_03101_00003_nrs1_cal.fits', '/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA/jw01345063001_03101_00001_nrs2_cal.fits', '/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA/jw01345063001_03101_00002_nrs2_cal.fits', '/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA/jw01345063001_03101_00003_nrs2_cal.fits']
Opening CALs
Opening S2D2S
Opening BKGs
Opening BNBGs
Opening S2Ds


In [40]:
%matplotlib inline
plt.close("all")
MSO.plot("/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA/SlitAnalysis")

TypeError: MultiStepObject._plot() takes 5 positional arguments but 8 were given

In [75]:
%matplotlib Qt5Agg

source_id = 4402

cal_list = MSO._get_hdus_by_source(source_id, MSO.cal_list)
bkg_list = MSO._get_hdus_by_source(source_id, MSO.bkg_list)
bnbg_list = MSO._get_hdus_by_source(source_id, MSO.BNBG_list)
wave_list = MSO._get_hdus_by_source(source_id, MSO.wavelength_list)
s2d_list = MultiStepObject._get_slit_by_source(source_id, MSO.s2d2_list)
s2d = MSO.s2d_list[MSO.s2d_sources.index(source_id)]

isnrs1 = not all(_ is None for _ in cal_list[:3])
isnrs2 = not all(_ is None for _ in cal_list[3:])

if isnrs1:
	fig = MultiStepObject._plot(cal_list[:3], bkg_list[:3], bnbg_list[:3], wave_list[:3], s2d_list[:3], s2d)
	plt.show()
if isnrs2:
	fig = MultiStepObject._plot(cal_list[3:], bkg_list[3:], bnbg_list[3:], wave_list[3:], s2d_list[3:], s2d)
	plt.show()

[0 1 2 3 4 5 6 7]


In [144]:
from functools import reduce
to_spec3_order = np.array([0, 3, 1, 4, 2, 5])

def get_sorted_files(folder, pattern):
	files = glob(os.path.join(folder, pattern))
	files.sort()
	return files

def index_in_px(i):
	return np.where([i & 2**n for n in range(6)])[0]

files = get_sorted_files("/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA", '*nrs1_cal.fits') + get_sorted_files("/home/tim-dewachter/Documents/Thèse/BetterNIRSpecBackground/mastDownload/JWST/CEERS-NIRSPEC-P5-PRISM-MSATA", '*nrs2_cal.fits')
files = [os.path.basename(_) for _ in files]
files = [_[22:-9] for _ in files]

for hdu in MSO.s2d_list:
	print(hdu[1].header["sourceid"])
	cal_list = MSO._get_hdus_by_source(hdu[1].header["sourceid"], MSO.cal_list)
	nrs = np.array([files[i] if not cal_list[i] is None else "********" for i, _ in enumerate(cal_list)])[to_spec3_order]
	print(nrs)
	for i in np.unique(hdu["CON"].data):
		print(f"{i:02} : {i:06b} - {index_in_px(i)} - {nrs[nrs != '********'][index_in_px(i)]}")
	bitor = reduce(lambda x, y: x | y, hdu["CON"].data.ravel())
	print(f"{bitor:02} : {bitor:06b} - {index_in_px(bitor)} - {nrs[nrs != '********'][index_in_px(bitor)]}")
	print("\n")

-119
['001_nrs1' '********' '********' '********' '********' '********']
00 : 000000 - [] - []
01 : 000001 - [0] - ['001_nrs1']
01 : 000001 - [0] - ['001_nrs1']


171
['001_nrs1' '********' '002_nrs1' '********' '003_nrs1' '********']
00 : 000000 - [] - []
01 : 000001 - [0] - ['001_nrs1']
02 : 000010 - [1] - ['002_nrs1']
03 : 000011 - [0 1] - ['001_nrs1' '002_nrs1']
04 : 000100 - [2] - ['003_nrs1']
05 : 000101 - [0 2] - ['001_nrs1' '003_nrs1']
06 : 000110 - [1 2] - ['002_nrs1' '003_nrs1']
07 : 000111 - [0 1 2] - ['001_nrs1' '002_nrs1' '003_nrs1']
07 : 000111 - [0 1 2] - ['001_nrs1' '002_nrs1' '003_nrs1']


-55
['001_nrs1' '001_nrs2' '********' '********' '********' '********']
00 : 000000 - [] - []
01 : 000001 - [0] - ['001_nrs1']
02 : 000010 - [1] - ['001_nrs2']
03 : 000011 - [0 1] - ['001_nrs1' '001_nrs2']


-65
['********' '001_nrs2' '********' '002_nrs2' '********' '********']
00 : 000000 - [] - []
01 : 000001 - [0] - ['001_nrs2']
02 : 000010 - [1] - ['002_nrs2']
03 : 000011 - [0 1

In [168]:
to_spec3_order = np.array([0, 2, 4, 1, 3, 5])
for hdu in MSO.s2d_list:
	print(hdu[1].header["sourceid"])

	cal_list = MSO._get_hdus_by_source(hdu[1].header["sourceid"], MSO.cal_list)
	index_used = np.array([to_spec3_order[i] for i in range(len(cal_list)) if not cal_list[i] is None])
	print(index_used)
	relative_to_absolute = {i: index_used[i] for i in range(len(index_used))}

	CON = hdu["CON"].data[0,:,:]
	absolute_CON = np.zeros_like(CON)
	for i in range(CON.shape[0]):
		for j in range(CON.shape[1]):
			rel_value = CON[i, j]
			# Convert to 6-bit binary string
			binary_str = format(rel_value, '06b')

			# Convert to absolute binary representation
			abs_binary = ['0'] * 6  # Start with all zeros
			for rel_idx, bit in enumerate(binary_str[::-1]):  # Reverse for LSB-first processing
				if bit == '1':
					abs_idx = relative_to_absolute[rel_idx]  # Map relative index to absolute index
					abs_binary[5 - abs_idx] = '1'  # Set the correct absolute position

			# Convert binary string back to integer
			absolute_CON[i, j] = int(''.join(abs_binary), 2)
	print(np.unique(CON))
	for i in np.unique(absolute_CON):
		print(f"{i:02} - {i:06b}")

	print("\n")

-119
[0]
[0 1]
00 - 000000
01 - 000001


171
[0 2 4]
[0 1 2 3 4 5 6 7]
00 - 000000
01 - 000001
04 - 000100
05 - 000101
16 - 010000
17 - 010001
20 - 010100
21 - 010101


-55
[0 1]
[0 1 2]
00 - 000000
01 - 000001
02 - 000010


-65
[1 3]
[0 1 2 3]
00 - 000000
02 - 000010
08 - 001000
10 - 001010


4611
[0 2 4 1 3 5]
[ 0  1  2  4  5  8 10 16 17 20 21 32 34 40 42]
00 - 000000
01 - 000001
02 - 000010
04 - 000100
06 - 000110
08 - 001000
09 - 001001
16 - 010000
17 - 010001
24 - 011000
25 - 011001
32 - 100000
34 - 100010
36 - 100100
38 - 100110


120
[3 5]
[0 1 2]
00 - 000000
08 - 001000
32 - 100000


-116
[0 2 4]
[0 1 2 3 4 5 6 7]
00 - 000000
01 - 000001
04 - 000100
05 - 000101
16 - 010000
17 - 010001
20 - 010100
21 - 010101


23275
[1 3 5]
[0 1 2 3 4 5 6 7]
00 - 000000
02 - 000010
08 - 001000
10 - 001010
32 - 100000
34 - 100010
40 - 101000
42 - 101010


-91
[0]
[0 1]
00 - 000000
01 - 000001


2269
[0 2 4]
[0 1 2 3 4 5 6 7]
00 - 000000
01 - 000001
04 - 000100
05 - 000101
16 - 010000
17 - 010001