In [11]:
import matplotlib.pyplot as plt
import numpy as np
import pickle

def plot_detected_lc_with_model(obs_record, lc_templates, title=None, show_peak=True):
    """
    Plot a single detected GRB light curve and overlay the synthetic template.
    
    Parameters
    ----------
    obs_record : dict
        One element from GRB_ObsDataLC_*.pkl.
    lc_templates : list of dicts
        The loaded template light curves from GRBAfterglow_templates.pkl.
    """
    mjd_obs = obs_record['mjd_obs']
    mag_obs = obs_record['mag_obs']
    snr_obs = obs_record['snr_obs']
    filters = obs_record['filter']
    detected = obs_record['detected']

    # Remove invalid values (those with mag > 40, usually placeholders)
    valid = mag_obs < 40
    mjd_obs = mjd_obs[valid]
    mag_obs = mag_obs[valid]
    snr_obs = snr_obs[valid]
    filters = filters[valid]
    detected = detected[valid]

    fig, ax = plt.subplots(figsize=(9, 5))
    
    filter_colors = {
        'u': 'violet', 'g': 'green', 'r': 'red',
        'i': 'orange', 'z': 'brown', 'y': 'gold'
    }

    # --- Plot observed points ---
    for f in np.unique(filters):
        mask = (filters == f)
        det_mask = mask & detected
        nondet_mask = mask & ~detected

        ax.scatter(mjd_obs[det_mask], mag_obs[det_mask],
                   label=f"{f}-band", color=filter_colors.get(f, 'gray'), s=50)
        ax.scatter(mjd_obs[nondet_mask], mag_obs[nondet_mask],
                   color=filter_colors.get(f, 'gray'), s=15, alpha=0.3, marker='x')

    # --- Overlay synthetic model ---
    lc_indx = obs_record['file_indx']
    t_peak = obs_record['peak_mjd']
    distance = obs_record['distance_Mpc']
    ebv = obs_record['ebv']

    # Distance modulus and extinction correction
    distmod = 5 * np.log10(distance * 1e6) - 5
    from rubin_sim.phot_utils import DustValues
    ax1 = DustValues().ax1

    t_model = lc_templates[lc_indx]['u']['ph']  # time grid (same for all filters)
    mjd_model = t_model + t_peak

    for f in lc_templates[lc_indx].keys():
        if f not in filter_colors:
            continue
        mag = lc_templates[lc_indx][f]['mag']
        mag_corr = mag + distmod + ax1[f] * ebv
        ax.plot(mjd_model, mag_corr, '-', color=filter_colors[f], alpha=0.4)

    # Peak marker
    if show_peak and 'peak_mjd' in obs_record and 'peak_mag' in obs_record and obs_record['peak_mag'] < 40:
        ax.plot(obs_record['peak_mjd'], obs_record['peak_mag'], 'k*', markersize=12, label='Peak')

    # Zoom in
    if detected.any():
        min_mjd = mjd_obs[detected].min() - 1
        max_mjd = mjd_obs[detected].max() + 3
        ax.set_xlim(min_mjd, max_mjd)

    # Axis labels
    ax.invert_yaxis()
    ax.set_xlabel("MJD")
    ax.set_ylabel("Apparent Magnitude")
    ax.grid(True, alpha=0.3)

    # Title
    if title is None:
        ra = obs_record.get('ra', 0.0)
        dec = np.degrees(obs_record.get('dec', 0.0))
        title = f"Detected GRB at RA={ra:.1f}°, Dec={dec:.1f}°"
    ax.set_title(title)
    ax.legend()
    plt.show()


In [12]:
# Load detected events
with open("AllTransient_MetricDetection/GRB_Afterglows/GRB_ObsDataLC_baseline_v4.3.1_10yrs.pkl", "rb") as f:
    detected_lcs = pickle.load(f)

# Load templates
with open("GRBAfterglow_templates.pkl", "rb") as f:
    lc_data = pickle.load(f)
    templates = lc_data['lightcurves']

# Plot
plot_detected_lc_with_model(detected_lcs[0], templates)
plot_detected_lc_with_model(detected_lcs[5], templates, title="Bright GRB example!")


IndexError: boolean index did not match indexed array along axis 0; size of axis is 809 but size of corresponding boolean axis is 57