In [None]:
from pathlib import Path
import yaml

root_dir = Path("/media") / "sharedData" / "data"
out_base_dir = root_dir / "2024_08_16_A2744_v4" / "glass_niriss"
bin_data_dir = out_base_dir / "binned_data"
root_name = "glass-a2744"

with open(out_base_dir / "conv_ancillary_data.yaml", "r") as file:
    info_dict = yaml.safe_load(file)

In [None]:
from glass_niriss.isophotal import reproject_and_convolve

ref_mosaic = (
    root_dir / "2023_11_07_spectral_orders" / "Prep" / "nis-wfss-ir_drz_sci.fits"
)

# The path of the original segmentation map
orig_seg = out_base_dir.parent / "grizli_home" / "tests" / "Prep_tests" / "glass-a2744-ir_seg.fits"

reproject_and_convolve(
    ref_path=ref_mosaic,
    orig_images=orig_seg,
    psfs=None,
    psf_target=None,
    out_dir=out_base_dir / "PSF_matched_data",
    new_names="glass-a2744_seg_map.fits",
    reproject_image_kw={"method": "interp", "order": 0, "compress" : False},
)

repr_seg_path = out_base_dir / "PSF_matched_data" / "glass-a2744_seg_map.fits"

In [None]:
# Bagpipes directory
# Will include filters and atlases
pipes_dir = out_base_dir / "sed_fitting" / "pipes"
pipes_dir.mkdir(exist_ok=True, parents=True)

filter_dir = pipes_dir / "filter_throughputs"
filter_dir.mkdir(exist_ok=True, parents=True)

filter_list = []
for key in info_dict.keys():
    filter_list.append(str(filter_dir / f"{key}.txt"))

atlas_dir = pipes_dir / "atlases"
atlas_dir.mkdir(exist_ok=True, parents=True)

In [None]:

obj_id = 1761
obj_z = 3.06
# obj_id = 497
# obj_z = 0.3033
obj_id = 1597
obj_z = 2.6724

# obj_id = 3311
# obj_z = 1.34
obj_id = 2606
# obj_id = 732
obj_z = 0.296

# # obj_id = 497
# # obj_z = 0.30
# # obj_id = 2224
# # obj_z = 0.3064
# obj_id = 1742
# obj_z = 3.06
# obj_id = 908
# obj_z = 0.3033
# obj_id = 3278
# obj_z = 0.296
# obj_id = 2328
# obj_z = 1.363
# obj_id = 2720
# obj_z = 3.04
# obj_id = 5021
# obj_z = 1.8868
# obj_id = 3137
# obj_z = 0.9384
obj_id = 2928
obj_z = 3.052


# obj_id = 2074
# obj_z = 1.369

In [None]:
use_hex = True
bin_diameter = 2
target_sn = 20
sn_filter = "jwst-nircam-f200w"

from glass_niriss.sed import bin_and_save

binned_name = f"{obj_id}_{"hexbin" if use_hex else "vorbin"}_{bin_diameter}_{target_sn}"
binned_data_path = (
    bin_data_dir
    / f"{binned_name}_data.fits"
)

if not binned_data_path.is_file():
    bin_and_save(
        obj_id=obj_id,
        out_dir=bin_data_dir,
        seg_map=repr_seg_path,
        info_dict=info_dict,
        sn_filter=sn_filter,
        target_sn=target_sn,
        bin_diameter=bin_diameter,
        use_hex=use_hex,
        overwrite=True,
    )

In [None]:
from glass_niriss.pipeline import generate_fit_params

bagpipes_atlas_params = generate_fit_params(obj_z=obj_z)

print (bagpipes_atlas_params)

In [None]:
from glass_niriss.sed import AtlasGenerator

n_samples = 1e5
n_cores = 16

remake_atlas = False
run_name = (
    f"z_{bagpipes_atlas_params["redshift"][0]}_"
    f"{bagpipes_atlas_params["redshift"][1]}_"
    f"{n_samples:.2E}"
)
atlas_path = (
    atlas_dir
    / f"{run_name}.hdf5"
)

