# Full Train Workflow

Run when you need to train a new classifier from scratch. Will regenerate transient data, refit all samples, and retrain the classifier.

## Step 0: Update configuration file.

In the same folder as this notebook, there is a $\texttt{config.yaml}$ file, which contains all filepaths and configuration options for the training workflow. Please update this now!

The most important filepath arguments are:
* $\texttt{create-dirs}$: Probably keep set to True. Create any data subdirectories that are missing.
* $\texttt{data-dir}$: This is where all generated data is stored. Set to the root directory for all outputs.
* $\texttt{relative-dirs}$: If true, all data for each step is stored within subdirectories of data_dir.
* $\texttt{transient-data-fn}$: This is where all transient data is stored as a TransientGroup. Technically a directory but loaded as a single file. If relative_dirs is True, is created as a subdirectory of data_dir.
* $\texttt{sampler-results-fn}$: Where light curve fits are stored. If relative_dirs is True, is created as a subdirectory of data_dir.
* $\texttt{figs-dir}$: Where all figures are stored (only generated if $\texttt{plot}$ is set to True). If relative_dirs is True, is created as a subdirectory of data_dir.
* $\texttt{models-dir}$: Where all classification models are stored. If relative_dirs is True, is created as a subdirectory of data_dir.

The most important sampling and classifier arguments are:
* $\texttt{sampler}$: Set to either dynesty or svi (all lowercase). SVI is faster but forces the posterior into a multivariate Gaussian.
* $\texttt{model-type}$: Set to either LightGBM (recommended) or MLP.
* $\texttt{use-redshift-features}$: If True, includes peak absolute magnitude and redshift as training features.
* $\texttt{fits-per-majority}$: Oversamples such that the majority class has this many samples fed into the classifier. Minority classes will correspond to more input samples per event. Defaults to 5.
* $\texttt{target-label}$: For binary classification - this is the positive label. Set to None for multiclass classification.
* $\texttt{n-folds}$: Number of K-folds. I usually set to 10.
* $\texttt{num-epochs}$: Number of estimators for LightGBM or number of training epochs for MLP.
* $\texttt{n-parallel}$: Number of threads to parallelize data import + sampling over.
* $\texttt{random-seed}$: For reproducibility.

## Step 1: Generate new TransientSet

Here we will import data from TNS + ALeRCE and generate a new TransientSet, from a list of event names. Names can be from TNS or ZTF.

The below code block will retrieve all spectroscopically classified TNS transients. Feel free to change to your own list of names or import script.

In [None]:
from pathlib import Path
import os
from snapi.query_agents import TNSQueryAgent

p = Path(os.getcwd()).parents[1]
SAVE_DIR = os.path.join(p, "data", "tutorial")
print(SAVE_DIR)

tns_agent = TNSQueryAgent(db_path=SAVE_DIR)
#tns_agent.update_local_database() # IMPORTANT: run this line if first time using SNAPI or if you want to reimport TNS csv
all_names = tns_agent.retrieve_all_names() # only spectroscopically classified
all_names = [x for x in all_names if int(x[:4]) > 2018] # because pre-2019 templates are pretty bad
print(len(all_names), all_names[:5])

The following script will import data for all provided names and generate a TransientGroup object. Will run in parallel across n_cores threads.

For the entire TNS dataset (~16000 events), this takes ~30 minutes on 8 parallel cores.

In [2]:
# First augments the config file to save to your data folder path!
# Feel free to change to any path you want
from superphot_plus.config import SuperphotConfig

config = SuperphotConfig.from_file("config.yaml")
config.update(data_dir=SAVE_DIR)
config.write_to_file('config.yaml')


In [3]:
from superphot_plus.config import SuperphotConfig
from superphot_plus.data_generation import import_all_names

config = SuperphotConfig.from_file("config.yaml")
save_dir = config.transient_data_fn

