# Full Train Workflow

Run when you need to train a new classifier from scratch. Will regenerate transient data, refit all samples, and retrain the classifier.

## Step 0: Update configuration file.

In the same folder as this notebook, there is a $\texttt{config.yaml}$ file, which contains all filepaths and configuration options for the training workflow. Please update this now!

The most important filepath arguments are:
* $\texttt{create-dirs}$: Probably keep set to True. Create any data subdirectories that are missing.
* $\texttt{data-dir}$: This is where all generated data is stored. Set to the root directory for all outputs.
* $\texttt{relative-dirs}$: If true, all data for each step is stored within subdirectories of data_dir.
* $\texttt{transient-data-fn}$: This is where all transient data is stored as a TransientGroup. Technically a directory but loaded as a single file. If relative_dirs is True, is created as a subdirectory of data_dir.
* $\texttt{sampler-results-fn}$: Where light curve fits are stored. If relative_dirs is True, is created as a subdirectory of data_dir.
* $\texttt{figs-dir}$: Where all figures are stored (only generated if $\texttt{plot}$ is set to True). If relative_dirs is True, is created as a subdirectory of data_dir.
* $\texttt{models-dir}$: Where all classification models are stored. If relative_dirs is True, is created as a subdirectory of data_dir.

The most important sampling and classifier arguments are:
* $\texttt{sampler}$: Set to either dynesty or svi (all lowercase). SVI is faster but forces the posterior into a multivariate Gaussian.
* $\texttt{model-type}$: Set to either LightGBM (recommended) or MLP.
* $\texttt{use-redshift-features}$: If True, includes peak absolute magnitude and redshift as training features.
* $\texttt{fits-per-majority}$: Oversamples such that the majority class has this many samples fed into the classifier. Minority classes will correspond to more input samples per event. Defaults to 5.
* $\texttt{target-label}$: For binary classification - this is the positive label. Set to None for multiclass classification.
* $\texttt{n-folds}$: Number of K-folds. I usually set to 10.
* $\texttt{num-epochs}$: Number of estimators for LightGBM or number of training epochs for MLP.
* $\texttt{n-parallel}$: Number of threads to parallelize data import + sampling over.
* $\texttt{random-seed}$: For reproducibility.

If you want to use hierarchical classification with the MLP, you need to also manually fill in the following:
* $\texttt{hierarchy}$: Set to True if you want to use the weighted hierarchical loss function (WHXE). Otherwise, make sure this is set to False.
* $\texttt{class-weights}$: Manually import these values using the class_weights.ipynb notebook to calculate the relevant class weights for the (WHXE). If you are not using this, set it to None.
* $\texttt{graph}$: This is a dictionary of properties of the taxonomic graph. Set each element to null if not using the WHXE.
    * $\texttt{edges}$: A list of 2-element lists containing the edges in the taxonomic graph in use. If not using WHXE, set to None.
    * $\texttt{height}$: The height you want the classifier to go to in terms of the taxonomic tree. Set to None if not using WHXE.
    * $\texttt{root}$: The root node of the tree as a string. This is set to None if not using WHXE.
    * $\texttt{vertices}$: A list of all of the vertices within the tree, including those that are not labels within the dataset. Set to None if not using WHXE. 
    * $\texttt{ignored-leaves}$: A list of the leaf vertices within the tree that do not have any counts for them within class_weights or we otherwise want to ignore. Set to None if not using WHXE.   

## Step 1: Generate new TransientSet

Here we will import data from TNS + ALeRCE and generate a new TransientSet, from a list of event names. Names can be from TNS or ZTF.

The below code block will retrieve all spectroscopically classified TNS transients. Feel free to change to your own list of names or import script.

In [None]:
from pathlib import Path
import os
from snapi.query_agents import TNSQueryAgent

p = Path(os.getcwd()).parents[1]
print(p)
#SAVE_DIR = os.path.join(p, "data", "tutorial") # Use this line if you want to use tutorial data.
SAVE_DIR = os.path.join(p, "data", "whxe")
print(SAVE_DIR)

tns_agent = TNSQueryAgent(db_path=SAVE_DIR)
#tns_agent.update_local_database() # IMPORTANT: run this line if first time using SNAPI or if you want to reimport TNS csv
all_names = tns_agent.retrieve_all_names() # only spectroscopically classified
all_names = [x for x in all_names if int(x[:4]) > 2018] # because pre-2019 templates are pretty bad
print(len(all_names), all_names[:5])

The following script will import data for all provided names and generate a TransientGroup object. Will run in parallel across n_cores threads.

