# Comparing KN surrogate models

**Abstract:** Here we compare the MSE of the surrogate models between, e.g. Tensorflow and Jax/flax models.

In [None]:
%load_ext autoreload 
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from nmma.em.training import SVDTrainingModel
from nmma.em.model import SVDLightCurveModel
import nmma as nmma
import time
import arviz

params = {"axes.grid": True,
        "text.usetex" : True,
        "font.family" : "serif",
        "ytick.color" : "black",
        "xtick.color" : "black",
        "axes.labelcolor" : "black",
        "axes.edgecolor" : "black",
        "font.serif" : ["Computer Modern Serif"],
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "axes.labelsize": 16,
        "legend.fontsize": 16,
        "legend.title_fontsize": 16,
        "figure.titlesize": 16}

plt.rcParams.update(params)

from nmma.em.io import read_photometry_files
from nmma.em.utils import interpolate_nans

import inspect 
import nmma.em.model_parameters as model_parameters

MODEL_FUNCTIONS = {
    k: v for k, v in model_parameters.__dict__.items() if inspect.isfunction(v)
}

model_name = "Bu2022Ye"
model_function = MODEL_FUNCTIONS[model_name]

In [None]:
import jax
import jaxlib
jax.devices() # check if CUDA is present

## Preprocessing data

In [None]:
lcs_dir = "/home/urash/twouters/KN_Lightcurves/lightcurves/lcs_bulla_2022" # for remote SSH Potsdam
flax_svd_path = "/home/urash/twouters/nmma_models/flax_models/" # initial flax models will be saved here
svd_path = "/home/urash/twouters/nmma_models/"
filenames = os.listdir(lcs_dir)
full_filenames = [os.path.join(lcs_dir, f) for f in filenames]
print(f"There are {len(full_filenames)} lightcurves for this model.")

In [None]:
# Read the data and interpolate the NaNs, and get training data
data = read_photometry_files(full_filenames)
data = interpolate_nans(data)
training_data, parameters = model_function(data)

In [None]:
#Extract times and filters
key = list(training_data.keys())[0]
example = training_data[key]
t = example["t"]
keys = list(example.keys())
filts = [k for k in keys if k not in parameters + ["t"]]

## Get the models


In [None]:
flax_model = SVDLightCurveModel(
        model_name,
        t,
        svd_path=flax_svd_path,
        parameter_conversion=None,
        mag_ncoeff=10,
        lbol_ncoeff=None,
        interpolation_type="flax",
        model_parameters=None,
        filters=filts,
        local_only=True
)
print(flax_model.svd_path)

In [None]:
# training_model.__dict__.keys()

In [None]:
tf_model = SVDLightCurveModel(
        model_name,
        t,
        svd_path=svd_path,
        parameter_conversion=None,
        mag_ncoeff=10,
        lbol_ncoeff=None,
        interpolation_type="tensorflow",
        model_parameters=None,
        filters=filts,
        local_only=True
)
print(tf_model.svd_path)

## Get the input and output pairs of the Bu2022Ye model

In [None]:
def get_input_values(training_data: dict, parameters: list) -> np.ndarray:
    """
    From a dictionary of training data, extract the input values for the model.
    """
    
    input_values = []
    
    for key in training_data.keys():
        data = training_data[key]
        input_values.append([data[param] for param in parameters])
    
    return np.array(input_values)

In [None]:
def get_output_values(training_data: dict) -> np.ndarray:
    """
    From a dictionary of training data, extract the output values for the model.
    """
    keys = list(training_data.keys())
    if "data" in training_data[keys[0]].keys():
        # This is the version that was in use for NMMA-GPU
        output_values = [training_data[key]["data"] for key in training_data.keys()]
    else:
        # This is the version for NMMA CPU, 12/12/2023
        output_values = []
        for key in training_data.keys():
            data = training_data[key]
            new_list = [data[f] for f in filts]
            output_values.append(new_list)
    
    return np.array(output_values)

In [None]:
input_values = get_input_values(training_data, parameters)
output_values = get_output_values(training_data)

In [None]:
print(np.shape(input_values))
print(np.shape(output_values))

NOTE: this can be done faster with jax.vmap etc, but not for 

In [40]:
N = 100
# Select a random subset of indices for the input values
idx_list = np.random.choice(len(input_values), N, replace=False)
sampled_input_values = input_values[idx_list]
sampled_output_values = output_values[idx_list]

**NOTE** this is without optimization from jax!

In [41]:
output_dict = {"flax": [], 
               "tf": []}

models_dict = {"flax": flax_model, 
               "tf": tf_model}