# import data for all_names from query agents
import_all_names(
    all_names, save_dir,
    checkpoint_freq=512,
    n_cores=config.n_parallel,
    overwrite=True,
    skipped_names_fn = os.path.join(config.data_dir, "skipped_names.txt"),
    tns_db_path=SAVE_DIR,
) # set overwrite=False to continue from where left off

In [4]:
# Let's check the TransientGroup we created!
from snapi import TransientGroup
from superphot_plus.config import SuperphotConfig

config = SuperphotConfig.from_file("config.yaml")
transient_group = TransientGroup.load(config.transient_data_fn)
transient_group.add_col('peak_abs_mag', lambda x: x.peak_abs_mag)
print(len(transient_group.metadata))
print(transient_group.metadata.head())
print(transient_group.metadata.groupby('spec_class').count())

## Step 2 (Option 1): Fit all transients using SVI (faster)

Here, we choose to fit our transients using stochastic variational inference (SVI). If using this option, make sure sampler='superphot_svi' in the config.yaml file. This option is faster but assumes Gaussianity of the posterior space, which can be limiting for certain light curve fits.

For all ~8300 TNS transients passing quality cuts, this takes ~30 minutes.

In [6]:
from snapi.scripts import fit_transient_group
from snapi import TransientGroup, SamplerResultGroup
from superphot_plus.samplers.numpyro_sampler import SVISampler
from superphot_plus.priors import generate_priors, SuperphotPrior
from superphot_plus.config import SuperphotConfig
import os
from pathlib import Path

p = Path(os.getcwd()).parents[1]
SAVE_DIR = os.path.join(p, "data", "tutorial")

config = SuperphotConfig.from_file("config.yaml")
priors = SuperphotPrior.load('priors/global_priors_hier_svi')

#priors = generate_priors(["ZTF_r","ZTF_g"])
priors = SuperphotPrior.load(SAVE_DIR + "/" + "global_priors_hier_svi")
svi_sampler = SVISampler(
    priors=priors,
    num_iter=10_000,
    random_state=config.random_seed,
)

transient_group = TransientGroup.load(config.transient_data_fn)
print("Transient group loaded")

result = fit_transient_group(
    transient_group,
    sampler = svi_sampler,
    parallelize=True,
    n_parallel=config.n_parallel,
    checkpoint_fn = os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    ),
    checkpoint_freq = 512,
    pad=True,
    overwrite=True # set to False to continue where left off
)
SamplerResultGroup(result).save(
    os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    )
)

In [7]:
# sanity check plot
import os
import matplotlib.pyplot as plt

from snapi import TransientGroup, SamplerResultGroup, Formatter
from superphot_plus.samplers.numpyro_sampler import SVISampler
from superphot_plus.priors import generate_priors
from superphot_plus.config import SuperphotConfig

config = SuperphotConfig.from_file("config.yaml")
transient_group = TransientGroup.load(config.transient_data_fn)
sampler_results = SamplerResultGroup.load(
    os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    )
)

priors = generate_priors(["ZTF_r","ZTF_g"])
svi_sampler = SVISampler(
    priors=priors,
    num_iter=10_000,
    random_state=config.random_seed,
)

names = sampler_results.metadata.index

formatter = Formatter()
for n in names[-5:]:
    t = transient_group[n] # can index like dictionary
    sr = sampler_results[n]
    svi_sampler.load_result(sr)
    
    fig, ax = plt.subplots()
    svi_sampler.plot_fit(
        ax,
        photometry = t.photometry,
        formatter = formatter,
    )
    formatter.reset_colors()
    formatter.reset_markers()
    t.photometry.plot(
        ax,
        mags=False,
        formatter=formatter
    )
    formatter.make_plot_pretty(ax)
    formatter.add_legend(ax)
    formatter.reset_colors()
    formatter.reset_markers()
    
    plt.show()
    

## Step 2a: Hierarchical Bayesian Inference (slowest)
Optimize priors hierarchically for the dataset. Uses subset (~500) of objects.

In [1]:
from snapi.scripts import fit_many_hierarchical
from snapi import TransientGroup, SamplerResultGroup
from superphot_plus.samplers.numpyro_sampler import SVISampler
from superphot_plus.priors import generate_priors
from superphot_plus.config import SuperphotConfig

