# GRB Tier-1 Simulation & CDS Exploration

**Goal:** simulate Tier-1 grid (12 GRB cases + background), then auto-generate:
- spectra (log–log),
- CDS 2D and 3D plots,
- HEALPix sky maps (+ zooms),
ready to paste into Google Slides.

This notebook uses the DC3 response/orientation and your SourceInjector workflow.


In [11]:
# Imports
import os
from pathlib import Path
import json
import csv
import numpy as np
import matplotlib.pyplot as plt

import astropy.units as u
from astropy.coordinates import SkyCoord

from cosipy import SpacecraftFile, SourceInjector
from cosipy.util import fetch_wasabi_file
from histpy import Histogram
from threeML import Band, Model, PointSource

# Optional: HEALPix maps
try:
    import healpy as hp
    HAVE_HEALPY = True
except Exception:
    HAVE_HEALPY = False

import h5py  # for HDF5 traversal

# Directories
DATA_DIR = Path("/data01/grb_2").expanduser()
OUT_DIR  = DATA_DIR / "tier1_outputs"
FIG_DIR  = OUT_DIR / "figs"
for d in (DATA_DIR, OUT_DIR, FIG_DIR):
    d.mkdir(parents=True, exist_ok=True)

# Response & Orientation (DC3 examples)
RESPONSE_H5 = DATA_DIR / (
    "SMEXv12.Continuum.HEALPixO3_10bins_log_flat."
    "binnedimaging.imagingresponse.nonsparse_nside8.area."
    "good_chunks_unzip.earthocc.h5"
)
ORI_FILE = DATA_DIR / "20280301_3_month_with_orbital_info.ori"

# Fetch once if missing
if not RESPONSE_H5.exists():
    zip_path = DATA_DIR / RESPONSE_H5.name.replace(".h5", ".zip")
    fetch_wasabi_file(
        "COSI-SMEX/DC2/Responses/"
        "SMEXv12.Continuum.HEALPixO3_10bins_log_flat."
        "binnedimaging.imagingresponse.nonsparse_nside8.area."
        "good_chunks_unzip.earthocc.zip",
        zip_path,
    )
    import shutil
    shutil.unpack_archive(zip_path, DATA_DIR)
    zip_path.unlink(missing_ok=True)

if not ORI_FILE.exists():
    fetch_wasabi_file(
        "COSI-SMEX/DC2/Data/Orientation/20280301_3_month_with_orbital_info.ori",
        ORI_FILE,
    )

print("Response:", RESPONSE_H5)
print("Orientation:", ORI_FILE)

Response: /data01/grb_2/SMEXv12.Continuum.HEALPixO3_10bins_log_flat.binnedimaging.imagingresponse.nonsparse_nside8.area.good_chunks_unzip.earthocc.h5
Orientation: /data01/grb_2/20280301_3_month_with_orbital_info.ori


🧩 2) Tier-1 Definitions (Spectra, Durations, Sky, Flux)

In [12]:
# Spectral sets (Band) — Tier-1 uses S1,S2 (long) and S4,S5 (short)
SPECTRA = {
    "S1": {"alpha": -1.1, "beta": -2.3, "xp_keV": 150.0},  # long, softer
    "S2": {"alpha": -1.0, "beta": -2.3, "xp_keV": 300.0},  # long, typical
    "S4": {"alpha": -0.6, "beta": -2.1, "xp_keV": 600.0},  # short, hard
    "S5": {"alpha": -0.8, "beta": -2.2, "xp_keV": 800.0},  # short, very hard
}

# Duration labels (metadata for filenames; injector uses orientation timeline)
DURATIONS = {
    "SHORT0p8": {"class": "short", "T90_s": 0.8},
    "LONG60":   {"class": "long",  "T90_s": 60.0},
}

# Sky positions (Galactic): plane vs high latitude
SKY = {
    "G1": {"l_deg": 0.0,  "b_deg": 0.0},   # plane (GC)
    "G5": {"l_deg": 45.0, "b_deg": +50.0}, # high-lat north
}

