In [None]:
import astropy
from astropy.table import Table
from astropy.io import fits
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from bayesn import SEDmodel
import pandas as pd
import jax.random as jr
from numpyro.infer import MCMC, NUTS
from tqdm import tqdm
import sncosmo
import emcee
import nestle
import corner
print(np.__version__)
from astropy.cosmology import FlatLambdaCDM
import astropy.units as u
from dustmaps.sfd import SFDQuery
from astropy.coordinates import SkyCoord

In [None]:
def obtain_data(head_table, idx):
    snid = head_table[idx]['SNID']
    ra = head_table[idx]['RA']
    dec = head_table[idx]['DEC']
    z = head_table[idx]['PHOTOZ']
    ptr_min = head_table[idx]['PTROBS_MIN']
    ptr_max = head_table[idx]['PTROBS_MAX']
    lc = phot_table[ptr_min-1:ptr_max]
    peak_mjd = lc['MJD'].iloc[lc['SIM_MAGOBS'].argmin()]

    coords_sn = SkyCoord(ra * u.deg, dec * u.deg, frame='icrs')
    sfd_query = SFDQuery() 
    mw_ebv = sfd_query(coords_sn)

    return snid, ra, dec, z, lc, mw_ebv

In [None]:
class snia:
    def __init__(self, snid, label, ra, dec, z, lc, mw_ebv, lensed=False, noted=False):
        self.snid = snid      #string
        self.label = label    #string
        self.ra = ra          #float
        self.dec = dec        #float
        self.z = z            #float
        self.lc = lc          #astropy Table
        self.mw_ebv = mw_ebv  #float
        self.lensed = lensed  #boolean
        self.noted = noted    #boolean; for marking abnormal SNs

    def calc_mu_exp(H0_value = 73.6, Om0_value = 0.334):
        cosmo = FlatLambdaCDM(H0=H0_value * u.km / u.s / u.Mpc, Om0=Om0_value)
        mu_exp = cosmo.distmod(self.z)
        return mu_exp

    def sn_plot(colors = {'u': '#9467bd', 'g': '#377eb8', 'r': '#4daf4a', 'i': '#e3c530', 'z': '#ff7f00'}, markers = {0: 'o', 1: '^'}):
        fig, ax = plt.subplots(1, 1, figsize=(10,6))
        ax.scatter(self.lc['MJD'], self.lc['SIM_MAGOBS'], c=[colors[band] for band in self.lc['BAND']], alpha = 0.5)
        ax.set_ylim(37,18)
    
        legend_patches = [mpatches.Patch(color=color, label=band.strip()) for band, color in colors.items()]
        ax.legend(handles=legend_patches, title="Bands", fontsize=14, title_fontsize=16, ncol=3)
        plt.gca().invert_yaxis()
        ax.set_xlabel("MJD", fontsize=18)
        ax.set_ylabel("Apparent Magnitude", fontsize=18)
        ax.set_title(f'JOLTEON ID: {self.label}', fontsize=24)
        fig.tight_layout()
        fig.show()

    def sncosmo_fit(alpha=0.14, beta=3.1, M0=-19.0, mw_rv=3.1, filt_map = {'g': 'lsstg','r': 'lsstr','i': 'lssti','z': 'lsstz', 'u': 'lsstu'}):
        def sncosmo_process_data(filt_map_s=filt_map):     
            lcs = self.lc.to_pandas() 
            mapped_filt_s = [filt_map_s.get(f, f) for f in lcs['BAND']]
            lcs['BAND'] = mapped_filt_s

            lcs = lcs[lcs['FLUXCAL'] > 0]
            lcs = lcs[lcs['FLUXCALERR'] > 0]
            lcs['m'] = -2.5 * np.log10(lcs['FLUXCAL']) + 27.5
            lcs['dm'] = np.abs(-2.5 * lcs['FLUXCALERR'] / (np.log(10) * lcs['FLUXCAL']))
            lcs = lcs.dropna(subset=['m', 'dm', 'BAND'])

            mask_s = np.isfinite(lcs['m']) & np.isfinite(lcs['dm']) & (lcs['BAND'] != '')
            lcs = lcs[mask_s]
    
            relative_mjd_s = lcs['MJD'] - self.peak_mjd
            time_window_mask_s = (relative_mjd_s >= -20) & (relative_mjd_s <= 50)
            lcs = lcs[time_window_mask_s]
            return lcs

        lcs = sncosmo_process_data(self.lc)
        lcs = lcs.rename(columns={'FLUXCAL': 'FLUX', 'FLUXCALERR': 'FLUXERR'})  #rename the columns so sncosmo knows what im doing
        lcs = Table.from_pandas(lcs)   #convert pandas dataframe into astropy Table that sncosmo requires; "s" for sncosmo
        new_column_values = ["ab"] * len(lcs)   #add magnitude system so sncosmo knows what im doing
        lcs['zpsys'] = new_column_values        #only works after lc has been converted into an astropy Table

        model = sncosmo.Model(source='salt3', effects=[sncosmo.OD94Dust()], effect_names=['mw'], effect_frames=['obs'])
        model.set(z=self.z)
        model.set(t0=self.peak_mjd)
        result, fitted_model = sncosmo.nest_lc(lcs, model,['x0', 'x1', 'c'], fixed={'mw_ebv': self.mw_ebv, 'mw_rv': mw_rv}, guess_ampitude_bound=True, bounds={'x1':(0.5, 2.2), 
                            'c':(-0.4, 0.6),'x0':(-0.003, 0.003)}, guess_z=False, minsnr=0, guess_t0=False)  

        x0_fit = fitted_model.parameters[fitted_model.param_names.index('x0')]
        m_B = -2.5 * np.log10(x0_fit)+10.635-.27
        x1_fit = fitted_model.parameters[fitted_model.param_names.index('x1')]
        c_fit = fitted_model.parameters[fitted_model.param_names.index('c')]
        x0_fit_err = result.errors['x0']
        x1_fit_err = result.errors['x1']
        c_fit_err = result.errors['c']

        def salt2mu(x1=None,x1err=None,
                c=None,cerr=None,
                mb=None,mberr=None, #error of the x0
                cov_x1_c=0,cov_x1_x0=0,cov_c_x0=0,
                alpha=None,beta=None,hostmass=None,
                M=None,x0=None,sigint=None,z=None,peczerr=0.00083,deltam=None):

            sf = -2.5/(x0*np.log(10.0))
            cov_mb_c = cov_c_x0*sf
            cov_mb_x1 = cov_x1_x0*sf
            mu_out = mb + x1*alpha - beta*c + 19.36
            invvars = 1.0 / (mberr**2.+ alpha**2. * x1err**2. + beta**2. * cerr**2. + \
                   2.0 * alpha * (cov_x1_x0*sf) - 2.0 * beta * (cov_c_x0*sf) - \
                   2.0 * alpha*beta * (cov_x1_c) )

            if deltam:
                if len(np.where(hostmass > 10)[0]):
                    mu_out[hostmass > 10] += deltam/2.
                if len(np.where(hostmass < 10)[0]):
                    mu_out[hostmass < 10] -= deltam/2.

            zerr = peczerr*5.0/np.log(10)*(1.0+z)/(z*(1.0+z/2.0))
            muerr_out = np.sqrt(1/invvars + zerr**2. + 0.055**2.*z**2.)
            if sigint: muerr_out = np.sqrt(muerr_out**2. + sigint**2.)
            return(mu_out,muerr_out)

        mu_obs, mu_obs_err = salt2mu(x1=x1_fit, x1err=x1_fit_err, c=c_fit, cerr=c_fit_err, mb=m_B, mberr=x0_fit_err, alpha=alpha, beta=beta, M=M0, x0=x0_fit, z=self.z)

        return result, fitted_model, mu_obs, mu_obs_err

    def bayesn_fit(load_model='BAYESN_1.YAML', filt_map = {'g': 'g_LSST', 'r': 'r_LSST', 'i': 'i_LSST', 'z': 'z_LSST', 'u': 'u_LSST'}):
        lcb = self.lc.to_pandas()
        mapped_filt_b = [filt_map_b.get(f, f) for f in lcb['BAND']]
        lcb['BAND'] = mapped_filt_b

        lcb = lcb[lcb['FLUXCAL'] > 0]
        lcb = lcb[lcb['FLUXCALERR'] > 0]
        lcb['m'] = -2.5 * np.log10(lcb['FLUXCAL']) + 27.5
        lcb['dm'] = np.abs(-2.5 * lcb['FLUXCALERR'] / (np.log(10) * lcb['FLUXCAL']))
        lcb = lcb.dropna(subset=['m', 'dm', 'BAND'])
        mask_b = np.isfinite(lcb['m']) & np.isfinite(lcb['dm']) & (lcb['BAND'] != '')
        lcb = lcb[mask_b]

        relative_mjd_b = lcb['MJD'] - self.peak_mjd
        time_window_mask_b = (relative_mjd_b >= -20) & (relative_mjd_b <= 50)
        lcb = lcb[time_window_mask_b]
    
        model_b = SEDmodel(load_model=load_model)
        samples, sn_props = model_b.fit(
                lcb['MJD'], lcb['m'], lcb['dm'], lcb['BAND'],
                z=self.z, peak_mjd=self.peak_mjd,
                ebv_mw=self.mw_ebv,
                filt_map=filt_map, 
                mag=True)
            