import warnings
warnings.filterwarnings("ignore")

config = SuperphotConfig.from_file("config.yaml")
priors = generate_priors(
    ["ZTF_r", "ZTF_g"],
    priors_dir='priors',
)

transient_group = TransientGroup.load(config.transient_data_fn)
print(len(transient_group))
print("Transient group loaded")

svi_sampler = SVISampler(
    priors=priors,
    num_events=500,
    num_iter=20_000,
    step_size=0.01,
    random_state=42,
    max_length=100,
)

global_priors, indiv_results = fit_many_hierarchical(
    transient_group,
    500,
    svi_sampler,
    pad=False
)
global_priors[0].save("hierarchical_svi_global_loc")
global_priors[1].save("hierarchical_svi_global_scale")
srg_indiv = SamplerResultGroup(indiv_results)
srg_indiv.save("hierarchical_svi_srg")

In [2]:
from superphot_plus.priors import generate_priors, SuperphotPrior
from snapi import SamplerResult
import os

# convert global SRs to new prior files
all_bands = ["ZTF_r", "ZTF_g",]
priors = generate_priors(
    all_bands,
    priors_dir='priors',
)
global_prior_mu = SamplerResult.load("hierarchical_svi_global_loc")
gauss_prior_mu = priors.reverse_transform(global_prior_mu.fit_parameters)
global_prior_sigma = SamplerResult.load("hierarchical_svi_global_scale")
gauss_prior_sigma = priors.reverse_transform(global_prior_sigma.fit_parameters)

df = priors.dataframe.copy()

all_params = df['param']
df.loc[:, 'mean'] = gauss_prior_mu[all_params].mean(axis=0).to_numpy()
df.loc[:, 'stddev'] = gauss_prior_sigma[all_params].mean(axis=0).to_numpy()

print(df)

global_priors = SuperphotPrior(df)

global_priors.save("priors/global_priors_hier_svi")

## Step 2 (Option 2): Fit light curves using dynesty (slower)

Here, we fit our transient photometry using the dynesty nested sampler. This is slower but does not assume Gaussianity of the posterior space, so can better capture degeneracies between parameters. If you use this, make sure to set sampler=superphot_dynesty in the config.yaml file.

Runtime for 7202 TNS samples: ~200 minutes (3.5 hours)

In [None]:
from snapi.scripts import fit_transient_group
from snapi import TransientGroup, SamplerResultGroup
from superphot_plus.samplers.dynesty_sampler import DynestySampler
from superphot_plus.priors import generate_priors, SuperphotPrior
from superphot_plus.config import SuperphotConfig

from pathlib import Path
import os

p = Path(os.getcwd()).parents[1]
SAVE_DIR = os.path.join(p, "data", "tutorial")

config = SuperphotConfig.from_file("config.yaml")

#priors = generate_priors(["ZTF_r","ZTF_g"])
priors = SuperphotPrior.load(SAVE_DIR + "/" + "global_priors_hier_svi")

transient_group = TransientGroup.load(config.transient_data_fn)
print("Transient group loaded")

priors = generate_priors(["ZTF_r","ZTF_g"], priors_dir='priors')

dynesty_sampler = DynestySampler(
    priors=priors,
    random_state=config.random_seed,
)

result = fit_transient_group(
    transient_group,
    sampler = dynesty_sampler,
    parallelize=True,
    n_parallel=config.n_parallel,
    checkpoint_fn = os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    ),
    checkpoint_freq = 128,
    pad=False,
    overwrite=True, # False to continue from checkpoint
)
SamplerResultGroup(result).save(
    os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    ),
)


In [None]:
# sanity check plot
import os
import matplotlib.pyplot as plt

from snapi import TransientGroup, SamplerResultGroup, Formatter
from superphot_plus.samplers.dynesty_sampler import DynestySampler
from superphot_plus.priors import generate_priors
from superphot_plus.config import SuperphotConfig