# Flux scale factors for Band.K at 100 keV (low/nominal/high)
FLUX = {"F0p1": 0.1, "F1": 1.0, "F10": 10.0}

# Base normalization and pivot (same as your earlier setup)
K_BASE = 7.56e-4 / (u.cm * u.cm * u.s * u.keV)
PIVOT_KEV = 100.0 * u.keV

🗂️ 3) Build & Save the Tier-1 Manifest (12 GRB cases)

In [13]:
manifest = []

# Short @ high-lat: S4 and S5 × {0.1, 1, 10}
for spec_id in ("S4", "S5"):
    for flux_id, scale in FLUX.items():
        manifest.append({
            "case_id":    f"SHORT_{spec_id}_{flux_id}",
            "duration_id":"SHORT0p8",
            "spec_id":    spec_id,
            "flux_id":    flux_id,
            "flux_scale": scale,
            "sky_id":     "G5",
        })

# Long @ plane: S1 and S2 × {0.1, 1, 10}
for spec_id in ("S1", "S2"):
    for flux_id, scale in FLUX.items():
        manifest.append({
            "case_id":    f"LONG_{spec_id}_{flux_id}",
            "duration_id":"LONG60",
            "spec_id":    spec_id,
            "flux_id":    flux_id,
            "flux_scale": scale,
            "sky_id":     "G1",
        })

print("Tier-1 cases:", len(manifest))
print(manifest[0])

# Save CSV/JSON for reproducibility
CSV_PATH = OUT_DIR / "tier1_manifest.csv"
JSON_PATH = OUT_DIR / "tier1_manifest.json"

with open(CSV_PATH, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=list(manifest[0].keys()))
    writer.writeheader()
    writer.writerows(manifest)

with open(JSON_PATH, "w") as f:
    json.dump(manifest, f, indent=2)

print("Saved:", CSV_PATH)
print("Saved:", JSON_PATH)

Tier-1 cases: 12
{'case_id': 'SHORT_S4_F0p1', 'duration_id': 'SHORT0p8', 'spec_id': 'S4', 'flux_id': 'F0p1', 'flux_scale': 0.1, 'sky_id': 'G5'}
Saved: /data01/grb_2/tier1_outputs/tier1_manifest.csv
Saved: /data01/grb_2/tier1_outputs/tier1_manifest.json


🛠️ 4) Helpers — Model Build, Filenames, Histogram utils, Plotting, HEALPix

In [14]:
# ---------- Model builders (FIXED: PointSource name is positional) ----------
def make_band_from_ids(spec_id: str, flux_scale: float) -> Band:
    s = SPECTRA[spec_id]
    band = Band()
    band.alpha.value = float(s["alpha"])
    band.beta.value  = float(s["beta"])
    band.xp.value    = float(s["xp_keV"])
    band.piv.value   = float(PIVOT_KEV.value)
    band.K.value     = float(K_BASE.value * flux_scale)
    band.xp.unit = u.keV
    band.piv.unit = u.keV
    band.K.unit  = 1/(u.cm*u.cm*u.s*u.keV)
    return band

def make_point_model(row: dict) -> Model:
    sky = SKY[row["sky_id"]]
    src_coord = SkyCoord(l=sky["l_deg"], b=sky["b_deg"], unit="deg", frame="galactic")
    band = make_band_from_ids(row["spec_id"], row["flux_scale"])
    # IMPORTANT: pass the source name as the FIRST POSITIONAL ARG
    ps = PointSource(row["case_id"], l=src_coord.l.deg, b=src_coord.b.deg, spectral_shape=band)
    return Model(ps)

# ---------- Filenames ----------
def out_stem_for(row: dict) -> str:
    dur  = row["duration_id"]
    spec = row["spec_id"]
    flux = row["flux_id"]
    sky  = SKY[row["sky_id"]]
    L = int(round(sky["l_deg"]))
    B = int(round(sky["b_deg"]))
    sign = "+" if B >= 0 else "-"
    return f"GRB_{dur}_{spec}_{flux}_L{L:03d}_B{sign}{abs(B):02d}"