In [2]:
def sn_plot(lc, head_table, colors = {'u': '#9467bd', 'g': '#377eb8', 'r': '#4daf4a', 'i': '#e3c530', 'z': '#ff7f00'}, markers = {0: 'o', 1: '^'}):
    print(head_table['SNID'], head_table['LABEL'])

    fig, ax = plt.subplots(1, 1, figsize=(10,6))
    ax.scatter(lc['MJD'], lc['SIM_MAGOBS'], c=[colors[band] for band in lc['BAND']], alpha = 0.5)
    ax.set_ylim(37,18)
    
    legend_patches = [mpatches.Patch(color=color, label=band.strip()) for band, color in colors.items()]

    ax.legend(handles=legend_patches, title="Bands", fontsize=14, title_fontsize=16, ncol=3)
    plt.gca().invert_yaxis()
    ax.set_xlabel("MJD", fontsize=18)
    ax.set_ylabel("Apparent Magnitude", fontsize=18)
    ax.set_title(f'JOLTEON ID: {head_table["SNID"]}', fontsize=24)
    fig.tight_layout()
    fig.show()

In [None]:
def sncosmo_process_data(lc, peak_mjd_s, filt_map_s = {'g': 'lsstg','r': 'lsstr','i': 'lssti','z': 'lsstz', 'u': 'lsstu'}):     
    lcs = lc.to_pandas() 
    mapped_filt_s = [filt_map_s.get(f, f) for f in lcs['BAND']]
    lcs['BAND'] = mapped_filt_s

    lcs = lcs[lcs['FLUXCAL'] > 0]
    lcs = lcs[lcs['FLUXCALERR'] > 0]
    lcs['m'] = -2.5 * np.log10(lcs['FLUXCAL']) + 27.5
    lcs['dm'] = np.abs(-2.5 * lcs['FLUXCALERR'] / (np.log(10) * lcs['FLUXCAL']))
    lcs = lcs.dropna(subset=['m', 'dm', 'BAND'])

    mask_s = np.isfinite(lcs['m']) & np.isfinite(lcs['dm']) & (lcs['BAND'] != '')
    lcs = lcs[mask_s]
    
    relative_mjd_s = lcs['MJD'] - peak_mjd_s
    time_window_mask_s = (relative_mjd_s >= -20) & (relative_mjd_s <= 50)
    lcs = lcs[time_window_mask_s]

    return lcs