config = SuperphotConfig.from_file("config.yaml")
transient_group = TransientGroup.load(config.transient_data_fn)
sampler_results = SamplerResultGroup.load(
    os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    )
)

priors = generate_priors(["ZTF_r","ZTF_g"])

svi_sampler = DynestySampler(
    priors=priors,
    random_state=config.random_seed,
)

names = sampler_results.metadata.index

formatter = Formatter()
for n in names[-5:]: # neweet 
    t = transient_group[n] # can index like dictionary
    sr = sampler_results[n]
    svi_sampler.load_result(sr)
    
    fig, ax = plt.subplots()
    svi_sampler.plot_fit(
        ax,
        photometry = t.photometry,
        formatter = formatter,
    )
    formatter.reset_colors()
    formatter.reset_markers()
    t.photometry.plot(
        ax,
        mags=False,
        formatter=formatter
    )
    formatter.make_plot_pretty(ax)
    formatter.add_legend(ax)

    formatter.reset_colors()
    formatter.reset_markers()
    
    plt.show()
    

## Step 2.5: Convert SamplerResultGroup posteriors back to uncorrelated Gaussians

When sampling, the posteriors are saved as the inputs to our flux model. The Gaussian priors, however, were converted to log-Gaussians and multiplied by base parameters where necessary before being fed into the model function. Therefore, we must revert these log-Gaussian and relative parameters back to their original uncorrelated Gaussian draws before using as classifier inputs. We do this below:

In [2]:
# warning: only run once!
import os
from snapi import SamplerResultGroup, TransientGroup
from superphot_plus.priors import generate_priors
from superphot_plus.config import SuperphotConfig

config = SuperphotConfig.from_file("config.yaml")

priors = generate_priors(["ZTF_r","ZTF_g"])
sampler_results = SamplerResultGroup.load(
    os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    )
)

# also add peak_abs_mag column
transient_group = TransientGroup.load(config.transient_data_fn)

new_sr = []
for i, sr in enumerate(sampler_results):
    if i % 1000 == 0:
        print(f"Converted {i} out of {len(sampler_results)} fits")
    n = sr.id
    transient = transient_group[n]    
    params = priors.reverse_transform(sr.fit_parameters)
    params['peak_abs_mag'] = transient.peak_abs_mag
    sr.fit_parameters = params
    new_sr.append(sr)
    
new_sampler_results = SamplerResultGroup(new_sr)
new_sampler_results.save(config.sampler_results_fn)

print(new_sampler_results.metadata)


In [3]:
# make summary df and all samples
import os
from snapi import SamplerResultGroup, TransientGroup
from superphot_plus.config import SuperphotConfig
import os

config = SuperphotConfig.from_file("config.yaml")
srg = SamplerResultGroup.load(config.sampler_results_fn)
metadata = srg.metadata
metadata.to_csv(os.path.join(config.data_dir, "all_samples.csv"))
print(metadata.head())

## Step 3: Train + evaluate classifier from sampling posteriors

Here we train a classifier with our uncorrelated posterior features. This script will automatically split the data into K-folds, oversample the training and validation sets to even out minority classes, and train either LightGBMs (recommended) or MLPs. If plot is True, metric plots and confusion matrices will also be generated.

In [1]:
from superphot_plus import SuperphotConfig, SuperphotTrainer
import pandas as pd
import os

config = SuperphotConfig.from_file("config.yaml")

# remove A_ZTF_r and t_0_ZTF_r from params used in classification - see paper for details
metadata = pd.read_csv(os.path.join(config.data_dir, "all_samples.csv"), index_col = 0)
keep_cols = metadata.drop(
    columns=['A_ZTF_r_median', 't_0_ZTF_r_median', 'score_median', 'sampler']
).columns
config.input_features = [c.replace("_median", "") for c in keep_cols]
print(config.input_features)