def paths_for(row: dict):
    stem = out_stem_for(row)
    h5   = OUT_DIR / f"{stem}.h5"
    # figures
    fig_spec   = FIG_DIR / f"{stem}_spectrum.png"
    fig_cds2   = FIG_DIR / f"{stem}_cds2d.png"
    fig_cds3   = FIG_DIR / f"{stem}_cds3d.png"
    # sky products
    fits_sky   = OUT_DIR / f"{stem}_sky_map.fits"
    fig_moll   = FIG_DIR / f"{stem}_sky_moll.png"
    fig_gnom_g = FIG_DIR / f"{stem}_sky_gnom_gal.png"
    fig_gnom_eq= FIG_DIR / f"{stem}_sky_gnom_eq.png"
    return dict(
        h5=h5, fig_spec=fig_spec, fig_cds2=fig_cds2, fig_cds3=fig_cds3,
        fits_sky=fits_sky, fig_moll=fig_moll, fig_gnom_g=fig_gnom_g, fig_gnom_eq=fig_gnom_eq
    )

# ---------- Histogram utilities ----------
def project_energy(h: Histogram):
    for name in ("Em", "Energy", "E", "En", "Energies"):
        try: return h.project(name)
        except Exception: pass
    raise RuntimeError("No energy axis found (Em/Energy/E/En/Energies).")

def axis_label(a):
    for k in ("label", "title", "axis_name", "name"):
        if hasattr(a, k): return getattr(a, k)
    return str(a)

def axis_centers(ax):
    for attr in ("centers", "bin_centers", "ticks"):
        if hasattr(ax, attr):
            arr = np.asarray(getattr(ax, attr))
            if arr.size: return arr
    for edges in ("edges", "bins", "bin_edges"):
        if hasattr(ax, edges):
            e = np.asarray(getattr(ax, edges))
            if e.ndim == 1 and e.size >= 2:
                return 0.5 * (e[1:] + e[:-1])
    return None

def hist_values(hobj):
    for attr in ("values", "counts", "y", "z", "data"):
        if hasattr(hobj, attr):
            try:
                arr = np.asarray(getattr(hobj, attr))
                if arr.size: return arr
            except Exception: pass
    try:
        dense = hobj.todense()
        for attr in ("values", "counts", "y", "data"):
            if hasattr(dense, attr):
                arr = np.asarray(getattr(dense, attr))
                if arr.size: return arr
    except Exception: pass
    return None

def cds2d_get_axes(h: Histogram):
    labels = [axis_label(a) for a in h.axes]
    geom = [x for x in labels if x not in ("Em","Energy","E","En","Energies","Time")]
    if len(geom) < 2:
        raise RuntimeError("Could not identify two geometry axes for CDS.")
    return geom[0], geom[1]

# ---------- Plotting ----------
def plot_spectrum_png(hfile: Path, out_png: Path, title: str):
    h = Histogram.open(str(hfile))
    proj = project_energy(h)
    fig, ax = plt.subplots(figsize=(6.5,4.4))
    proj.draw(ax=ax, label=title)
    ax.set_xscale("log"); ax.set_yscale("log")
    ax.set_xlabel("Energy [keV]"); ax.set_ylabel("Counts")
    ax.set_title(title); ax.legend()
    fig.tight_layout(); fig.savefig(out_png, dpi=160); plt.close(fig)

def plot_cds2d_png(hfile: Path, out_png: Path, title: str):
    h = Histogram.open(str(hfile))
    axA, axB = cds2d_get_axes(h)
    try: h_noE = h.integrate("Em")
    except Exception:
        try: h_noE = h.marginalize("Em")
        except Exception: h_noE = h
    h2d = h_noE.project([axA, axB])

    fig, ax = plt.subplots(figsize=(6.3,5.0))
    drew = False
    try:
        h2d.draw(ax=ax); drew = True
    except Exception:
        pass
    if not drew:
        Xc = axis_centers(h2d.axes[0]); Yc = axis_centers(h2d.axes[1]); Z = hist_values(h2d)
        if (Xc is not None) and (Yc is not None) and (Z is not None) and (Z.ndim==2):
            im = ax.imshow(Z, origin="lower", aspect="auto",
                           extent=[Xc.min(), Xc.max(), Yc.min(), Yc.max()])
            fig.colorbar(im, ax=ax, label="Counts")
    ax.set_xlabel(axA); ax.set_ylabel(axB)
    ax.set_title(title)
    fig.tight_layout(); fig.savefig(out_png, dpi=160); plt.close(fig)

