# Working with Superphot+

Superphot+ was designed to rapidly fit photometric SN-like light curves to an empirical model for subsequent classification or analysis.
This tutorial briefly covers how to import light curves directly from ALeRCE or ANTARES, apply pre-processing for improved quality, and run various sampling methods to fit the light curves.

## Light curve import

There are a suite of helper functions in `src/data_generation` to import photometric light curves from the ALeRCE or ANTARES servers. We will do both here to compare:

In [None]:
from dustmaps.config import config

config["data_dir"] = "."  # ensure dustmaps path is correct

# from superphot_plus.file_utils import read_single_lightcurve, save_single_lightcurve
import os
from superphot_plus.constants import *  # all hyperparameters/priors for fitting
from superphot_plus.utils import *  # all utility functions
from superphot_plus.import_utils import *
from superphot_plus.data_generation.alerce import *
from superphot_plus.data_generation.antares import *

In [None]:
test_sn = "ZTF22abvdwik"  # can change to any ZTF supernova

For this tutorial, we will save everything in `../examples/outputs/`

In [None]:
OUTPUT_DIR = "../examples/outputs/"
os.makedirs(OUTPUT_DIR, exist_ok=True)
generate_single_flux_file(test_sn, OUTPUT_DIR)

Great! Now let's extract and plot the lightcurve:

In [None]:
import pandas as pd

lc_fn = os.path.join(OUTPUT_DIR, test_sn + ".csv")
df = pd.read_csv(lc_fn)
df

In [None]:
import matplotlib.pyplot as plt

m = df["magpsf"]  # magnitudes
merr = df["sigmapsf"]  # mag errs
t = df["mjd"]  # times
b = df["fid"] - 1  # alter so 0=g, 1=r

plt.errorbar(t[b == 0], m[b == 0], yerr=merr[b == 0], fmt="o", c="g", label="g")
plt.errorbar(t[b == 1], m[b == 1], yerr=merr[b == 1], fmt="^", c="r", label="r")

plt.legend()
plt.xlabel("MJD")
plt.ylabel("Apparent magnitude")
plt.gca().invert_yaxis()

Because our fitting procedure assumes flux units instead of magnitude, we convert using an average zeropoint of 26.3. We also rule out any NaN values, sort the lightcurve, clip bogus LC tails, and apply extinction:

In [None]:
t, f, ferr, b, ra, dec = import_lc(lc_fn)

plt.close()
plt.errorbar(t[b == "g"], f[b == "g"], yerr=ferr[b == "g"], fmt="o", c="g", label="g")
plt.errorbar(t[b == "r"], f[b == "r"], yerr=ferr[b == "r"], fmt="^", c="r", label="r")

plt.legend()
plt.xlabel("MJD")
plt.ylabel("Flux (in arbitrary units)")

We will then save these pre-processed lightcurves as a separate file to be input into the fitting scripts:

In [None]:
from superphot_plus.lightcurve import Lightcurve

lc = Lightcurve(
    times=t,
    fluxes=f,
    flux_errors=ferr,
    bands=b,
    name=test_sn,
)
lc.save_to_file(
    os.path.join(OUTPUT_DIR, test_sn+".npz"),
    overwrite=True,
)

## Fitting Light Curves

There are a few sampling techniques implemented for rapid fitting of light curves:
* Nested sampling (`dynesty`) constrains the posterior space with nested ellipsoids of increasing density.
* Advanced HMC with the NUTS sampler (using `numpyro`) uses Hamiltonian Monte Carlo sampling but without U-turns to increase sampling efficiency.
* Stochastic variational inference (SVI; also using `numpyro`) approximates the marginal distributions for each fit as Gaussians, which sacrifices precision for much faster runtime. Recommended for realtime applications.

Let's use each to fit our test light curve:

In [None]:
from superphot_plus.lightcurve import Lightcurve
from superphot_plus.samplers.dynesty_sampler import DynestySampler
from superphot_plus.samplers.numpyro_sampler import NumpyroSampler
from superphot_plus.surveys.surveys import Survey

fn_to_fit = os.path.join(OUTPUT_DIR, test_sn + ".npz")
lightcurve = Lightcurve.from_file(fn_to_fit)
priors = Survey.ZTF().priors

In [None]:
%%time

sampler = DynestySampler()
posteriors = sampler.run_single_curve(lightcurve, priors=priors, rstate=np.random.default_rng(9876))
posteriors.save_to_file(OUTPUT_DIR)
print("Nested sampling")

In [None]:
%%time

sampler = NumpyroSampler()
posteriors = sampler.run_single_curve(lightcurve, priors=priors, rng_seed=1, sampler="NUTS")
posteriors.save_to_file(OUTPUT_DIR)
print("NUTS")

