In [None]:
%matplotlib inline

from typing import TYPE_CHECKING

import suPAErnova as snpae
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import tensorflow as tf
from tensorflow import keras as ks

# Packages needed only for type checking
if TYPE_CHECKING:
    from suPAErnova.uttils.suPAErnova_types import CONFIG

# Some useful paths to have access to
cwd = Path.cwd()
examples_dir = cwd.parent.parent.parent.parent
data_dir = examples_dir / "suPAErnova_data" # Put your data into this directory

# Global Configuration

In [None]:
verbose = False # Increase log verbosity
force = False # Rerun all steps every time

cfg = snpae.setup_global_config({}, verbose=verbose, force=force) # Just pass an empty dictionary to initialise

# Data Configuration and Setup

In [None]:
# Note that these keys *MUST* be captilalised. This is not the case when using a `suPAErnova.toml` config file.
data_cfg = {
    # === Required Keys ===

    # Path to directory containing data.
    #   Can be absolute or relative to the base path.
    "DATA_DIR": str(data_dir), # Needs to be a string so that SuPAErnova can validate it
    
    # Metadata CSV containing SN names and SALT fit parameters.
    #   Can be absolute or relative to the data path.
    "META": "meta.csv",

    # TXT file containing additional SALT fit parameters.
    #   Can be absolute or relative to the data path.
    "IDR": "IDR_eTmax.txt",

    # TXT file containing a mask of bad spectra / wavelength ranges.
    #   Can be absolute or relative to the data path.
    "MASK": "mask_info_wmin_wmax.txt",
}

cfg["DATA"] = data_cfg
data = snpae.steps.DATAStep(cfg)
success, result = data.setup()
if not success: # Make sure you handle failures appropriately!
    data.log.error(f"Error running setup: {result}")
success, result = data.run()
if not success: # Make sure you handle failures appropriately!
    data.log.error(f"Error running: {result}")
success, result = data.result()
if not success: # Make sure you handle failures appropriately!
    data.log.error(f"Error saving results: {result}")

# The result() function of a step returns the original cfg passed in, but now with the step stored in cfg["GLOBAL"]["RESULTS"]step.name]
# This allows later steps to access the results of previous steps.
cfg = result
print(cfg["GLOBAL"]["RESULTS"][data.name])

# TF PAE Setup
Just using the default parameters here

In [None]:
tf_pae_cfg = {}
cfg["TF_PAE"] = tf_pae_cfg

# Callbacks
The easiest way to interact with the `SuPAErnova` pipeline, without having to delve into the source code, is through callback functions. These are user-defined functions which run before and after different stages within each step. When using a `suPAErnova.toml` file, these callback function are defined in scripts, with the path to these scripts provided in the confg file. Here you can directly define your functions and pass them in.

You do have to be careful with some of these callback function as many will be run within `@tf.function` decorated function, meaning there are some annoying side-effects you'll need to pay attention to.

In [None]:
callbacks = {}

# analyse
def pre_analyse(self) -> None:
    print("pre-analyse callback")

def post_analyse(self) -> None:
    print("post-analyse callback")

callbacks["ANALYSE"] = {"PRE": pre_analyse, "POST": post_analyse}

# result
def pre_result(self) -> None:
    print("pre-result callback")

def post_result(self) -> None:
    print("post-result callback")

callbacks["RESULT"] = {"PRE": pre_result, "POST": post_result}

# run
def pre_run(self) -> None:
    print("pre-run callback")

def post_run(self) -> None:
    print("post-run callback")

callbacks["RUN"] = {"PRE": pre_run, "POST": post_run}

# setup
def pre_setup(self) -> None:
    print("pre-setup callback")

def post_setup(self) -> None:
    print("post-setup callback")

callbacks["SETUP"] = {"PRE": pre_setup, "POST": post_setup}

# validate
def pre_validate(self) -> None:
    print("pre-validate callback")

def post_validate(self) -> None:
    print("post-validate callback")

callbacks["VALIDATE"] = {"PRE": pre_validate, "POST": post_validate}

# train_all
def pre_train_all(self) -> None:
    print("pre-train_all callback")

def post_train_all(self) -> None:
    print("post-train_all callback")

callbacks["TRAIN_ALL"] = {"PRE": pre_train_all, "POST": post_train_all}