def plot_cds3d_png(hfile: Path, out_png: Path, title: str):
    h = Histogram.open(str(hfile))
    axA, axB = cds2d_get_axes(h)
    try: h_noE = h.integrate("Em")
    except Exception:
        try: h_noE = h.marginalize("Em")
        except Exception: h_noE = h
    h2d = h_noE.project([axA, axB])
    Xc = axis_centers(h2d.axes[0]); Yc = axis_centers(h2d.axes[1]); Z = hist_values(h2d)
    if (Xc is None) or (Yc is None) or (Z is None) or (Z.ndim != 2):
        print("[cds3d] Skipped (no centers/values)")
        return
    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
    Xg, Yg = np.meshgrid(Xc, Yc, indexing="ij")
    sx = max(1, Xg.shape[0]//128); sy = max(1, Yg.shape[1]//128)
    fig = plt.figure(figsize=(7.6,5.8))
    ax = fig.add_subplot(111, projection="3d")
    ax.plot_surface(Xg[::sx,::sy], Yg[::sx,::sy], Z[::sx,::sy], linewidth=0, antialiased=False, shade=True)
    ax.set_xlabel(axA); ax.set_ylabel(axB); ax.set_zlabel("Counts")
    ax.set_title(title)
    fig.tight_layout(); fig.savefig(out_png, dpi=160); plt.close(fig)

# ---------- HEALPix reconstruction & figures ----------
def infer_nside(n_pix:int):
    ns = int(round((n_pix/12.0)**0.5))
    return ns if 12*ns*ns == n_pix else None

def hdf5_to_healpix(h5_path: Path):
    with h5py.File(str(h5_path), "r") as f:
        cands = []
        def _walk(g, p=""):
            for k,v in g.items():
                path = f"{p}/{k}" if p else k
                if isinstance(v, h5py.Dataset) and v.ndim==3:
                    cands.append((path, v.shape))
                elif isinstance(v, h5py.Group):
                    _walk(v, path)
        _walk(f)
        cands.sort(key=lambda x: np.prod(x[1]), reverse=True)
        if not cands: raise RuntimeError("No 3D dataset found in HDF5")
        path, shape = cands[0]
        cube = f[path][...].astype(float)

    sky_axis = None; nside = None
    for i,n in enumerate(shape):
        ns = infer_nside(n)
        if ns is not None: sky_axis, nside = i, ns; break
    if sky_axis is None: raise RuntimeError("No HEALPix axis detected in 3D cube")

    axes_to_sum = tuple(j for j in range(cube.ndim) if j != sky_axis)
    sky = np.nansum(cube, axis=axes_to_sum)
    sky[~np.isfinite(sky)] = 0.0
    return nside, sky.astype(np.float32)

def save_sky_figures(h5_path: Path, fits_out: Path, moll_png: Path, gnom_g_png: Path, gnom_eq_png: Path):
    if not HAVE_HEALPY:
        print("[sky] healpy not installed; skipped sky maps")
        return
    nside, sky = hdf5_to_healpix(h5_path)
    hp.write_map(str(fits_out), sky, coord='G', nest=False, dtype=np.float32, overwrite=True)

    idx = int(np.nanargmax(sky))
    lon_b, lat_b = hp.pix2ang(nside, idx, lonlat=True)

    # Mollweide
    plt.figure(figsize=(9.0,5.4))
    hp.mollview(sky, coord='G', title='All-sky (Galactic)', unit='counts',
                min=np.nanpercentile(sky, 15), max=np.nanpercentile(sky, 99.5), cmap='viridis')
    hp.graticule(); hp.projplot(lon_b, lat_b, lonlat=True, marker='*', color='r', markersize=12)
    plt.savefig(moll_png, dpi=160); plt.close()

    # Gnomonic around brightest (Galactic)
    plt.figure(figsize=(7.0,6.0))
    hp.gnomview(sky, rot=(lon_b, lat_b), coord='G', xsize=800, reso=5.0,
                title=f'Zoom (Galactic) l={lon_b:.2f}°, b={lat_b:.2f}°',
                unit='counts', min=np.nanpercentile(sky, 20), max=np.nanpercentile(sky, 99.7), cmap='viridis')
    hp.graticule(); hp.projplot(lon_b, lat_b, lonlat=True, marker='*', color='r', markersize=12)
    plt.savefig(gnom_g_png, dpi=160); plt.close()

    # Equatorial center (try coord transform; fallback if not supported)
    eq = SkyCoord(l=lon_b*u.deg, b=lat_b*u.deg, frame='galactic').icrs
    ra_c, dec_c = eq.ra.deg, eq.dec.deg
    ok = False
    try:
        plt.figure(figsize=(7.0,6.0))
        hp.gnomview(sky, rot=(ra_c, dec_c), coord=['G','C'], xsize=800, reso=5.0,
                    title=f'Zoom (Equatorial) RA={ra_c:.2f}°, Dec={dec_c:.2f}°',
                    unit='counts', min=np.nanpercentile(sky, 20), max=np.nanpercentile(sky, 99.7), cmap='plasma')
        hp.graticule(); hp.projplot(ra_c, dec_c, lonlat=True, marker='*', color='w', markersize=12)
        plt.savefig(gnom_eq_png, dpi=160); plt.close(); ok=True
    except TypeError:
        pass
    if not ok:
        rot = hp.Rotator(coord=['C','G'])
        l_c, b_c = rot(ra_c, dec_c, lonlat=True)
        plt.figure(figsize=(7.0,6.0))
        hp.gnomview(sky, rot=(l_c, b_c), coord='G', xsize=800, reso=5.0,
                    title=f'Zoom (Equatorial center) RA={ra_c:.2f}°, Dec={dec_c:.2f}°',
                    unit='counts', min=np.nanpercentile(sky, 20), max=np.nanpercentile(sky, 99.7), cmap='plasma')
        hp.graticule(); hp.projplot(l_c, b_c, lonlat=True, marker='*', color='w', markersize=12)
        plt.savefig(gnom_eq_png, dpi=160); plt.close()


🚀 5) Run: Inject All Tier-1 Cases + Save Plots & Sky Maps

