In [None]:
from pathlib import Path
import yaml
import os

root_dir = Path("/media") / "sharedData" / "data"
root_name = "glass-a2744"

catalogue_dir = root_dir / "2024_08_16_A2744_v4" / "glass_niriss" / "match_catalogues"

grizli_home_dir = root_dir / "2024_08_16_A2744_v4" / "grizli_home"
grizli_extraction_dir = grizli_home_dir / "Extractions"
os.chdir(grizli_extraction_dir)

In [None]:
from grizli import jwst_utils, multifit
import logging
import shutil

jwst_utils.QUIET_LEVEL = logging.WARNING
jwst_utils.set_quiet_logging(jwst_utils.QUIET_LEVEL)

In [None]:
flt_files = [str(s) for s in Path.cwd().glob("*GrismFLT.fits")][:]

grp = multifit.GroupFLT(
    grism_files=flt_files,
    catalog=f"{root_name}-ir.cat.fits",
    cpu_count=-1,
    sci_extn=1,
    pad=800,
)

In [None]:
from astropy.io import fits
from astropy.table import Table, vstack, join
from astropy.coordinates import match_coordinates_sky, SkyCoord
from astropy.wcs import WCS
import astropy.visualization as astrovis
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np

plot = True

if not (catalogue_dir / "ALT_ASTRODEEP_specz.fits").is_file():
    astrodeep_cat = Table.read(catalogue_dir / "ASTRODEEP-JWST-ABELL2744_photoz.fits")
    alt_cat = Table.read(catalogue_dir / "ALT_DR1_public.fits")

    astrodeep_coords = SkyCoord(
        ra=astrodeep_cat["RA"], dec=astrodeep_cat["DEC"], unit=u.deg
    )
    alt_coords = SkyCoord(ra=alt_cat["ra"], dec=alt_cat["dec"], unit=u.deg)

    idx, sep2d, d3d = match_coordinates_sky(alt_coords, astrodeep_coords)

    astrodeep_cat.rename_columns(["ID", "RA", "DEC"], ["id_astrodeep", "ra", "dec"])
    alt_cat.rename_columns(["id", "z_ALT"], ["id_alt", "zspec"])
    max_offset = 0.1 * u.arcsec

    use_idx = sep2d > max_offset

    astrodeep_cat = astrodeep_cat

    combined_spec_cat = vstack(
        [astrodeep_cat[astrodeep_cat["zspec"] > -99], alt_cat[use_idx]]
    )
    combined_spec_cat.write(catalogue_dir / "ALT_ASTRODEEP_specz.fits")

else:

    combined_spec_cat = Table.read(catalogue_dir / "ALT_ASTRODEEP_specz.fits")

if not (catalogue_dir / "grizli_specz_matched.fits").is_file():
    grp_idx, grp_sep2d = grp.catalog.match_to_catalog_sky(combined_spec_cat)
    combined_spec_cat["NUMBER"] = np.full_like(
        combined_spec_cat["ra"].shape, -99, dtype=int
    )
    print (grp.catalog.colnames)

    print(len(np.unique(grp_idx)))
    print(grp_idx)
    unique_matches = np.unique(grp_idx)
    for un in unique_matches:
        match_idxs = np.argwhere((grp_idx == un) & (grp_sep2d < 1 * u.arcsec))
        if len(match_idxs) == 0:
            continue
        # print(grp_idx[match_idxs])
        if len(match_idxs) >= 2:
            match_idxs = match_idxs[np.argmin(grp_sep2d[match_idxs])]
        combined_spec_cat["NUMBER"][match_idxs] = grp.catalog["NUMBER"][un]

    grizli_cat = Table.read(f"{root_name}-ir.cat.fits")
    grizli_z_cat = join(grizli_cat, combined_spec_cat, keys="NUMBER", join_type="left")

    if plot:
        with fits.open(
            grizli_extraction_dir.parent / "Prep" / f"{root_name}-ir_drc_sci.fits"
        ) as hdul:
            direct_img = hdul[0].data.copy()
            direct_wcs = WCS(hdul[0].header.copy())
            del hdul

        # mismatched_spec_cat = combined_spec_cat[
        #     # (grp_sep2d>1*u.arcsec)
        #     (grp_sep2d <= 1 * u.arcsec)
        #     # & (grp_sep2d > 0.75 * u.arcsec)
        # ]
        # print(len(mismatched_spec_cat))

        fig, ax = plt.subplots(
            # figsize=(20, 15),
            dpi=600,
            subplot_kw={"projection": direct_wcs},
        )
        ax.imshow(
            direct_img,
            norm=astrovis.ImageNormalize(
                direct_img,
                stretch=astrovis.LogStretch(),
                interval=astrovis.PercentileInterval(99.9),
            ),
            cmap="binary",
        )
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
        ax.scatter(
            grizli_z_cat["ra"],
            grizli_z_cat["dec"],
            transform=ax.get_transform("world"),
            c=grizli_z_cat["zspec"],
            s=1,
            vmin=0.25,vmax=0.35, cmap="rainbow"
        )
        ax.set_xlim(xlim), ax.set_ylim(ylim)

    grizli_specz_cat.write(catalogue_dir / "grizli_specz_matched.fits", overwrite=True)
