# Working with Superphot+

Superphot+ was designed to rapidly fit photometric SN-like light curves to an empirical model for subsequent classification or analysis.
This tutorial briefly covers how to import light curves directly from ALeRCE or ANTARES, apply pre-processing for improved quality, and run various sampling methods to fit the light curves.

## Data Importing and Preprocessing

Superphot+ is built on SNAPI (https://github.com/kdesoto-astro/snapi), which provides functionality for easy importing and pre-processing of photometric data from various alert brokers. Here, we use SNAPI to import a ZTF light curve from ALeRCE.

In [None]:
# All imports/filepaths for this section
!pip list
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from snapi.query_agents import ALeRCEQueryAgent, TNSQueryAgent
from snapi import Formatter, Photometry, Transient

test_sn = "2023lkw"
p = Path(os.getcwd()).parents[1]
SAVE_DIR = os.path.join(p, "data", "tutorial")
print(SAVE_DIR)


In [None]:
alerce_agent = ALeRCEQueryAgent()
tns_agent = TNSQueryAgent(db_path=SAVE_DIR)
transient = Transient(iid=test_sn)
qr_tns, success = tns_agent.query_transient(transient, local=True) # we dont want spectra
for result in qr_tns:
    transient.ingest_query_info(result.to_dict())
print(transient.internal_names)
qr_alerce, success = alerce_agent.query_transient(transient)
for result in qr_alerce:
    transient.ingest_query_info(result.to_dict())

In [None]:
# plot imported LC
photometry = transient.photometry
formatter = Formatter()
fig, ax = plt.subplots()

photometry.plot(ax, mags=False)
formatter.add_legend(ax)
formatter.make_plot_pretty(ax)
plt.show()

Here, the nondetections are marked as semi-transparent upper-limits, with the detections shown opaque wih uncertainty margins. Superphot+ currently only works with detections.

Let's phase/normalize the light curve and correct for extinction. Note that, for the Superphot+ samplers, the photometry MUST be phased and normalized to be fit correctly.

In [None]:
photometry.phase(inplace=True)
photometry.normalize(inplace=True)
photometry.correct_extinction(coordinates=transient.coordinates, inplace=True)
transient.photometry = photometry

fig, ax = plt.subplots()

photometry.plot(ax, mags=False)
formatter.add_legend(ax)
formatter.make_plot_pretty(ax)
plt.show()


Now let's save this file for later use:

In [None]:
transient.save(
    os.path.join(SAVE_DIR, test_sn)
)

For our numpyro samplers, we need to pad all bands to have the same number of points. To do this, we create a padded variant of our transient photometry:

In [None]:
transient = Transient.load(
    os.path.join(SAVE_DIR, test_sn),
)

n_pad = int(2**np.ceil(np.log2(len(transient.photometry.detections))))

padded_lcs = set()
fill = {'phase': 1000., 'flux': 0.1, 'flux_unc': 1000., 'zpt': 23.90, 'non_detections': False}

padded_lcs = []
for lc in transient.photometry.light_curves:
    padded_lc=lc.pad(fill, n_pad - len(lc.detections))
    padded_lcs.append(padded_lc)

padded_photometry = Photometry.from_light_curves(padded_lcs)
transient.photometry = padded_photometry
transient.save(
    os.path.join(SAVE_DIR, test_sn+"_padded")
)
formatter = Formatter()
fig, ax = plt.subplots()
padded_photometry.plot(ax, mags=False)
formatter.add_legend(ax)
formatter.make_plot_pretty(ax)
plt.show()


## Fitting Light Curves

In [None]:
# All imports for this section
import os
from pathlib import Path

import matplotlib.pyplot as plt
from snapi import Transient, Formatter, SamplerResult

from superphot_plus.samplers import DynestySampler, NUTSSampler, SVISampler
from superphot_plus.priors import generate_priors, SuperphotPrior

p = Path(os.getcwd()).parents[1]
SAVE_DIR = os.path.join(p, "data", "tutorial")

test_sn = "2023lkw"

There are a few sampling techniques implemented for rapid fitting of light curves:
* Nested sampling (`dynesty`) constrains the posterior space with nested ellipsoids of increasing density.
* Advanced HMC with the NUTS sampler (using `numpyro`) uses Hamiltonian Monte Carlo sampling but without U-turns to increase sampling efficiency.
* Stochastic variational inference (SVI; also using `numpyro`) approximates the marginal distributions for each fit as Gaussians, which sacrifices precision for much faster runtime. Recommended for realtime applications.

Let's use each to fit our test light curve:

In [None]:
fn_to_fit = os.path.join(SAVE_DIR, test_sn)
transient = Transient.load(fn_to_fit)
photometry = transient.photometry
priors = generate_priors(["ZTF_r","ZTF_g"])

In [None]:
%%time

sampler = DynestySampler(
    priors=priors,
    random_state=42,
    verbose=True,
    nlive=100,
)
sampler.fit_photometry(photometry)
sampler.result.save(
    os.path.join(SAVE_DIR, test_sn+f"_{sampler.name}")
)
print("Nested sampling")
print(sampler.result.fit_parameters.head())

In [None]:
%%time

# use padded photometry
padded_fn = os.path.join(SAVE_DIR, test_sn + "_padded")
pad_transient = Transient.load(padded_fn)
pad_photometry = pad_transient.photometry

sampler = NUTSSampler(
    priors=priors,
    num_chains=1,
    num_warmup=5000,
    num_samples=1000,
    random_state=42
)
sampler.fit_photometry(pad_photometry)
sampler.result.save(
    os.path.join(SAVE_DIR, test_sn+f"_{sampler.name}")
)
print("NUTS")
print(sampler.result.fit_parameters.head())

In [None]:
%%time

# use padded photometry: only necessary if you're planning on running repeatedly on different events
# if one-time fit, feel free to not use padding
padded_fn = os.path.join(SAVE_DIR, test_sn + "_padded")
pad_transient = Transient.load(padded_fn)
pad_photometry = pad_transient.photometry

sampler = SVISampler(
    priors=priors,
    num_iter=10_000,
    random_state=42,
)
sampler.fit_photometry(pad_photometry)
sampler.result.save(
    os.path.join(SAVE_DIR, test_sn+f"_{sampler.name}")
)
print("SVI")
print(sampler.result.fit_parameters.head())

In [None]:
# NEW!!! Hierarchical SVI - uses better priorss
priors = SuperphotPrior.load(os.path.join(SAVE_DIR, "global_priors_hier_svi"))
padded_fn = os.path.join(SAVE_DIR, test_sn + "_padded")
pad_transient = Transient.load(padded_fn)
pad_photometry = pad_transient.photometry

sampler = SVISampler(
    priors=priors,
    num_iter=10_000,
    random_state=42,
)
sampler.fit_photometry(pad_photometry)
sampler.result.save(
    os.path.join(SAVE_DIR, test_sn+f"_{sampler.name}_hierarchical")
)
print("SVI")
print(sampler.result.fit_parameters.head())

Now, let's plot each fit to compare results!

In [None]:
fn_to_fit = os.path.join(SAVE_DIR, test_sn)
transient = Transient.load(fn_to_fit)
photometry = transient.photometry
priors = generate_priors(["ZTF_r","ZTF_g"])

fig, ax = plt.subplots()
formatter = Formatter()
dsampler = DynestySampler(priors=priors)
nsampler = NUTSSampler(priors=priors)
ssampler = SVISampler(priors=priors)
for sampler in [dsampler, nsampler, ssampler]:
    sr = SamplerResult.load(os.path.join(SAVE_DIR, test_sn+f"_{sampler.name}"))
    sampler.load_result(sr)
    ax = sampler.plot_fit(
        ax, formatter,
        photometry,
    )
sr = SamplerResult.load(os.path.join(SAVE_DIR, test_sn+f"_{ssampler.name}_hierarchical"))
ssampler.load_result(sr)
ax = ssampler.plot_fit(
    ax, formatter,
    photometry,
)

photometry.plot(ax, formatter, mags=False)
ax.set_ylim((-0.1, 1.3))
formatter.add_legend(ax)
formatter.make_plot_pretty(ax)
plt.show()


All three fits look decent, differing most in the r-band fit. Note how the hierarchical SVI fit does better than the one from default priors for the same number of training iterations. Let's examine the plateau durations for each sampler:

In [None]:
for sampler_name in ["superphot_dynesty", "superphot_svi", "superphot_nuts", "superphot_svi_hierarchical"]:
    sr = SamplerResult.load(os.path.join(SAVE_DIR, test_sn+f"_{sampler_name}"))
    fit_params = sr.fit_parameters
    # param naming convention: {paramname}_{filter}
    plt.hist(fit_params["gamma_ZTF_r"], alpha=0.5, label=sampler_name, density=True)

plt.xlabel("gamma_ZTF_r")
plt.legend()
plt.show()

## Classification

Superphot+ uses the resulting fit parameters as input features for a LightGBM classifier. We can call the model's evaluate() function to return probabilities of the object being each of 5 major supernova types. But first, we convert our auxiliary-band and log-Gaussian parameters back to Gaussian relative values, for better normalization before classification.

In [None]:
import os
import numpy as np
from superphot_plus.priors import generate_priors
from superphot_plus.supernova_class import SupernovaClass
from superphot_plus.model import SuperphotLightGBM
from snapi import SamplerResult
from pathlib import Path

p = Path(os.getcwd()).parents[1]
SAVE_DIR = os.path.join(p, "data", "tutorial")
test_sn = "2023lkw"


In [None]:

model_fn = os.path.join(SAVE_DIR, "model_superphot_full.pt")
# if from the full train workflow notebook, use below:
#model_fn = os.path.join(SAVE_DIR, "models/model_superphot_svi_LightGBM_None_False_25_None_10_100_42_0.pt")
full_model = SuperphotLightGBM.load(model_fn)

priors = generate_priors(["ZTF_r","ZTF_g"])
sr = SamplerResult.load(os.path.join(SAVE_DIR, test_sn+f"_superphot_nuts"))

# convert fit parameters back to uncorrelated Gaussian draws
uncorr_fits = priors.reverse_transform(sr.fit_parameters)

# fix index for groupby() operations within model.evaluate()
uncorr_fits.index = [test_sn,] * len(uncorr_fits)

# exclude A_ZTF_r and t_0_ZTF_r - that's how the model was initially trained!
uncorr_fits.drop(
    columns=['A_ZTF_r', 't_0_ZTF_r'],
    inplace=True
)
print(uncorr_fits.median(axis=0))

# use full-phase classifier
probs_avg = full_model.evaluate(uncorr_fits)
probs_avg.columns = np.sort(['SN Ia', 'SN Ibc', 'SN II', 'SN IIn', 'SLSN-I'])
# if modeled was trained with SN Ibn class, use below:
#probs_avg.columns = np.sort(['SN Ia', 'SN Ibc', 'SN II', 'SN IIn', 'SLSN-I', 'SN Ibn'])
# get predicted SN class and output probability of that classification
pred_class = probs_avg.idxmax(axis=1).iloc[0]
print(probs_avg, pred_class)

Looking at the TNS page, we see that 2023lkw is indeed a SN II. For an example on how to use the built-in SuperphotTrainer class, see the full train workflow notebook. To include more SN classes, please train a new model.