In [15]:
# Load orientation & injector once
ORI = SpacecraftFile.parse_from_file(ORI_FILE)
INJECTOR = SourceInjector(response_path=RESPONSE_H5)

for row in manifest:
    paths = paths_for(row)

    # Build model and (re)create HDF5
    model = make_point_model(row)
    if paths["h5"].exists():
        paths["h5"].unlink()

    INJECTOR.inject_model(model=model, orientation=ORI, make_spectrum_plot=False, data_save_path=str(paths["h5"]))
    print("[OK] Injected:", paths["h5"].name)

    # Figures
    plot_spectrum_png(paths["h5"], paths["fig_spec"], title=f"{row['case_id']} — Spectrum")
    plot_cds2d_png(paths["h5"],   paths["fig_cds2"],  title=f"{row['case_id']} — CDS 2D")
    plot_cds3d_png(paths["h5"],   paths["fig_cds3"],  title=f"{row['case_id']} — CDS 3D")

    # Sky products (if healpy available)
    save_sky_figures(paths["h5"], paths["fits_sky"], paths["fig_moll"], paths["fig_gnom_g"], paths["fig_gnom_eq"])

print("All Tier-1 injections complete.")

[OK] Injected: GRB_SHORT0p8_S4_F0p1_L045_B+50.h5
[cds3d] Skipped (no centers/values)
[OK] Injected: GRB_SHORT0p8_S4_F1_L045_B+50.h5
[cds3d] Skipped (no centers/values)
[OK] Injected: GRB_SHORT0p8_S4_F10_L045_B+50.h5
[cds3d] Skipped (no centers/values)
[OK] Injected: GRB_SHORT0p8_S5_F0p1_L045_B+50.h5
[cds3d] Skipped (no centers/values)
[OK] Injected: GRB_SHORT0p8_S5_F1_L045_B+50.h5
[cds3d] Skipped (no centers/values)
[OK] Injected: GRB_SHORT0p8_S5_F10_L045_B+50.h5
[cds3d] Skipped (no centers/values)
[OK] Injected: GRB_LONG60_S1_F0p1_L000_B+00.h5
[cds3d] Skipped (no centers/values)