For the entire TNS dataset (~16000 events), this takes ~30 minutes on 8 parallel cores.

In [None]:
# First augments the config file to save to your data folder path!
# Feel free to change to any path you want
from superphot_plus.config import SuperphotConfig

config = SuperphotConfig.from_file("config.yaml")
config.update(data_dir=SAVE_DIR)
config.write_to_file('config.yaml')

In [None]:
from superphot_plus.config import SuperphotConfig
from superphot_plus.data_generation import import_all_names

config = SuperphotConfig.from_file("config.yaml")
save_dir = config.transient_data_fn

# import data for all_names from query agents 
# QUALITY CUTS HAPPEN HERE
import_all_names(
    all_names, save_dir,
    checkpoint_freq=512,
    n_cores=config.n_parallel,
    overwrite=False,
    skipped_names_fn = os.path.join(config.data_dir, "skipped_names.txt"),
    tns_db_path=SAVE_DIR
) # set overwrite=False to continue from where left off

In [None]:
# Let's check the TransientGroup we created!
from snapi import TransientGroup
from superphot_plus.config import SuperphotConfig

config = SuperphotConfig.from_file("config.yaml")
transient_group = TransientGroup.load(config.transient_data_fn)

print(len(transient_group.metadata))
print(transient_group.metadata.head())
print(transient_group.metadata.groupby('spec_class').count())

Finally, before fitting, we want to phase and normalize all the photometry. This is because our samplers expect light curves to already be phased and normalized before fitting.

### STEP 1.5: Hierarchical Counts + Class Weights

If we are using WHXE in an MLP, we need generate a taxonomy to be used in our Weighted Hierachical Cross Entropy Loss function.

We want to use the counts of items following the quality cuts as that is what our model is training on.

##### Substep 1: We want to load in the quality cut data as a CSV file.

In [None]:
from superphot_plus.config import SuperphotConfig
import pandas as pd

config = SuperphotConfig.from_file("config.yaml")
if config.use_hierarchy:
    df = transient_group.metadata
    df.to_csv(f'{config.data_dir}/quality_cut_tns_data.csv')

##### Substep 2: Adjust Weights, Labels, and Tree

Here we need to adjust a few things:

1. Update the weight dictionary to be based on the data following the quality cuts in **import_all_names** above. 
2. Establish we are happy with the adjustment to the tree/taxonomy the user inputted above to desired **height** given in **config.yaml** and develop an updated mapping schema based on the **mapping** schema.

##### Task 2.1 Update Weight Dictionary

If you hadn't originally, go back to the **class_weights.ipynb** file and change the loaded in CSV file into the CSV of the file following quality cuts. Rerun this and update the config.yaml file as appropriate.

In [None]:
from superphot_plus.config import SuperphotConfig
from superphot_plus.model.taxonomy import Taxonomy

config = SuperphotConfig.from_file("config.yaml")
if config.use_hierarchy:
    
    # Copied from output produced in class_weights.ipynb, replace with your own version
  new_weights = {'SN Ia': [0.06716554115714546, 5479],
 'SN II': [0.3011456628477905, 1222],
 'SN IIn': [1.6576576576576576, 222],
 'SN Ia-91T-like': [1.8775510204081634, 196],
 'SN Ic': [2.5734265734265733, 143],
 'SN Ib': [2.852713178294574, 129],
 'SLSN-I': [3.5384615384615383, 104],
 'SN IIb': [3.5728155339805827, 103],
 'SN IIP': [4.088888888888889, 90],
 'TDE': [5.041095890410959, 73],
 'SN Ic-BL': [5.411764705882353, 68],
 'SLSN-II': [7.215686274509804, 51],
 'SN Ia-pec': [9.68421052631579, 38],
 'SN Ia-91bg-like': [9.945945945945946, 37],
 'SN Ibn': [13.62962962962963, 27],
 'SN': [14.153846153846153, 26],
 'SN Iax[02cx-like]': [15.333333333333334, 24],
 'SN Ib/c': [16.727272727272727, 22],
 'SN Ia-CSM': [18.4, 20],
 'SN Ia-SC': [36.8, 10],
 'SN II-pec': [52.57142857142857, 7],
 'SN Ib-pec': [73.6, 5]}

  config.update(class_weights = new_weights)
  config.write_to_file('config.yaml')
  print(config.class_weights)

##### Task 2.2 Adjust the Tree

We need to define the taxonomy class variable to be used as our tree for processing as well as adjust the tree (and relevant values) to the desired height. This all happens in the backend when defining **taxonomy**. Thus, you as the user do not see this. For a sanity check, we create a duplicate version below so that you (as the user) can understand what is being used as the taxonomy.