if not atlas_path.is_file() or remake_atlas:

    atlas_gen = AtlasGenerator(
        fit_instructions=bagpipes_atlas_params,
        filt_list=filter_list,
        phot_units="ergscma",
    )

    atlas_gen.gen_samples(n_samples=n_samples, parallel=n_cores)

    atlas_gen.write_samples(filepath=atlas_path)

In [None]:
from glass_niriss.pipeline import load_photom_bagpipes
from glass_niriss.sed import AtlasFitter
from functools import partial
from astropy.table import Table
import numpy as np
import os

os.chdir(pipes_dir)

load_fn = partial(load_photom_bagpipes, phot_cat = binned_data_path, cat_hdu_index="PHOT_CAT")

fit = AtlasFitter(
    fit_instructions=bagpipes_atlas_params,
    atlas_path=atlas_path,
    out_path=pipes_dir.parent,
    overwrite=False
)

obs_table = Table.read(binned_data_path, hdu="PHOT_CAT")
cat_IDs = np.arange(len(obs_table))[:]

catalogue_out_path = fit.out_path / Path(
    f"{binned_name}_{run_name}.fits"
)
if not catalogue_out_path.is_file():

    fit.fit_catalogue(
        IDs=cat_IDs,
        load_data=load_fn,
        spectrum_exists=False,
        make_plots=False,
        cat_filt_list=filter_list,
        run=f"{binned_name}_{run_name}",
        parallel=8,
    )
    print(fit.cat)
else:
    fit.cat = Table.read(catalogue_out_path)

In [None]:
import matplotlib.pyplot as plt
import cmcrameri.cm as cmc
from astropy.io import fits
fig, axs = plt.subplots(1, 1)

seg_map = fits.getdata(binned_data_path, hdu="SEG_MAP")
print (len(np.unique(seg_map)))

plot_map = np.full_like(seg_map, np.nan, dtype=float)
for row in fit.cat:
    plot_map[seg_map == int(row["#ID"])] = (
        row[
            # "continuity:massformed_50"
            "stellar_mass_50"
            # "ssfr_50"
            # "sfr_50"
            # "continuity:metallicity_50"
            # "mass_weighted_age_50"
            # "dust:Av_50"
            # "dust:eta"
            # "nebular:logU_50"
            # "redshift_50"
        ]
        # *row[
        #     "dust:eta_50"
        # ]
        # /
        -
        np.log10(
        (len((seg_map == int(row["#ID"])).nonzero()[0])
        * ((0.04 * 4.63) ** 2))
        )
    )
plot_map[seg_map==0] = np.nan
im = axs.imshow(
    plot_map,
    # np.log10(plot_map),
    origin="lower",
    # vmin=3,
    # vmax=9,
    # vmin=-4,
    # vmax=1,
    # vmin=-12,
    # vmax=-8,
    # vmin=-8,
    # vmax=-3,
    # cmap="plasma",
    # vmin=0,
    cmap="rainbow"
    # cmap = cmc.lajolla
)
axs.set_facecolor("k")
plt.colorbar(im)

plt.show()

In [None]:
# grizli_extraction_dir = root_dir / "2024_08_16_A2744_v4" / "grizli_home" / "Extractions"
# beams_path = [*grizli_extraction_dir.glob(f"*{obj_id}.beams.fits")]

# from grizli import jwst_utils
# import logging
# jwst_utils.QUIET_LEVEL = logging.WARNING
# jwst_utils.set_quiet_logging(jwst_utils.QUIET_LEVEL)


# if len(beams_path)==0:
#     from grizli import multifit, fitting
#     from grizli.pipeline import auto_script

#     os.chdir(grizli_extraction_dir)

#     flt_files = [str(s) for s in Path.cwd().glob("*GrismFLT.fits")][:]

#     grp = multifit.GroupFLT(
#         grism_files=flt_files,
#         catalog=f"{root_name}-ir.cat.fits",
#         cpu_count=-1,
#         sci_extn=1,
#         pad=800,
#     )

#     print("5. Extracting spectra...")
#     pline = {
#         "kernel": "square",
#         "pixfrac": 1.0,
#         "pixscale": 0.03,
#         "size": 50,
#         "wcs": None,
#     }
#     args = auto_script.generate_fit_params(
#         pline=pline,
#         field_root=root_name,
#         min_sens=0.0,
#         min_mask=0.0,
#         include_photometry=False,  # set both of these to True to include photometry in fitting
#         use_phot_obj=False,
#     )