[OK] Injected: GRB_LONG60_S1_F1_L000_B+00.h5
[cds3d] Skipped (no centers/values)
[OK] Injected: GRB_LONG60_S1_F10_L000_B+00.h5
[cds3d] Skipped (no centers/values)
[OK] Injected: GRB_LONG60_S2_F0p1_L000_B+00.h5
[cds3d] Skipped (no centers/values)
[OK] Injected: GRB_LONG60_S2_F1_L000_B+00.h5
[cds3d] Skipped (no centers/values)
[OK] Injected: GRB_LONG60_S2_F10_L000_B+00.h5
[cds3d] Skipped (no centers/values)
All Tier-1 injections complete.


<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

🌌 6) Background-Only Injection (Near-Zero Flux)

In [16]:
# Find smallest allowed K to avoid parameter bound violations
_probe = Band()
K_MIN = getattr(_probe.K, "min_value", 1e-50)

# Build tiny-K model at high latitude G5
row_bg = {
    "case_id":    "BKG_ONLY",
    "duration_id":"LONG60",  # label only
    "spec_id":    "S1",      # alpha/beta used but K tiny
    "flux_id":    "Fmin",    # label only
    "flux_scale": 0.0,
    "sky_id":     "G5",
}

s = SPECTRA["S1"]
zero = Band()
zero.alpha.value = float(s["alpha"]); zero.beta.value = float(s["beta"])
zero.xp.value    = float(s["xp_keV"]); zero.piv.value = float(PIVOT_KEV.value)
zero.K.value     = float(K_MIN)
zero.xp.unit=u.keV; zero.piv.unit=u.keV; zero.K.unit=1/(u.cm*u.cm*u.s*u.keV)

sky = SKY[row_bg["sky_id"]]
src = SkyCoord(l=sky["l_deg"], b=sky["b_deg"], unit="deg", frame="galactic")

# IMPORTANT: PointSource name as positional argument
ps  = PointSource(row_bg["case_id"], l=src.l.deg, b=src.b.deg, spectral_shape=zero)
model_bg = Model(ps)

paths_bg = paths_for(row_bg)
if paths_bg["h5"].exists():
    paths_bg["h5"].unlink()

INJECTOR.inject_model(model=model_bg, orientation=ORI, make_spectrum_plot=False, data_save_path=str(paths_bg["h5"]))
print("[OK] Injected background-only:", paths_bg["h5"].name)

# Figures
plot_spectrum_png(paths_bg["h5"], paths_bg["fig_spec"], title="Background-only — Spectrum")
plot_cds2d_png(paths_bg["h5"],   paths_bg["fig_cds2"],  title="Background-only — CDS 2D")
plot_cds3d_png(paths_bg["h5"],   paths_bg["fig_cds3"],  title="Background-only — CDS 3D")

# Sky (if healpy)
save_sky_figures(paths_bg["h5"], paths_bg["fits_sky"], paths_bg["fig_moll"], paths_bg["fig_gnom_g"], paths_bg["fig_gnom_eq"])


[OK] Injected background-only: GRB_LONG60_S1_Fmin_L045_B+50.h5
[cds3d] Skipped (no centers/values)