If you want to adjust the height, you need to go and change it in the **config.yaml** file now. 

In [None]:
from superphot_plus.config import SuperphotConfig
from superphot_plus.model.taxonomy import Taxonomy

config = SuperphotConfig.from_file("config.yaml")

if config.use_hierarchy:
    taxonomy = Taxonomy(config)
    print(taxonomy)

In [None]:
if config.use_hierarchy:    
    taxonomy.draw_graph()

##### Task 2.3 Define Needed Attributes

For the tree we have defined we need to extract certain features to be used later in the hierarchical loss function. Again, note this is a duplicate and not the actual one used in the backend. This is simply for demonstration purposes.

In [None]:
if config.use_hierarchy:
    all_paths, path_lengths, mask_list, y_dict = taxonomy.calc_paths_and_masks()

## Step 2 (Option 1): Fit all transients using SVI (faster)

Here, we choose to fit our transients using stochastic variational inference (SVI). If using this option, make sure sampler='superphot_svi' in the config.yaml file. This option is faster but assumes Gaussianity of the posterior space, which can be limiting for certain light curve fits.

For all 7202 TNS transients passing quality cuts, this takes ~30 minutes.

In [None]:
from snapi.scripts import fit_transient_group
from snapi import TransientGroup, SamplerResultGroup
from superphot_plus.samplers.numpyro_sampler import SVISampler
from superphot_plus.priors import generate_priors, SuperphotPrior
from superphot_plus.config import SuperphotConfig
import os
from pathlib import Path

p = Path(os.getcwd()).parents[1]
SAVE_DIR = os.path.join(p, "data", "tutorial")

config = SuperphotConfig.from_file("config.yaml")

#priors = generate_priors(["ZTF_r","ZTF_g"])
priors = SuperphotPrior.load(SAVE_DIR + "/" + "global_priors_hier_svi")
svi_sampler = SVISampler(
    priors=priors,
    num_iter=10_000,
    random_state=config.random_seed,
)

transient_group = TransientGroup.load(config.transient_data_fn)
print("Transient group loaded")

result = fit_transient_group(
    transient_group,
    sampler = svi_sampler,
    parallelize=True,
    n_parallel=config.n_parallel,
    checkpoint_fn = os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    ),
    checkpoint_freq = 512,
    pad=True,
    overwrite=True # set to False to continue where left off
)
SamplerResultGroup(result).save(
    os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    )
)

In [None]:
# sanity check plot
import os
import matplotlib.pyplot as plt

from snapi import TransientGroup, SamplerResultGroup, Formatter
from superphot_plus.samplers.numpyro_sampler import SVISampler
from superphot_plus.priors import generate_priors
from superphot_plus.config import SuperphotConfig

config = SuperphotConfig.from_file("config.yaml")
transient_group = TransientGroup.load(config.transient_data_fn)
sampler_results = SamplerResultGroup.load(
    os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    )
)

priors = generate_priors(["ZTF_r","ZTF_g"])
svi_sampler = SVISampler(
    priors=priors,
    num_iter=10_000,
    random_state=config.random_seed,
)

print(len(sampler_results), sampler_results.metadata.tail())
names = sampler_results.metadata.index

formatter = Formatter()
for n in names[-5:]:
    t = transient_group[n] # can index like dictionary
    sr = sampler_results[n]
    svi_sampler.load_result(sr)
    
    fig, ax = plt.subplots()
    svi_sampler.plot_fit(
        ax,
        photometry = t.photometry,
        formatter = formatter,
    )
    formatter.reset_colors()
    formatter.reset_markers()
    t.photometry.plot(
        ax,
        mags=False,
        formatter=formatter
    )
    formatter.make_plot_pretty(ax)
    formatter.add_legend(ax)
    formatter.reset_colors()
    formatter.reset_markers()
    
    plt.show()
    

## Step 2 (Option 2): Fit light curves using dynesty (slower)

Here, we fit our transient photometry using the dynesty nested sampler. This is slower but does not assume Gaussianity of the posterior space, so can better capture degeneracies between parameters. If you use this, make sure to set sampler=superphot_dynesty in the config.yaml file.

Runtime for 7202 TNS samples: ~200 minutes (3.5 hours)

In [None]:
from snapi.scripts import fit_transient_group
from snapi import TransientGroup, SamplerResultGroup
from superphot_plus.samplers.dynesty_sampler import DynestySampler
from superphot_plus.priors import generate_priors, SuperphotPrior
from superphot_plus.config import SuperphotConfig