else:
    grizli_specz_cat = Table.read(catalogue_dir / "grizli_specz_matched.fits")

In [None]:
# from grizli import fitting
# from grizli.pipeline import auto_script

# grizli_specz_cat.sort("FLUX_AUTO")

# specz_dir = Path.cwd() / "specz"
# max_size = 450
# pline = {
#     "kernel": "square",
#     "pixfrac": 1.0,
#     "pixscale": 0.06,
#     "size": 50,
#     "wcs": None,
# }
# args = auto_script.generate_fit_params(
#     pline=pline,
#     field_root=root_name,
#     min_sens=0.0,
#     min_mask=0.0,
#     include_photometry=False,  # set both of these to True to include photometry in fitting
#     use_phot_obj=False,
# )

# for row in grizli_specz_cat[:]:
#     if not np.isfinite(row["zspec"]):
#         continue
#     obj_id = row["NUMBER"]
#     obj_z = row["zspec"]

#     if (specz_dir / "stack" / f"{root_name}_{obj_id:05}.stack.fits").is_file():
#         print (f"{obj_id} exists already.")
#         continue
#     # if obj_id<450 or obj_id>470:
#     #     continue
#     # print (obj_id, obj_z, row["FLUX_AUTO"])
#     # continue
#     # Maximum diagonal extent of detection bounding box, measured from centre
#     # det_diag = np.sqrt((row["XMAX"] - row["XMIN"])**2 + (row["YMAX"] - row["YMIN"])**2)
#     det_halfdiag = np.sqrt(
#         (np.nanmax([row["XMAX"] - row["X"], row["X"] - row["XMIN"]])) ** 2
#         + (np.nanmax([row["YMAX"] - row["Y"], row["Y"] - row["YMIN"]])) ** 2
#     )

#     # pixel scale is half detection
#     # Include factor of 25% to account for blotting and pixelation effects
#     est_beam_size = int(np.nanmin([np.ceil(0.5*1.25*det_halfdiag), max_size]))
#     import shutil
#     try:
#         # print (f"Fetching beams for {obj_id}...")
#         # # beams = grp.get_beams(
#         # #     obj_id,
#         # #     size=est_beam_size, 
#         # #     min_mask=0,
#         # #     min_sens=0,
#         # #     beam_id="A",
#         # # )
#         # mb = multifit.MultiBeam(
#         #     f"beams/{root_name}_{obj_id:05}.beams.fits", fcontam=0.2, min_sens=0.0, min_mask=0, group_name=root_name
#         # )
#         # # mb.fit_trace_shift()
#         # # _ = mb.oned_figure()
#         # #     _ = mb.drizzle_grisms_and_PAs(size=32, scale=0.5, diff=False)
#         # # mb.write_master_fits()
#         shutil.copy2(f"specz/beams/{root_name}_{obj_id:05}.beams.fits", f"{root_name}_{obj_id:05}.beams.fits")
#         # # print (f"Saved beams for {obj_id}.")
#         print (f"Fitting {obj_id}...")