<Figure size 900x540 with 0 Axes>

<Figure size 700x600 with 0 Axes>

<Figure size 700x600 with 0 Axes>

📊 7) Quick Overlays (Spectra) & CDS 2D Comparison Panel

In [18]:
# Overlay spectra for a given group (same spectrum & sky, any number of fluxes)
def spectra_overlay_group(rows: list, png_path: Path, title: str):
    fig, ax = plt.subplots(figsize=(7.2,4.6))
    n_plotted = 0
    for row in rows:
        p = paths_for(row)
        if not p["h5"].exists():
            print("[overlay] missing file:", p["h5"])
            continue
        h = Histogram.open(str(p["h5"]))
        try:
            proj = project_energy(h)
            proj.draw(ax=ax, label=row["case_id"])
            n_plotted += 1
        except Exception as e:
            print("[overlay]", row["case_id"], "->", e)
            continue
    if n_plotted == 0:
        print("[overlay] nothing to plot for", title)
        plt.close(fig); return
    ax.set_xscale("log"); ax.set_yscale("log")
    ax.set_xlabel("Energy [keV]"); ax.set_ylabel("Counts")
    ax.set_title(title); ax.legend(fontsize=8)
    fig.tight_layout(); fig.savefig(png_path, dpi=160); plt.close(fig)


# ---- Helpers to robustly get CDS 2D matrices ----

def hist_values_from_draw(h2d, nx=None, ny=None):
    """Last-resort: render with draw() and read the color array back."""
    fig, ax = plt.subplots()
    Z = None
    try:
        h2d.draw(ax=ax)
        for coll in ax.collections:
            if hasattr(coll, "get_array"):
                data = np.asarray(coll.get_array())
                if data.size:
                    Z = data
                    break
        # reshape if we know target grid
        if Z is not None and nx is not None and ny is not None:
            if Z.size == (nx-1)*(ny-1):
                Z = Z.reshape((nx-1, ny-1))
            elif Z.size == nx*ny:
                Z = Z.reshape((nx, ny))
        return Z
    except Exception:
        return None
    finally:
        plt.close(fig)

def cds2d_grid(hfile: Path):
    """Return (axA, axB, Xc, Yc, Z) with robust fallbacks."""
    h = Histogram.open(str(hfile))
    axA, axB = cds2d_get_axes(h)
    try:
        h_noE = h.integrate("Em")
    except Exception:
        try:
            h_noE = h.marginalize("Em")
        except Exception:
            h_noE = h
    h2d = h_noE.project([axA, axB])

    Xc = axis_centers(h2d.axes[0]); Yc = axis_centers(h2d.axes[1]); Z = hist_values(h2d)
    if Z is None and (Xc is not None) and (Yc is not None):
        Z = hist_values_from_draw(h2d, len(Xc), len(Yc))
    return axA, axB, Xc, Yc, Z


# ---- Compare 1–3 CDS maps with shared color scale (robust) ----