In [None]:
%%time

sampler = NumpyroSampler()
posteriors = sampler.run_single_curve(lightcurve, priors=priors, rng_seed=1, sampler="svi")
posteriors.save_to_file(OUTPUT_DIR)
print("SVI")

Now, let's plot each fit to compare results!

In [None]:
from superphot_plus.plotting.lightcurves import plot_lc_fit
from IPython import display
from superphot_plus.surveys.surveys import Survey

priors = Survey.ZTF().priors
for method in ["dynesty", "NUTS", "svi"]:
    plot_lc_fit(test_sn, priors.reference_band, priors.ordered_bands, OUTPUT_DIR, OUTPUT_DIR, OUTPUT_DIR, sampling_method=method, file_type="png")

display.Image(os.path.join(OUTPUT_DIR, test_sn + "_dynesty.png"))

In [None]:
display.Image(os.path.join(OUTPUT_DIR, test_sn + "_NUTS.png"))

In [None]:
display.Image(os.path.join(OUTPUT_DIR, test_sn + "_svi.png"))

It looks like there is a tradeoff between fit time and fit quality, though there may be an issues with priors. Plotting the distribution for our differing parameters ($t0$ and $\gamma$), we get:

In [None]:
from superphot_plus.file_utils import get_posterior_samples

params_dynesty = get_posterior_samples(
    test_sn, fits_dir=OUTPUT_DIR, sampler='dynesty'
)[0]
params_NUTS = get_posterior_samples(
    test_sn, fits_dir=OUTPUT_DIR, sampler='NUTS'
)[0]
params_svi = get_posterior_samples(
    test_sn, fits_dir=OUTPUT_DIR, sampler='svi'
)[0]
print(params_dynesty[0])

t0_idx = 3
gamma_idx = 2

plt.hist(params_dynesty[:, t0_idx], alpha=0.5, label="dynesty", density=True)
plt.hist(params_NUTS[:, t0_idx], alpha=0.5, label="NUTS", density=True)
plt.hist(params_svi[:, t0_idx], alpha=0.5, label="SVI", density=True)
plt.xlabel("t0")
plt.legend()
plt.show()

In [None]:
from superphot_plus.surveys.surveys import Survey

ztf_priors = Survey.ZTF().priors
r_priors = ztf_priors.bands["r"]
PRIOR_GAMMA = r_priors.gamma

plt.hist(params_dynesty[:, gamma_idx], alpha=0.5, label="dynesty", density=True)
plt.hist(params_NUTS[:, gamma_idx], alpha=0.5, label="NUTS", density=True)
plt.hist(params_svi[:, gamma_idx], alpha=0.5, label="SVI", density=True)
plt.axvline(PRIOR_GAMMA.mean, c="r", label="Prior")
plt.axvline(PRIOR_GAMMA.mean + PRIOR_GAMMA.std, c="r", linestyle="dashed")
plt.axvline(PRIOR_GAMMA.mean - PRIOR_GAMMA.std, c="r", linestyle="dashed")
plt.xlabel("log gamma")
plt.xlim((0.5, 2))
plt.legend()
plt.show()

## Classification

Superphot+ uses the resulting fit parameters as input features for a multi-layer perceptron (MLP) classifier. We can call the classification functions to return probabilities of the object being each of 5 major supernova types:

In [None]:
from superphot_plus.utils import adjust_log_dists
from superphot_plus.trainer import SuperphotTrainer
from superphot_plus.file_utils import get_posterior_samples

TRAINED_MODEL_FN = os.path.join(OUTPUT_DIR, "model.pt")
TRAINED_CONFIG_FN = os.path.join(OUTPUT_DIR, "model.yaml")
trainer = SuperphotTrainer(
    TRAINED_CONFIG_FN,
    OUTPUT_DIR,
    sampler="dynesty",
    model_type='MLP',
    probs_file=None,
    n_folds=1,
)
trainer.setup_model(load_checkpoint=True)
lc_probs = trainer.classify_single_light_curve(
    test_sn, OUTPUT_DIR, sampler="dynesty"
)
print(lc_probs)
# Alternatively, classify from posterior samples directly
fit_params = get_posterior_samples(test_sn, OUTPUT_DIR, "dynesty")[0]
adj_params = adjust_log_dists(fit_params)
lc_probs2 = trainer.models[0].classify_from_fit_params(adj_params)
print(np.subtract(lc_probs, np.mean(lc_probs2, axis=0)))

## Improvements that need to be made:

* Exploration why variation between dynesty + numpyro fits
* Quantifying minimum number of iters for SVI or warmup samples for NUTS for asymptotic fitting behavior
* Modularizing numpyro script, removal of magic numbers
* Refining plotting file, maybe splitting into separate folder