#         _ = fitting.run_all_parallel(
#             obj_id,
#             zr=[obj_z-0.005, obj_z + 0.005],
#             dz=[0.002,0.0002],
#             verbose=True,
#             get_output_data=True,
#             skip_complete=False,
#             save_figures=True,
#         )
#         print("Fit complete, output saved.")
#         [p.unlink() for p in Path.cwd().glob(f"*{obj_id}.beams.fits")]
#         [p.unlink() for p in Path.cwd().glob(f"*{obj_id}.full.png")]
#         [p.unlink() for p in Path.cwd().glob(f"*{obj_id}.log_par")]
#         [p.rename(specz_dir / "stack" / p.name) for p in Path.cwd().glob(f"*{obj_id}.*stack*")]
#     except:
#         print (f"Extraction failed for {obj_id}.")

In [None]:
from grizli import fitting
from grizli.pipeline import auto_script

grizli_specz_cat.sort("FLUX_AUTO")

specz_dir = Path.cwd() / "specz"
specz_dir.mkdir(exist_ok=True, parents=True)
for filetype in ["beams", "full", "1D", "row", "line", "log_par", "stack"]:
    (specz_dir / filetype).mkdir(exist_ok=True, parents=True)

max_size = 450

for row in grizli_specz_cat[:]:
    if not np.isfinite(row["zspec"]):
        continue
    obj_id = row["NUMBER"]
    obj_z = row["zspec"]

    if (specz_dir / "full" / f"{root_name}_{obj_id:05}.full.fits").is_file():
        continue
    # print (obj_id, obj_z, row["FLUX_AUTO"])
    # continue
    # Maximum diagonal extent of detection bounding box, measured from centre
    # det_diag = np.sqrt((row["XMAX"] - row["XMIN"])**2 + (row["YMAX"] - row["YMIN"])**2)
    det_halfdiag = np.sqrt(
        (np.nanmax([row["XMAX"] - row["X"], row["X"] - row["XMIN"]])) ** 2
        + (np.nanmax([row["YMAX"] - row["Y"], row["Y"] - row["YMIN"]])) ** 2
    )

    # pixel scale is half detection
    # Include factor of 25% to account for blotting and pixelation effects
    est_beam_size = int(np.nanmin([np.ceil(0.5*1.25*det_halfdiag), max_size]))

    pline = {
        "kernel": "square",
        "pixfrac": 1.0,
        "pixscale": 0.06,
        "size": int(np.clip(2*est_beam_size*0.06, a_min=3, a_max=30)),
        "wcs": None,
    }
    args = auto_script.generate_fit_params(
        pline=pline,
        field_root=root_name,
        min_sens=0.0,
        min_mask=0.0,
        include_photometry=False,  # set both of these to True to include photometry in fitting
        use_phot_obj=False,
    )

    try:
        if not (specz_dir / "beams" / f"{root_name}_{obj_id:05}.beams.fits").is_file():
        
            print (f"Fetching beams for {obj_id}...")
            beams = grp.get_beams(
                obj_id,
                size=est_beam_size, 
                min_mask=0,
                min_sens=0,
                beam_id="A",
            )
            mb = multifit.MultiBeam(
                beams, fcontam=0.2, min_sens=0.0, min_mask=0, group_name=root_name
            )
            # mb.fit_trace_shift()
            # _ = mb.oned_figure()
            #     _ = mb.drizzle_grisms_and_PAs(size=32, scale=0.5, diff=False)
            mb.write_master_fits()

            print (f"Saved beams for {obj_id}.")
        else:
            shutil.copy2(specz_dir / "beams" / f"{root_name}_{obj_id:05}.beams.fits", f"{root_name}_{obj_id:05}.beams.fits")
        print (f"Fitting {obj_id}...")

        _ = fitting.run_all_parallel(
            obj_id,
            zr=[obj_z-0.005, obj_z + 0.005],
            dz=[0.002,0.0002],
            verbose=True,
            get_output_data=True,
            skip_complete=False,
            save_figures=True,
        )
        print("Fit complete, output saved.")
        for filetype in ["beams", "full", "1D", "row", "line", "log_par", "stack"]:
            [p.rename(specz_dir / filetype / p.name) for p in Path.cwd().glob(f"*{obj_id}.*{filetype}*")]
    except:
        print (f"Extraction failed for {obj_id}.")