In [None]:
def sncosmo_fit(z, peak_mjd_s, mw_ebv, mw_rv=3.1):
    lcs = lcs.rename(columns={'FLUXCAL': 'FLUX', 'FLUXCALERR': 'FLUXERR'})  #rename the columns so sncosmo knows what im doing
    
    try:
        lcs = Table.from_pandas(lcs)   #convert pandas dataframe into astropy Table that sncosmo requires; "s" for sncosmo
    except:
        pass
    
    new_column_values = ["ab"] * len(lcs)   #add magnitude system so sncosmo knows what im doing
    lcs['zpsys'] = new_column_values        #only works after lc has been converted into an astropy Table

    model = sncosmo.Model(source='salt3', effects=[sncosmo.OD94Dust()], effect_names=['mw'], effect_frames=['obs'])
    model.set(z=z)
    model.set(t0=peak_mjd_s)


    result, fitted_model = sncosmo.nest_lc(lcs, model,['x0', 'x1', 'c'], fixed={'mw_ebv': mw_ebv, 'mw_rv': mw_rv}, guess_ampitude_bound=True, bounds={'x1':(0.5, 2.2), 
                        'c':(-0.4, 0.6),'x0':(-0.003, 0.003)}, guess_z=False, minsnr=0, guess_t0=False)  

    return result, fitted_model
    

In [None]:
def sncosmo_corner(result):
    fig = corner.corner(result.samples, labels=result.vparam_names,sample_weight=result.weights)
    plt.show()

In [None]:
def sncosmo_plot_a(lcs, result, fitted_model):
    sncosmo.plot_lc(lcs, model=fitted_model, errors=result.errors)
    plt.show()

