## Import modules

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sncosmo
from bayesn import SEDmodel
import matplotlib.patches as mpatches
import os
import pandas as pd
import pickle

## New code to add 

In [None]:
def get_flux_and_chi_squared_from_chains(
        self, t, bands, chains, zs, ebv_mws, mag=True, num_samples=None, mean=False, observations_path = None
    ):
        """
        Returns model photometry for posterior samples from BayeSN fits, which can be used to make light curve fit
        plots.

        Parameters
        ----------
        t: array-like
            Array of phases to evaluate model photometry at
        bands: array-like
            List of bandpasses to evaluate model photometry in. Photometry will be
        chain_path: str
            Path to file containing BayeSN fitting posterior samples you wish to obtain photometry for
        zs: array-like
            Array of heliocentric redshifts corresponding to the SNe you are obtaining model fit light curves for.
        ebv_mws: array-like
            Array containing Milky Way extincion values corresponding to the SNe you are obtaining model fit light
            curves for.
        mag: Bool, optional
            Boolean to specify whether you want magnitude or flux data. If True, magnitudes will be returned. If False,
            flux densities (f_lambda) will be returned. Default to True i.e. mag data.
        num_samples: int, optional
            An optional keyword argument to specify the number of posterior samples you wish to obtain photometry for.
            Might be useful in testing if you are looking at lots of SNe, as otherwise this function will take a while
            to generate e.g. photometry for 1000 posterior samples across 1000 SNe. Default to None, meaning that
            photometry will be calculated for all posterior samples in chains provided.
        #Added by AER
        observations_path: str
            Path to file containing the observational data, used to compare to fits.

        Returns
        -------

        flux_grid: jax.numpy.array
            Array of shape (number of SNe, number of posterior samples, number of bands, number of phases to evaluate),
            containing photometry across all SNe, all posterior samples, all bands and at all phases requested.
        #Added by AER
        chi_squared: numpy.array
            Array of shape (number of bands, number of phases to evaluate), containing the chi squared statistic for each
            mean flux we have an observation for.

        """
        if type(chains) == str:
            with open(chains, "rb") as file:
                chains = pickle.load(file)
        
        N_sne = chains["theta"].shape[2]
        if num_samples is None:
            num_samples = chains["theta"].shape[0] * chains["theta"].shape[1]

        if isinstance(zs, float):
            zs = np.array([zs])
        if isinstance(ebv_mws, float):
            ebv_mws = np.array([ebv_mws])

        if mean:
            num_samples = 1

        band_list = isinstance(bands[0], list)
        if band_list:
            max_bands = np.max([len(b) for b in bands])
        else:
            max_bands = len(bands)  

        flux_grid = jnp.zeros((N_sne, num_samples, max_bands, len(t)))
        band_weights = self.band_weights

        print("Getting best fit light curves from chains...")
        for i in tqdm(np.arange(N_sne)):
            if band_list:
                fit_bands = bands[i]
            else:
                fit_bands = bands
            theta = chains["theta"][..., i].flatten(order="F")
            AV = chains["AV"][..., i].flatten(order="F")
            tmax = chains["tmax"][..., i].flatten(order="F")
            if "RV" in chains.keys():
                RV = chains["RV"][..., i].flatten(order="F")
            else:
                RV = None
            mu = chains["mu"][..., i].flatten(order="F")
            eps = chains["eps"][..., i]
            eps = eps.reshape((eps.shape[0] * eps.shape[1], eps.shape[2]), order="F")
            eps = eps.reshape(
                (eps.shape[0], self.l_knots.shape[0] - 2, self.tau_knots.shape[0]),
                order="F",
            )
            eps_full = jnp.zeros(
                (eps.shape[0], self.l_knots.shape[0], self.tau_knots.shape[0])
            )
            eps = eps_full.at[:, 1:-1, :].set(eps)
            del_M = chains["delM"][..., i].flatten(order="F")

            theta, AV, mu, eps, del_M, tmax = theta[:num_samples], AV[:num_samples], mu[:num_samples], \
                                        eps[:num_samples, ...], del_M[:num_samples, ...], tmax[:num_samples, ...]
            if 'RV' in chains.keys():
                RV = RV[:num_samples, ...]
            if mean:
                theta, AV, mu, eps, del_M, tmax = (
                    theta.mean()[None],
                    AV.mean()[None],
                    mu.mean()[None],
                    eps.mean(axis=0)[None],
                    del_M.mean()[None],
                    tmax.mean()[None],
                )

            if self.band_weights is not None:
                self.band_weights = band_weights[i : i + 1, ...]

            lc, lc_err, params = self.simulate_light_curve(
                t,
                theta.shape[0],
                fit_bands,
                theta=theta,
                AV=AV,
                mu=mu,
                tmax=tmax,
                del_M=del_M,
                eps=eps,
                RV=RV,
                z=zs[i],
                write_to_files=False,
                ebv_mw=ebv_mws[i],
                yerr=0,
                mag=mag,
            )
            lc = lc.T
            lc = lc.reshape(num_samples, len(fit_bands), len(t))
            flux_grid = flux_grid.at[i, :, : len(fit_bands), :].set(lc)

        # Added by AER to calculate Chi Squared
        if type(observations_path) == str:
            meta, lcdata = sncosmo.read_snana_ascii(observations_path, default_tablename='OBS')
        
        t_obs = np.unique(lcdata['MJD'] - meta['SEARCH_PEAKMJD'])
        flux = flux_grid.mean(axis=(0, 1))
        chi_squared = np.empty(flux.shape)
        if t_obs.issubset(t):
            print("Getting chi squared...")
            for i, observation in lcdata.iterrows():
                index_t = np.where(t == (observation.MJD-meta['SEARCH_PEAKMJD']))
                index_band = np.where(bands == observation.FLT)
                flux_from_chain = flux[index_band, index_t]
                observed_flux = observation.FLUXCAL
                observed_flux_error = observation.FLUXCALERR
                chi_squared[index_band, index_t] = ((observed_flux-flux_from_chain)**2)/(observed_flux_error**2)
            
        return flux_grid, chi_squared