In [2]:
import astropy.table as at
from astropy.io import fits
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from bayesn import SEDmodel
import pandas as pd
import jax.random as jr
from numpyro.infer import MCMC, NUTS

In [None]:
head_table = at.Table(fits.open(f"/global/cfs/cdirs/lsst/www/jolteon/data/FINAL2/JOLTEON_FINAL_0000_HEAD.FITS")[1].data)
phot_table = at.Table(fits.open(f"/global/cfs/cdirs/lsst/www/jolteon/data/FINAL2/JOLTEON_FINAL_0000_PHOT.FITS")[1].data)

In [None]:
#option 1:

for sn in head_table:
    snid = sn['SNID']
    if 'SPECZ' in sn.colnames and sn['SPECZ'] > 0:
        redshift = sn['SPECZ']
    ptrmin = sn['PTRMIN']
    ptrmax = sn['PTRMAX']
    if 'evb_mnv' in sn.colnames:
        ebv_mv = sn['ebv_mv']
    start_idx = ptrmin - 1 
    end_idx   = ptrmax     
    lightcurve = phot_table[start_idx:end_idx]  
    
print(snid, len(lightcurve))

filt_map = {'g': 'g_DES', 'i': 'i_DES', 'z': 'z_DES'} 
model = SEDmodel(load_model='85_day_model.YAML')

def peak_mag_bayesn(lightcurve, z, filt_map, ebv_mw=1/31): 
    df = lightcurve.to_pandas()
    df = df.rename(columns={'FLUXCAL': 'flux_c', 'FLUXCALERR': 'dflux_c', 'BAND': 'filt', 'MJD': 'MJD'})
    df = df.dropna(subset=['MJD', 'flux_c', 'dflux_c', 'filt'])

    df['filt'] = df['filt'].map(filt_map)

    df['m'] = -2.5 * np.log10(df['flux_c']) + 27.5
    df['dm'] = np.abs(-2.5 * df['dflux_c'] / (np.log(10) * df['flux_c']))
    df = df.dropna(subset=['m', 'dm'])
    peak_mjd = df['MJD'].iloc[df['m'].argmin()]

    samples, sn_props = model.fit(
        df['MJD'], df['m'], df['dm'], df['filt'],
        z=z, peak_mjd=peak_mjd,
        ebv_mw=ebv_mw,
        filt_map=filt_map,
        mag=True
    )
    peak_mag = sn_props.get('m_b', None)
    return peak_mag


sn_results = []  #SNID, z, peak_mag, type, etc.
for sn in head_table:
    snid = sn['SNID'] #to be changed
    if 'SPECZ' in sn.colnames and sn['SPECZ'] > 0: #to be changed
        redshift = sn['SPECZ']
    
    start_idx = sn['PTRMIN'] - 1
    end_idx   = sn['PTRMAX']
    lightcurve = phot_table[start_idx:end_idx]
    
    try:
        peak_mag = peak_mag_bayesn(lightcurve, z, filt_map)
    except Exception as e:
        print(f"Fit failed for SNID {snid}: {e}")
        continue

    
    sn_type = None #maybe?
 
    sn_results.append({
        "SNID": snid,
        "redshift": float(redshift),
        "peak_mag": float(peak_mag) if peak_mag is not None else None,
        "type": sn_type
    })


plt.hist([r['peak_mag'] for r in sn_results if r['peak_mag'] is not None], bins=30)
plt.xlabel('Peak magnitude (m_b)')
plt.ylabel('Number of SNe')
plt.show()

In [None]:
#option 2:

filt_map = {'g': 'g_DES', 'i': 'i_DES', 'z': 'z_DES'}
model    = SEDmodel(load_model='85_day_model.YAML')

def run_hierarchical_dust_fit(head_table, phot_table, filt_map):
    ebv_mw_array = np.array([
        sn['ebv_mv'] if 'ebv_mv' in head_table.colnames else 1/31
        for sn in head_table
    ])

    obs, weights = model.process_dataset(
        head_table, phot_table,
        ebv_mw=ebv_mw_array,
        filter_list=list(filt_map.values())
    )

    kernel = NUTS(model.dust_model)
    mcmc   = MCMC(kernel, num_warmup=1000, num_samples=2000)
    mcmc.run(jr.PRNGKey(0), obs, weights)
    samples = mcmc.get_samples()

    mb_med = np.median(samples['m_b'], axis=0)
    av_med = np.median(samples['AV'], axis=0)
    rv_med = np.median(samples['Rv'], axis=0)

    return pd.DataFrame({
        'SNID':    head_table['SNID'],
        'm_b_med': mb_med,
        'AV_med':  av_med,
        'RV_med':  rv_med,
        'z':       head_table['SPECZ']    
    })

results_df = run_hierarchical_dust_fit(head_table, phot_table, filt_map)
print(results_df.head())


plt.hist(results_df['m_b_med'], bins=30)
plt.xlabel('Peak magnitude (m_b)')
plt.ylabel('Number of SNe')
plt.show()


sn_results = []
for _, row in results_df.iterrows():
    sn_results.append({
        "SNID":     row["SNID"],
        "redshift": float(row["z"]),     
        "peak_mag": float(row["m_b_med"]),
        "type":     None                   
    })


In [None]:
bin_width = 0.05
binned_peaks = {}

for result in sn_results:
    z = result["redshift"]
    if z is None:
        continue

    bin_index = int(z / bin_width) 
    bin_z_lower = bin_index * bin_width 
    if result["peak_mag"] is not None:
        binned_peaks.setdefault(bin_z_lower, []).append(result["peak_mag"])


def check_bin(example_bins = [0.0, 0.25, 0.5, 1.0])
    for b in example_bins:
        if b in binned_peaks:
            print(f"Bin z=[{b:.2f}, {b+bin_width:.2f}): {len(binned_peaks[b])} SNe")
    return example_bins

In [None]:
bin_median_peak = {}
for bin_z, peak_list in binned_peaks.items():
    if len(peak_list) == 0:
        continue
    median_mag = float(np.median(peak_list))
    bin_median_peak[bin_z] = median_mag
    

def check_median(n=5):
    for b in sorted(bin_median_peak)[:n]: 
        print(f"Bin starting {b:.2f}: median peak mag = {bin_median_peak[b]:.2f}")
    return n

In [None]:
candidates = []
threshold = 1.5 #ab mags

for result in sn_results:
    m_peak = result["peak_mag"]
    z = result["redshift"]
    if m_peak is None or z is None:
        continue  
    bin_index = int(z / bin_width)
    bin_z_lower = bin_index * bin_width
    if bin_z_lower not in bin_median_peak:
        continue  
    median_mag = bin_median_peak[bin_z_lower]
   
    if m_peak < median_mag - mag_threshold:
        diff = median_mag - m_peak
        candidate_info = {
            "SNID": result["SNID"],
            "redshift": result["redshift"],
            "peak_mag": result["peak_mag"],
            "type": result["type"],
            "mag_diff": diff
        }
        lensed_candidates.append(candidate_info)

In [None]:
if candidates:
    for c in candidates:
        snid = c["SNID"]
        z = c["redshift"]
        peak = c["peak_mag"]
        mag_excess = c["mag_diff"]  #how much brighter than median
        sn_type = c.get("type", None)

        type_str = f"Type={sn_type}" if sn_type is not None else "Type=Unknown"
        print(f"SNID {snid}: z={z:.3f}, PeakMag={peak:.2f}, {type_str}, Δm={mag_excess:.2f} mag brighter than median")