In [None]:
def sncosmo_dist_mod(result, fitted_model):
    x0_fit = fitted_model.parameters[fitted_model.param_names.index('x0')]
    m_B = -2.5 * np.log10(x0_fit)+10.635-.27
    x1_fit = fitted_model.parameters[fitted_model.param_names.index('x1')]
    c_fit = fitted_model.parameters[fitted_model.param_names.index('c')]
    x0_fit_err = result.errors['x0']
    x1_fit_err = result.errors['x1']
    c_fit_err = result.errors['c']

    def salt2mu(x1=None,x1err=None,
            c=None,cerr=None,
            mb=None,mberr=None, #error of the x0
            cov_x1_c=0,cov_x1_x0=0,cov_c_x0=0,
            alpha=None,beta=None,hostmass=None,
            M=None,x0=None,sigint=None,z=None,peczerr=0.00083,deltam=None):

        sf = -2.5/(x0*np.log(10.0))
        cov_mb_c = cov_c_x0*sf
        cov_mb_x1 = cov_x1_x0*sf
        mu_out = mb + x1*alpha - beta*c + 19.36
        invvars = 1.0 / (mberr**2.+ alpha**2. * x1err**2. + beta**2. * cerr**2. + \
               2.0 * alpha * (cov_x1_x0*sf) - 2.0 * beta * (cov_c_x0*sf) - \
               2.0 * alpha*beta * (cov_x1_c) )

        if deltam:
            if len(np.where(hostmass > 10)[0]):
                mu_out[hostmass > 10] += deltam/2.
            if len(np.where(hostmass < 10)[0]):
                mu_out[hostmass < 10] -= deltam/2.

        zerr = peczerr*5.0/np.log(10)*(1.0+z)/(z*(1.0+z/2.0))
        muerr_out = np.sqrt(1/invvars + zerr**2. + 0.055**2.*z**2.)
        if sigint: muerr_out = np.sqrt(muerr_out**2. + sigint**2.)
        return(mu_out,muerr_out)


    mu_obs_s, mu_obs_s_err = salt2mu(x1=x1_fit, x1err=x1_fit_err, c=c_fit, cerr=c_fit_err, mb=m_B, mberr=x0_fit_err, alpha=alpha, beta=beta, M=M0, x0=x0_fit, z=z)
    return mu_obs_s, mu_obs_s_err
    

In [None]:
def bayesn_process_data(lc, peak_mjd_b, filt_map_b = {'g': 'g_LSST', 'r': 'r_LSST', 'i': 'i_LSST', 'z': 'z_LSST', 'u': 'u_LSST'} #for bayesn):
    lcb = lc.to_pandas()
    mapped_filt_b = [filt_map_b.get(f, f) for f in lcb['BAND']]
    lcb['BAND'] = mapped_filt_b

    lcb = lcb[lcb['FLUXCAL'] > 0]
    lcb = lcb[lcb['FLUXCALERR'] > 0]
    lcb['m'] = -2.5 * np.log10(lcb['FLUXCAL']) + 27.5
    lcb['dm'] = np.abs(-2.5 * lcb['FLUXCALERR'] / (np.log(10) * lcb['FLUXCAL']))
    lcb = lcb.dropna(subset=['m', 'dm', 'BAND'])

    mask_b = np.isfinite(lcb['m']) & np.isfinite(lcb['dm']) & (lcb['BAND'] != '')
    lcb = lcb[mask_b]

    relative_mjd_b = lcb['MJD'] - peak_mjd_b
    time_window_mask_b = (relative_mjd_b >= -20) & (relative_mjd_b <= 50)
    lcb = lcb[time_window_mask_b]

    return lcb

In [None]:
def bayesn_fit(lcb, z, peak_mjd_b, mw_ebv, filt_map_b, load_model='BAYESN_1.YAML'):
    
    model_b = SEDmodel(load_model=load_model)
    samples, sn_props = model_b.fit(
        lcb['MJD'], lcb['m'], lcb['dm'], lcb['BAND'],
        z=z, peak_mjd=peak_mjd_b,
        ebv_mw=mw_ebv,
        filt_map=filt_map_b, 
        mag=True)

    return samples, sn_props

In [None]:
def bayesn_dist_mod(samples):
    return np.mean(samples['Ds'])

In [None]:
def calc_mu_exp(z, H0_value = 73.6, Om0_value = 0.334):
    cosmo = FlatLambdaCDM(H0=H0_value * u.km / u.s / u.Mpc, Om0=Om0_value)
    mu_exp = cosmo.distmod(z)
    return mu_exp