# train_amplitude
def pre_train_amplitude(self) -> None:
    print("pre-train_amplitude callback")

def post_train_amplitude(self) -> None:
    print("post-train_amplitude callback")

callbacks["TRAIN_AMPLITUDE"] = {"PRE": pre_train_amplitude, "POST": post_train_amplitude}

# train_colour  
def pre_train_colour(self) -> None:
    print("pre-train_colour callback")

def post_train_colour(self) -> None:
    print("post-train_colour callback")

callbacks["TRAIN_COLOUR"] = {"PRE": pre_train_colour, "POST": post_train_colour}

# train_latents    
def pre_train_latents(self) -> None:
    print("pre-train_latents callback")

def post_train_latents(self) -> None:
    print("post-train_latents callback")

callbacks["TRAIN_LATENTS"] = {"PRE": pre_train_latents, "POST": post_train_latents}

# train_model      
def pre_train_model(self) -> None:
    print("pre-train_model callback")

def post_train_model(self) -> None:
    print("post-train_model callback")

callbacks["TRAIN_MODEL"] = {"PRE": pre_train_model, "POST": post_train_model}

# train_time     
def pre_train_time(self) -> None:
    print("pre-train_time callback")

def post_train_time(self) -> None:
    print("post-train_time callback")

callbacks["TRAIN_TIME"] = {"PRE": pre_train_time, "POST": post_train_time}

## Custom activation function

In [None]:
# setup_activation
def pre_setup_activation(self) -> None:
    print("pre-setup_activation callback")
    print(self.opts["ACTIVATION"]) # The string the user has passed


def post_setup_activation(self) -> None:
    print("post-setup_activation callback")
    
    print(self.opts["ACTIVATION"]) # The corresponding pre-built activation function
    # Here, you can replace `self.opts["ACTIVATION"]` with your own function
    # As long as it has the signature f(x: _ActivationInput) -> tf.Tensor
    # It should work fine.

    original_activation = self.opts["ACTIVATION"]

    # Add this to register your function as serialisable, allowing it to be saved and loaded alongside the TF_PAEModel
    @ks.utils.register_keras_serializable(name="custom_activation")
    def custom_activation(x: "ks.activations._ActivationInput") -> "tf.Tensor":
        # Warning, this prints *a lot*, here for demonstration purposes, but I wouldn't recommend actually doing this.
        # TODO: Find where this is being used outside of a `@tf.function` wrapper, and fix that
        # print(f"my custom activation function received {x} as input")
        
        return original_activation(x)
    self.opts["ACTIVATION"] = custom_activation

    print(self.opts["ACTIVATION"]) # Your custom activation function

callbacks["SETUP_ACTIVATION"] = {"PRE": pre_setup_activation, "POST": post_setup_activation}

## Custom loss function

In [None]:
# setup_loss
def pre_setup_loss(self) -> None:
    print("pre-setup_loss callback")
    print(self.opts["LOSS"]) # The string the user has passed


def post_setup_loss(self) -> None:
    print("post-setup_loss callback")
    
    print(self.opts["LOSS"]) # The corresponding pre-built loss function
    # Here, you can replace `self.opts["LOSS"]` with your own function
    # As long as it has the signature f(x: tf.Tensor, x_pred: tf.Tensor, kwargs: CONFIG[tf.Tensor]) -> tf.Tensor
    # It should work fine.

    # Here:
    #   - `x` is the amplitude (flux) of the real SN spectra
    #   - `x_pred` is the amplitude (flux) of the corresponding encode-decode SN spectra
    #   - `kwargs` is a dictionary containing:
    #       "sigma": The uncertainty in the amplitude (flux) of the real SN spectra
    #       "mask": The mask of spectral data which should be ignored
    #       "model": The `TF_PAEModel` being train, which itself will contain lots of other information you can use
    #
    # The returned tensor must have shape=(), i.e. be a single number as a Tensor. This can usually be achieved through the various reduce_* functions.

    original_loss = self.opts["LOSS"]
    def custom_loss(x: "tf.Tensor", x_pred: "tf.Tensor", kwargs: "CONFIG[tf.Tensor]") -> "tf.Tensor":
        # This will only print once, since this runs within a `@tf.function` wrapper
                # Use tf.print to print every time this is run
        print(f"my custom loss function received a real spectra: {x}, a predicted spectra: {x_pred}, and kwargs: {kwargs}")
        return original_loss(x, x_pred, kwargs)
    
    self.opts["LOSS"] = custom_loss
    print(self.opts["LOSS"]) # Your custom loss function