# train classifier
trainer = SuperphotTrainer(config)
trainer.config.input_features = [
    #"A_ZTF_r",
    #"peak_abs_mag",
    "beta_ZTF_r",
    "gamma_ZTF_r",
    "tau_rise_ZTF_r",
    "tau_fall_ZTF_r",
    "extra_sigma_ZTF_r",
    "A_ZTF_g",
    "beta_ZTF_g",
    "gamma_ZTF_g",
    "t_0_ZTF_g",
    "tau_rise_ZTF_g",
    "tau_fall_ZTF_g",
    "extra_sigma_ZTF_g",
]
trainer.run()

Finally, we train a version of the classifier without a test set (aka we use the entire dataset in training or validation). This is what we'll be using to classify a new, disparate dataset.

In addition to the classic full-phase classifier, we train a classifier that only uses early-phase features (excludes plateau durations and fall timescales). This is more effective at classifying partial supernova light curves.

In [None]:
from superphot_plus import SuperphotConfig, SuperphotTrainer
from snapi import TransientGroup, SamplerResultGroup
import os
import pandas as pd

config = SuperphotConfig.from_file("config.yaml")

# remove A_ZTF_r and t_0_ZTF_r from params used in classification - see paper for details
metadata = pd.read_csv(os.path.join(config.data_dir, "all_samples.csv"), index_col = 0)
keep_cols = metadata.drop(
    columns=['A_ZTF_r_median', 't_0_ZTF_r_median', 'score_median', 'sampler']
).columns
config.input_features = [c.replace("_median", "") for c in keep_cols]
print(config.input_features)

transient_group = TransientGroup.load(config.transient_data_fn)
srg = SamplerResultGroup.load(config.sampler_results_fn)

trainer = SuperphotTrainer(config)
trainer.setup_model()
meta_df = trainer.retrieve_transient_metadata(transient_group)
train_df, val_df = trainer.split(meta_df, split_frac=0.1)

train_srg = srg.filter(train_df.index)
val_srg = srg.filter(val_df.index)


# train full-type
trainer.config.input_features = [
    "beta_ZTF_r",
    "gamma_ZTF_r",
    "tau_rise_ZTF_r",
    "tau_fall_ZTF_r",
    "extra_sigma_ZTF_r",
    "A_ZTF_g",
    "beta_ZTF_g",
    "gamma_ZTF_g",
    "t_0_ZTF_g",
    "tau_rise_ZTF_g",
    "tau_fall_ZTF_g",
    "extra_sigma_ZTF_g",
]
trainer.train(0, (train_df, train_srg), (val_df, val_srg))
trainer.models[0].save(config.model_prefix + "_full")
print(trainer.models[0].best_model.feature_name_)

probs_avg = trainer.evaluate(0, (meta_df, srg))
probs_avg[['SLSN-I', 'SN II', 'SN Ia', 'SN Ibc', 'SN IIn']] = probs_avg[['SLSN-I', 'SN II', 'SN Ia', 'SN Ibc', 'SN IIn']].round(2)
probs_avg['max_prob'] = probs_avg[['SLSN-I', 'SN II', 'SN Ia', 'SN Ibc', 'SN IIn']].max(axis=1)
probs_avg.drop(columns=['true_class','fold'], inplace=True)
probs_avg.to_csv("/Users/kdesoto/superphot-plus-data/probs_phot_full.csv")

print("Full results saved")
# train early-type
metadata = pd.read_csv(os.path.join(config.data_dir, "all_samples.csv"), index_col = 0)
keep_cols = metadata.drop(
    columns=['A_ZTF_r_median', 't_0_ZTF_r_median', 'score_median', 'sampler']
).columns
trainer.config.input_features = [c.replace("_median", "") for c in keep_cols if (
        ('tau_fall' not in c) and ('gamma' not in c)
    )
]

trainer.train(1, (train_df, train_srg), (val_df, val_srg))
trainer.models[1].save(config.model_prefix + "_early")
print(trainer.models[1].best_model.feature_name_)

probs_avg = trainer.evaluate(1, (meta_df, srg))
probs_avg.to_csv(config.probs_fn[:-4] + "_early.csv")