from pathlib import Path
import os

p = Path(os.getcwd()).parents[1]
SAVE_DIR = os.path.join(p, "data", "tutorial")

config = SuperphotConfig.from_file("config.yaml")

#priors = generate_priors(["ZTF_r","ZTF_g"])
priors = SuperphotPrior.load(SAVE_DIR + "/" + "global_priors_hier_svi")

transient_group = TransientGroup.load(config.transient_data_fn)
print("Transient group loaded")

priors = generate_priors(["ZTF_r","ZTF_g"])

dynesty_sampler = DynestySampler(
    priors=priors,
    random_state=config.random_seed,
)

result = fit_transient_group(
    transient_group,
    sampler = dynesty_sampler,
    parallelize=True,
    n_parallel=config.n_parallel,
    checkpoint_fn = os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    ),
    checkpoint_freq = 128,
    pad=False,
    overwrite=True, # False to continue from checkpoint
)
SamplerResultGroup(result).save(
    os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    ),
)


In [None]:
# sanity check plot
import os
import matplotlib.pyplot as plt

from snapi import TransientGroup, SamplerResultGroup, Formatter
from superphot_plus.samplers.dynesty_sampler import DynestySampler
from superphot_plus.priors import generate_priors
from superphot_plus.config import SuperphotConfig

config = SuperphotConfig.from_file("config.yaml")
transient_group = TransientGroup.load(config.transient_data_fn)
sampler_results = SamplerResultGroup.load(
    os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    )
)

priors = generate_priors(["ZTF_r","ZTF_g"])

svi_sampler = DynestySampler(
    priors=priors,
    random_state=config.random_seed,
)

names = sampler_results.metadata.index

formatter = Formatter()
for n in names[-5:]: # neweet 
    t = transient_group[n] # can index like dictionary
    sr = sampler_results[n]
    svi_sampler.load_result(sr)
    
    fig, ax = plt.subplots()
    svi_sampler.plot_fit(
        ax,
        photometry = t.photometry,
        formatter = formatter,
    )
    formatter.reset_colors()
    formatter.reset_markers()
    t.photometry.plot(
        ax,
        mags=False,
        formatter=formatter
    )
    formatter.make_plot_pretty(ax)
    formatter.add_legend(ax)

    formatter.reset_colors()
    formatter.reset_markers()
    
    plt.show()
    

## Step 2.5: Convert SamplerResultGroup posteriors back to uncorrelated Gaussians

When sampling, the posteriors are saved as the inputs to our flux model. The Gaussian priors, however, were converted to log-Gaussians and multiplied by base parameters where necessary before being fed into the model function. Therefore, we must revert these log-Gaussian and relative parameters back to their original uncorrelated Gaussian draws before using as classifier inputs. We do this below:

In [None]:
# warning: only run once!
import os
from snapi import SamplerResultGroup
from superphot_plus.priors import generate_priors
from superphot_plus.config import SuperphotConfig

config = SuperphotConfig.from_file("config.yaml")

priors = generate_priors(["ZTF_r","ZTF_g"])
sampler_results = SamplerResultGroup.load(
    os.path.join(
        config.data_dir,
        "tmp_sampler_results"
    )
)

new_sr = []
for i, sr in enumerate(sampler_results):
    if i % 1000 == 0:
        print(f"Converted {i} out of {len(sampler_results)} fits")
    sr.fit_parameters = priors.reverse_transform(sr.fit_parameters)
    new_sr.append(sr)
    
new_sampler_results = SamplerResultGroup(new_sr)
new_sampler_results.save(config.sampler_results_fn)


In [None]:
from snapi import SamplerResultGroup
from superphot_plus.config import SuperphotConfig
import os

config = SuperphotConfig.from_file("config.yaml")
srg = SamplerResultGroup.load(config.sampler_results_fn)
metadata = srg.metadata
metadata.to_csv(os.path.join(config.data_dir, "all_samples.csv"))
print(metadata.head())

## Step 3: Train + evaluate classifier from sampling posteriors

Here we train a classifier with our uncorrelated posterior features. This script will automatically split the data into K-folds, oversample the training and validation sets to even out minority classes, and train either LightGBMs (recommended) or MLPs. If plot is True, metric plots and confusion matrices will also be generated.

In [None]:
from superphot_plus import SuperphotConfig, SuperphotTrainer
import pandas as pd
import os

config = SuperphotConfig.from_file("config.yaml")