for key in models_dict.keys():
    print(f"Getting output for model {key}")
    model = models_dict[key]
    start = time.time()

    for i in idx_list:
        # Compute the lightcurve
        _, _, mag = nmma.em.utils.calc_lc(t,
                                    input_values[i], 
                                    svd_mag_model = model.svd_mag_model, 
                                    interpolation_type=model.interpolation_type,
                                    filters = filts, 
                                    mag_ncoeff = 10
                                    )
        # Convert this dictionary to values of the LCs
        mag = mag.values()
        mag = np.array(list(mag))#.T
        # Save to the correct output
        output_dict[key].append(mag)
    end = time.time()
    print(f"Computing all the flax lightcurves for a subset of {N} lightcurves took {end-start} seconds.")
    # Make sure this is a np.ndarray

# Convert to np.ndarray
output_dict["flax"] = np.array(output_dict["flax"])
output_dict["tf"] = np.array(output_dict["tf"])

Getting output for model flax
Computing all the flax lightcurves for a subset of 100 lightcurves took 32.9498770236969 seconds.
Getting output for model tf
Computing all the flax lightcurves for a subset of 100 lightcurves took 4.315046072006226 seconds.


## Compare MSE or MAE values

TODO: Best to compare this as a distribution, and perhaps best to consider MAE, or some self-defined loss function or error function?

In [42]:
def mse(y_true, y_pred, axis=None):
    return np.mean((y_true - y_pred)**2, axis=axis)

def se(y_true, y_pred):
    return (y_true - y_pred)**2

def mae(y_true, y_pred, axis=None):
    return np.mean(np.abs(y_true - y_pred), axis=axis)

def ae(y_true, y_pred):
    return np.abs(y_true - y_pred)

def my_format(low: float, med: float, high: float, nb: int = 3) -> str:
    med = np.round(med, nb)
    low = med - low
    low = np.round(low, nb)
    high = high - med
    high = np.round(high, nb)
    
    return f"{med} - {low} + {high}"

In [56]:
mse_dict = {"flax mse": [], 
            "tf mse": [],
            "flax mae": [], 
            "tf mae": []}
for key, output in output_dict.items():
    for error_fn, name_error_fn in zip([mse, mae], ["mse", "mae"]):
        diffs = se(output, sampled_output_values)
        axis = 0
        mse_values = error_fn(output, sampled_output_values, axis=axis)
        mse_values = np.mean(mse_values, axis=0)
        print(f"Error function: {name_error_fn}, model: {key}")
        for f, val in zip(filts, mse_values):
            # Add to my dictionary
            mse_dict[f"{key} {name_error_fn}"].append(val)
            # Print if desired
            ## print(f"{f}: {val}")

Error function: mse, model: flax
Error function: mae, model: flax
Error function: mse, model: tf
Error function: mae, model: tf


In [57]:
df = pd.DataFrame(mse_dict)
df.index = filts
df

Unnamed: 0,flax mse,tf mse,flax mae,tf mae
bessellux,5.388313,6.911255,1.145407,1.361794
bessellb,3.927375,4.915675,0.969374,1.178965
bessellv,2.731485,3.296068,0.779683,0.953346
bessellr,1.695274,2.189382,0.673414,0.841428
besselli,1.344738,1.690819,0.605197,0.766069
sdssu,0.859674,1.072204,0.531136,0.657003
ps1__g,0.68434,0.850139,0.461206,0.572337
ps1__r,0.514548,0.695942,0.450967,0.56584
ps1__i,0.46974,0.607467,0.408007,0.510941
ps1__z,0.362303,0.50263,0.38709,0.499446


## Speed comparison: improve speed generation of flax?

In [44]:
import jax
import jax.numpy as jnp
import scipy.interpolate as interp

### 1. Jit, no vmap

In [52]:
# Lambda function, so that we focus on the parameters as being the input only
from nmma.em.utils import get_calc_lc_jit
calc_lc_given_params_jit = get_calc_lc_jit(t, svd_mag_model=flax_model.svd_mag_model, filters=filts)
# Compilation
_ = calc_lc_given_params_jit(sampled_input_values[0])
start = time.time()
for i in range(N):
    _ = calc_lc_given_params_jit(sampled_input_values[i])
end = time.time()
print(f"Computing all the flax lightcurves for a subset of {N} lightcurves took {end-start} seconds.")

Computing all the flax lightcurves for a subset of 100 lightcurves took 0.032933712005615234 seconds.


### 2. Jit and vmap

In [53]:
# vmap the function
calc_lc_given_params_vmap = jax.vmap(calc_lc_given_params_jit)
# apply to input_values
# Compilation
_ = calc_lc_given_params_vmap(sampled_input_values)
# Execution time
start = time.time()
_ = calc_lc_given_params_vmap(sampled_input_values)
end = time.time()
print(f"Computing all the flax lightcurves for a subset of {N} lightcurves took {end-start} seconds.")

Computing all the flax lightcurves for a subset of 100 lightcurves took 0.0031833648681640625 seconds.