In [None]:
from astropy.io import fits
from astropy.table import Table, vstack, join
from astropy.coordinates import match_coordinates_sky, SkyCoord
from astropy.wcs import WCS
import astropy.visualization as astrovis
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np

plot = True

if not (catalogue_dir / "ASTRODEEP_photz.fits").is_file():

    uncover_cat = Table.read(catalogue_dir / "UNCOVER_DR3_SPS_redshift_catalog.fits")

    uncover_cat.rename_columns(["id", "z_ml"], ["id_uncover", "zphot"])

    astrodeep_cat = Table.read(catalogue_dir / "ASTRODEEP-JWST-ABELL2744_photoz.fits")

    astrodeep_cat.rename_columns(["ID", "RA", "DEC"], ["id_astrodeep_zphot", "ra", "dec"])
    # alt_cat.rename_columns(["id", "z_ALT"], ["id_alt", "zspec"])
    # max_offset = 0.1 * u.arcsec

    # use_idx = sep2d > max_offset

    # astrodeep_cat = astrodeep_cat

    # combined_spec_cat = vstack(
    #     [astrodeep_cat[astrodeep_cat["zspec"] > -99], alt_cat[use_idx]]
    # )
    grizli_zspec_cat = Table.read(catalogue_dir / "grizli_specz_matched.fits")

    # print (np.nansum(np.isin(astrodeepgrizli_zspec_cat["id_astrodeep"])))

    # grp_idx, grp_sep2d = grp.catalog.match_to_catalog_sky(uncover_cat)
    grp_idx, grp_sep2d = grp.catalog.match_to_catalog_sky(astrodeep_cat)

    astrodeep_cat["NUMBER"] = np.full_like(
        astrodeep_cat["ra"].shape, -99, dtype=int
    )
    uncover_cat["NUMBER"] = np.full_like(
        uncover_cat["ra"].shape, -99, dtype=int
    )
    # print (grp.catalog.colnames)

    print(len(np.unique(grp_idx)))
    print(grp_idx)
    unique_matches = np.unique(grp_idx)
    for un in unique_matches:
        match_idxs = np.argwhere((grp_idx == un) & (grp_sep2d < 1 * u.arcsec))
        if len(match_idxs) == 0:
            continue
        # print(grp_idx[match_idxs])
        if len(match_idxs) >= 2:
            match_idxs = match_idxs[np.argmin(grp_sep2d[match_idxs])]
        astrodeep_cat["NUMBER"][match_idxs] = grp.catalog["NUMBER"][un]
        # uncover_cat["NUMBER"][match_idxs] = grp.catalog["NUMBER"][un]

    # grizli_cat = Table.read(f"{root_name}-ir.cat.fits")
    del grizli_zspec_cat["zphot"]
    grizli_zphot_cat = join(grizli_zspec_cat, astrodeep_cat["NUMBER", "id_astrodeep_zphot", "zphot"], keys="NUMBER", join_type="left")
    plot=False
    if plot:
        with fits.open(
            grizli_extraction_dir.parent / "Prep" / f"{root_name}-ir_drc_sci.fits"
        ) as hdul:
            direct_img = hdul[0].data.copy()
            direct_wcs = WCS(hdul[0].header.copy())
            del hdul

        # mismatched_spec_cat = combined_spec_cat[
        #     # (grp_sep2d>1*u.arcsec)
        #     (grp_sep2d <= 1 * u.arcsec)
        #     # & (grp_sep2d > 0.75 * u.arcsec)
        # ]
        # print(len(mismatched_spec_cat))

        fig, ax = plt.subplots(
            # figsize=(20, 15),
            dpi=600,
            subplot_kw={"projection": direct_wcs},
        )
        ax.imshow(
            direct_img,
            norm=astrovis.ImageNormalize(
                direct_img,
                stretch=astrovis.LogStretch(),
                interval=astrovis.PercentileInterval(99.9),
            ),
            cmap="binary",
        )
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
        ax.scatter(
            grizli_zphot_cat["ra"],
            grizli_zphot_cat["dec"],
            transform=ax.get_transform("world"),
            c=grizli_zphot_cat["zphot"],
            s=1,
            vmin=0.25,vmax=0.35, cmap="rainbow"
        )
        ax.set_xlim(xlim), ax.set_ylim(ylim)