#     # for id in obj["id"]):
#     # print(id)
#     # obj_id = 1597
#     # obj_z = 2.6724
#     # obj_id = 3311
#     # obj_z = 1.3397
#     # obj_id = 1761
#     # obj_z = 3.06
#     # obj_id = 886
#     # obj_z = 0.3033
#     # obj_id = 2308
#     # obj_z = 0.3033
#     beams = grp.get_beams(
#         obj_id,
#         # center_rd = (3.60940, -30.39839),
#         size=50,  # Make sure the size here is large enough to avoid the beam being cut off
#         min_mask=0,
#         min_sens=0,
#         show_exception=True,
#         beam_id="A",
#     )
#     # print (beams)
#     mb = multifit.MultiBeam(
#         beams, fcontam=0.2, min_sens=0.0, min_mask=0, group_name=root_name
#     )
#     mb.fit_trace_shift()
#     # _ = mb.oned_figure()
#     #     _ = mb.drizzle_grisms_and_PAs(size=32, scale=0.5, diff=False)
#     mb.write_master_fits()
#     from grizli import fitting

#     _ = fitting.run_all_parallel(
#         obj_id,
#         zr=[obj_z-0.02, obj_z + 0.02],
#         # zr = [2.5,3.5],
#         verbose=True,
#         get_output_data=True,
#         skip_complete=False,
#         save_figures=True,
#     )
#     print("5. Extracting spectra...[COMPLETE]")

In [None]:
# from glass_niriss.grism import RegionsMultiBeam
# import scipy
# print (scipy.__version__)


# grizli_extraction_dir = (
#     root_dir / "2024_08_16_A2744_v4" / "grizli_home" / "Extractions"
# )

# regions_out_dir = out_base_dir / "multiregion_grism" / f"{obj_id}"
# regions_out_dir.mkdir(exist_ok=True, parents=True)

# beams_path = [*grizli_extraction_dir.glob(f"*{obj_id}.beams.fits")]
# if len(beams_path) >= 1:
#     beams_path = beams_path[0]
# else:
#     raise IOError("Beams file does not exist.")

# multib = RegionsMultiBeam(
#     binned_data=binned_data_path,
#     pipes_dir=pipes_dir,
#     # f"bcg_{obj_id}_{bin_mode}_{bin_size}_{sn_target}_z_{obj_z}_{obj_z}_{atlas_size:.2E}",
#     run_name=f"{binned_name}_{run_name}",
#     beams=str(beams_path),
#     min_mask=0.0,
#     min_sens=0.0,
#     mask_resid=False,
#     verbose=False,
# )

# multib.fit_at_z(
#     z=obj_z,
#     n_samples=3,
#     veldisp=500,
#     oversamp_factor=3,
#     fit_stacks=True,
#     temp_dir=regions_out_dir.parent,
#     out_dir = regions_out_dir,
#     # direct_images=direct_images, poly_order=3
#     num_iters=30,
#     # force_iter=2,
#     cpu_count=16,
# )

# # 300 - 12752
# # 350 - 12754
# # 400 - 12756
# # 450 - 12760
# # 500 - 12765
# # 550 - 12770
# # 600 - 12770

In [None]:
# import os
# import multiprocessing as mp
# import time

# class Worker:
#     def __init__(self, data):
#         self.data = data

#     def initializer(self):
#         # setting up a database connection
#         print("{} with PID {} initialized".format(self, os.getpid()))

#     def __call__(self, value):
#         print("{} with PID {} called with value={}".format(self, os.getpid(), value))
#         # doing something with self.data and write it to the database
#         time.sleep(0.5)

# def worker_initializer(data):
#     global worker
#     worker = Worker(data)
#     worker.initializer()

# def worker_call(*args, **kwds):
#     return worker(*args, **kwds)

# def main():
#     print("main has PID {}".format(os.getpid()))
#     with mp.Pool(processes=2, initializer=worker_initializer, initargs=([1,2,3,4],)) as pool:
#         pool.map(worker_call, range(4))

# if __name__=="__main__":
#     main()