callbacks["SETUP_LOSS"] = {"PRE": pre_setup_loss, "POST": post_setup_loss}

## Custom optimiser

In [None]:
# setup_optimiser
def pre_setup_optimiser(self) -> None:
    print("pre-setup_optimiser callback")
    print(self.opts["OPTIMISER"]) # The string the user has passed


def post_setup_optimiser(self) -> None:
    print("post-setup_optimiser callback")
    
    print(self.optimiser) # The corresponding pre-built optimiser function
    # Here, you can replace `self.optimiser` with your own function
    # As long as it has the signature f(lr: ks.optimizers.schedules.LearningRateSchedule | float, kwargs: CFG) -> ks.optimizers.Optimizer 
    # It should work fine.

    # Here:
    #   - `lr` is the incoming learning rate either as a float or a LearningRateSchedule
    #   - `kwargs` is a dictionary containing:
    #       "lr_decay_steps": The user-defined `lr_decay_steps` parameter
    #       "lr_decay_rate": The user-defined `lr_decay_rate` parameter
    #       "weight_decay_rate": The user-defined `weight_decay_rate` parameter
 
    original_optimiser = self.optimiser
    def custom_optimiser(lr: "ks.optimizers.schedules.LearningRateSchedule | float", kwargs: "CFG") -> "ks.optimizers.Optimizer":
        # This will only print once, since this runs within a `@tf.function` wrapper
        # Use tf.print to print every time this is run
        print(f"my custom optimiser received a learning rate: {lr} and kwargs: {kwargs}")
        return original_optimiser(lr, kwargs)
    
    self.optimiser = custom_optimiser
    print(self.optimiser) # Your custom optimiser function

callbacks["SETUP_OPTIMISER"] = {"PRE": pre_setup_optimiser, "POST": post_setup_optimiser}

## Custom scheduler

In [None]:
# setup_scheduler
def pre_setup_scheduler(self) -> None:
    print("pre-setup_scheduler callback")
    print(self.opts["SCHEDULER"]) # The string the user has passed


def post_setup_scheduler(self) -> None:
    print("post-setup_scheduler callback")
    
    print(self.scheduler) # The corresponding pre-built scheduler function
    # Here, you can replace `self.scheduler` with your own function
    # As long as it has the signature f(lr: float, kwargs: CFG) -> ks.optimizers.schedules.LearningRateSchedule | flaot
    # It should work fine.

    # Here:
    #   - `lr` is the incoming learning rate either as a float
    #   - `kwargs` is a dictionary containing:
    #       "lr_decay_steps": The user-defined `lr_decay_steps` parameter
    #       "lr_decay_rate": The user-defined `lr_decay_rate` parameter
    #       "weight_decay_rate": The user-defined `weight_decay_rate` parameter
 
    original_scheduler = self.scheduler
    def custom_scheduler(lr: "float", kwargs: "CFG") -> "ks.optimizers.schedules.LearningRateSchedule | float":
        # This will only print once, since this runs within a `@tf.function` wrapper
        # Use tf.print to print every time this is run
        print(f"my custom scheduler received a learning rate: {lr} and kwargs: {kwargs}")
        return original_scheduler(lr, kwargs)
    
    self.scheduler = custom_scheduler
    print(self.scheduler) # Your custom scheduler function

callbacks["SETUP_SCHEDULER"] = {"PRE": pre_setup_scheduler, "POST": post_setup_scheduler}

In [None]:
cfg["TF_PAE"]["CALLBACKS"] = callbacks

tf_pae = snpae.steps.TF_PAEStep(cfg)
print(tf_pae)

In [None]:
success, result = tf_pae.setup()
if not success: # Make sure you handle failures appropriately!
    tf_pae.log.error(f"Error running setup: {result}")
success, result = tf_pae.run()
if not success: # Make sure you handle failures appropriately!
    tf_pae.log.error(f"Error running: {result}")
success, result = tf_pae.result()
if not success: # Make sure you handle failures appropriately!
    tf_pae.log.error(f"Error saving results: {result}")
success, result = tf_pae.analyse()
if not success: # Make sure you handle failures appropriately!
    tf_pae.log.error(f"Error analysing: {result}")