# else:
    # combined_photz_cat = Table.read(catalogue_dir / "ASTRODEEP_phostz.fits")
    grizli_zphot_cat.write(catalogue_dir / "grizli_photz_matched.fits", overwrite=True)
else:
    grizli_zphot_cat = Table.read(catalogue_dir / "grizli_photz_matched.fits")


In [None]:
grizli_zphot_cat.sort("FLUX_AUTO", reverse=True)

photz_dir = Path.cwd() / "photz"
photz_dir.mkdir(exist_ok=True, parents=True)
for filetype in ["beams", "full", "1D", "row", "line", "log_par", "stack"]:
    (photz_dir / filetype).mkdir(exist_ok=True, parents=True)

max_size = 450

for row in grizli_zphot_cat[:]:
    obj_id = row["NUMBER"]
    obj_z = row["zphot"]

    if (
        grizli_extraction_dir / "specz" / "full" / f"{root_name}_{obj_id:05}.full.fits"
    ).is_file() or (
        photz_dir / "full" / f"{root_name}_{obj_id:05}.full.fits"
    ).is_file():
        continue
    if obj_id==2409:
        continue
    # Maximum diagonal extent of detection bounding box, measured from centre
    # det_diag = np.sqrt((row["XMAX"] - row["XMIN"])**2 + (row["YMAX"] - row["YMIN"])**2)
    det_halfdiag = np.sqrt(
        (np.nanmax([row["XMAX"] - row["X"], row["X"] - row["XMIN"]])) ** 2
        + (np.nanmax([row["YMAX"] - row["Y"], row["Y"] - row["YMIN"]])) ** 2
    )

    # pixel scale is half detection
    # Include factor of 25% to account for blotting and pixelation effects
    est_beam_size = int(np.nanmin([np.ceil(0.5 * 1.25 * det_halfdiag), max_size]))

    pline = {
        "kernel": "square",
        "pixfrac": 1.0,
        "pixscale": 0.06,
        "size": int(np.clip(2*est_beam_size*0.06, a_min=3, a_max=30)),
        "wcs": None,
    }
    args = auto_script.generate_fit_params(
        pline=pline,
        field_root=root_name,
        min_sens=0.0,
        min_mask=0.0,
        include_photometry=False,  # set both of these to True to include photometry in fitting
        use_phot_obj=False,
    )

    try:
        if not (photz_dir / "beams" / f"{root_name}_{obj_id:05}.beams.fits").is_file():
        
            print (f"Fetching beams for {obj_id}...")
            beams = grp.get_beams(
                obj_id,
                size=est_beam_size, 
                min_mask=0,
                min_sens=0,
                beam_id="A",
            )
            mb = multifit.MultiBeam(
                beams, fcontam=0.2, min_sens=0.0, min_mask=0, group_name=root_name
            )
            mb.write_master_fits()

            print (f"Saved beams for {obj_id}.")
        else:
            shutil.copy2(photz_dir / "beams" / f"{root_name}_{obj_id:05}.beams.fits", f"{root_name}_{obj_id:05}.beams.fits")

        print(f"Fitting {obj_id}...")

        if np.isfinite(obj_z):
            zr = [obj_z * 0.9, obj_z * 1.1]
        else:
            zr = [0, 5.0]

        _ = fitting.run_all_parallel(
            obj_id,
            zr=zr,
            dz=[0.01, 0.001],
            verbose=True,
            get_output_data=True,
            skip_complete=False,
            save_figures=True,
        )
        print("Fit complete, output saved.")
        for filetype in ["beams", "full", "1D", "row", "line", "log_par", "stack"]:
            [
                p.rename(photz_dir / filetype / p.name)
                for p in Path.cwd().glob(f"*{obj_id}.*{filetype}*")
            ]
    except Exception as e:
        print(f"Extraction failed for {obj_id}: {e}")