def cds2d_compare(rows: list, png_path: Path, title: str):
    """
    rows: list of 1–3 manifest rows (e.g., [low, nominal, high]).
    Computes shared vmin/vmax across whatever is available.
    """
    from matplotlib.gridspec import GridSpec
    if not rows:
        print("[cds2d_compare] no rows provided"); return

    # Build triples only for rows that opened successfully
    triples = []
    for r in rows:
        p = paths_for(r)
        if not p["h5"].exists():
            print("[cds2d_compare] missing file:", p["h5"]); continue
        try:
            axA, axB, Xc, Yc, Z = cds2d_grid(p["h5"])
            triples.append((r["case_id"], axA, axB, Xc, Yc, Z))
        except Exception as e:
            print("[cds2d_compare]", r["case_id"], "->", e)

    if not triples:
        print("[cds2d_compare] nothing to plot for", title); return

    # Collect Z arrays to set shared color scale
    Zlist = [t[5] for t in triples if t[5] is not None]
    if not Zlist:
        print("[cds2d_compare] could not recover Z for any case"); return

    # Flatten and compute robust vmin/vmax
    Zstack = np.hstack([z.ravel() for z in Zlist if z.size > 0])
    # guard against NaNs/Infs
    Zstack = Zstack[np.isfinite(Zstack)]
    if Zstack.size == 0:
        print("[cds2d_compare] Z arrays are non-finite"); return

    vmin = np.nanpercentile(Zstack, 15)
    vmax = np.nanpercentile(Zstack, 99.5)
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
        vmin, vmax = float(np.nanmin(Zstack)), float(np.nanmax(Zstack))
        if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
            print("[cds2d_compare] invalid vmin/vmax"); return

    # Lay out 1–3 panels dynamically
    n = len(triples)
    fig = plt.figure(figsize=(6.0*n + 1.0, 4.8))
    gs = GridSpec(1, n, figure=fig, wspace=0.25)

    for i, (case, axA, axB, Xc, Yc, Z) in enumerate(triples):
        ax = fig.add_subplot(gs[0, i])
        if (Xc is not None) and (Yc is not None) and (Z is not None) and (Z.ndim == 2):
            im = ax.imshow(
                Z, origin="lower", aspect="auto",
                extent=[Xc.min(), Xc.max(), Yc.min(), Yc.max()],
                vmin=vmin, vmax=vmax
            )
            if i == n - 1:
                cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
                cbar.set_label("Counts")
        else:
            ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
        ax.set_title(case, fontsize=9)
        ax.set_xlabel(axA)
        if i == 0:
            ax.set_ylabel(axB)

    fig.suptitle(title)
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    fig.savefig(png_path, dpi=170)
    plt.close(fig)


# ---------- Example calls ----------
# Example 1: short@G5 S4 — all available fluxes in your manifest (works with 1–3)
rows_S4 = sorted(
    [r for r in manifest if (r["spec_id"]=="S4" and r["sky_id"]=="G5")],
    key=lambda r: ["F0p1","F1","F10"].index(r["flux_id"]) if r["flux_id"] in ["F0p1","F1","F10"] else 1
)
spectra_overlay_group(rows_S4, FIG_DIR/"overlay_SHORT_S4_G5.png", "Short S4 @ G5 — Flux overlay")
cds2d_compare(rows_S4, FIG_DIR/"cds_COMPARE_SHORT_S4_G5.png", "CDS 2D — Short S4 @ G5 (available fluxes)")

# Example 2: long@G1 S2 — all available fluxes in your manifest (works with 1–3)
rows_S2 = sorted(
    [r for r in manifest if (r["spec_id"]=="S2" and r["sky_id"]=="G1")],
    key=lambda r: ["F0p1","F1","F10"].index(r["flux_id"]) if r["flux_id"] in ["F0p1","F1","F10"] else 1
)
spectra_overlay_group(rows_S2, FIG_DIR/"overlay_LONG_S2_G1.png", "Long S2 @ G1 — Flux overlay")
cds2d_compare(rows_S2, FIG_DIR/"cds_COMPARE_LONG_S2_G1.png", "CDS 2D — Long S2 @ G1 (available fluxes)")






📦 8) (Optional) Zip Outputs for Sharing

In [19]:
import shutil
ZIP_PATH = OUT_DIR.with_suffix(".zip")
if ZIP_PATH.exists():
    ZIP_PATH.unlink()
shutil.make_archive(str(OUT_DIR), "zip", root_dir=OUT_DIR)
print("Zipped to:", ZIP_PATH)


Zipped to: /data01/grb_2/tier1_outputs.zip


📝 9) Notes (Markdown)

- **Fix applied**: `PointSource(<name>, l=..., b=..., spectral_shape=...)` — name passed positionally.
- Keep response & CDS binning constant across runs (inherent from the response file).
- 3D surfaces may skip if centers/values aren’t available (expected for some histpy variants).
- If `healpy` isn’t installed, sky maps are skipped — `pip install healpy` to enable.
- Figures land in `tier1_outputs/figs/`. Use overlay/compare PNGs directly in Slides.