# remove A_ZTF_r and t_0_ZTF_r from params used in classification - see paper for details
metadata = pd.read_csv(os.path.join(config.data_dir, "all_samples.csv"), index_col = 0)
keep_cols = metadata.drop(
    columns=['A_ZTF_r_median', 't_0_ZTF_r_median', 'score_median', 'sampler']
).columns
config.input_features = [c.replace("_median", "") for c in keep_cols]
print(config.input_features)

# train classifier
trainer = SuperphotTrainer(config)
trainer.run()

Finally, we train a version of the classifier without a test set (aka we use the entire dataset in training or validation). This is what we'll be using to classify a new, disparate dataset.

In addition to the classic full-phase classifier, we train a classifier that only uses early-phase features (excludes plateau durations and fall timescales). This is more effective at classifying partial supernova light curves.

In [2]:
from superphot_plus import SuperphotConfig, SuperphotTrainer
from snapi import TransientGroup, SamplerResultGroup
import os
import pandas as pd

config = SuperphotConfig.from_file("config.yaml")

# remove A_ZTF_r and t_0_ZTF_r from params used in classification - see paper for details
metadata = pd.read_csv(os.path.join(config.data_dir, "all_samples.csv"), index_col = 0)
keep_cols = metadata.drop(
    columns=['A_ZTF_r_median', 't_0_ZTF_r_median', 'score_median', 'sampler']
).columns
config.input_features = [c.replace("_median", "") for c in keep_cols]
print(config.input_features)

transient_group = TransientGroup.load(config.transient_data_fn)
srg = SamplerResultGroup.load(config.sampler_results_fn)

trainer = SuperphotTrainer(config)
trainer.setup_model()
meta_df = trainer.retrieve_transient_metadata(transient_group)
train_df, val_df = trainer.split(meta_df, split_frac=0.1)
train_srg = srg.filter(train_df.index)
val_srg = srg.filter(val_df.index)

trainer.train(0, (train_df, train_srg), (val_df, val_srg))
trainer.models[0].save(config.model_prefix + "_full")
print(trainer.models[0].best_model.feature_name_)

# Calls to trainer.evaluate -> mlp.evaluate -> mlp.get_predictions ->
probs_avg = trainer.evaluate(0, (meta_df, srg))
probs_avg.to_csv(config.probs_fn[:-4] + "_full.csv")

# # train early-type
# metadata = pd.read_csv(os.path.join(config.data_dir, "all_samples.csv"), index_col = 0)
# keep_cols = metadata.drop(
#     columns=['A_ZTF_r_median', 't_0_ZTF_r_median', 'score_median', 'sampler']
# ).columns
# trainer.config.input_features = [c.replace("_median", "") for c in keep_cols if (
#         ('tau_fall' not in c) and ('gamma' not in c)
#     )
# ]

# trainer.train(1, (train_df, train_srg), (val_df, val_srg))
# trainer.models[1].save(config.model_prefix + "_early")
# print(trainer.models[1].best_model.feature_name_)

# probs_avg = trainer.evaluate(1, (meta_df, srg))
# probs_avg.to_csv(config.probs_fn[:-4] + "_early.csv")


AttributeError: partially initialized module 'torch' has no attribute 'fx' (most likely due to a circular import)

Below we plot the confusion matrices for both the full and early models. These plots can be found in the same folder where you saved all the data for this training run under **figs**.

In [None]:
from superphot_plus.plotting.confusion_matrices import *
from superphot_plus import SuperphotConfig
import pandas as pd

config = SuperphotConfig.from_file("config.yaml")

def df_from_csv(filepath):
    """
    Loads a CSV file into a pandas DataFrame.

    Parameters:
        filepath (str): Path to the CSV file.

    Returns:
        pd.DataFrame: The resulting DataFrame.
    """
    try:
        df = pd.read_csv(filepath)
        return df
    except FileNotFoundError:
        print(f"Error: File not found at '{filepath}'")
    except pd.errors.EmptyDataError:
        print(f"Error: File at '{filepath}' is empty or improperly formatted")
    except Exception as e:
        print(f"Unexpected error: {e}")
    
#early_df = df_from_csv("/Users/charisgraham/Desktop/Summer2025/superphotplus/superphot-plus/data/whxe/probabilities/probs_superphot_svi_MLP_None_False_10_None_5_1_42_64_3_0.001_64_early.csv")
full_df = df_from_csv(f"{config.probs_fn[:-4]}_full.csv")


In [None]:
plot_matrices(config, full_df)

In [None]:
plot_matrices(config, early_df)

# Cut out outputs