# Comparing KN surrogate models

**Abstract:** Here we compare the MSE of the surrogate models between, e.g. the existing Tensorflow model on Zenodo vs a new, deep network as surrogate. 

In [1]:
%load_ext autoreload 
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from nmma.em.training import SVDTrainingModel
import nmma as nmma
import time
import arviz

params = {"axes.grid": True,
        "text.usetex" : True,
        "font.family" : "serif",
        "ytick.color" : "black",
        "xtick.color" : "black",
        "axes.labelcolor" : "black",
        "axes.edgecolor" : "black",
        "font.serif" : ["Computer Modern Serif"],
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "axes.labelsize": 16,
        "legend.fontsize": 16,
        "legend.title_fontsize": 16,
        "figure.titlesize": 16}

plt.rcParams.update(params)

from nmma.em.io import read_photometry_files
from nmma.em.utils import interpolate_nans

import inspect 
import nmma.em.model_parameters as model_parameters

MODEL_FUNCTIONS = {
    k: v for k, v in model_parameters.__dict__.items() if inspect.isfunction(v)
}

model_name = "Bu2022Ye"
model_function = MODEL_FUNCTIONS[model_name]

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Install wrapt_timeout_decorator if you want timeout simulations.


## Preprocessing data

In [2]:
# Choose model and set location of the kilonova lightcurves
bulla_2022_dir = "/home/urash/twouters/KN_Lightcurves/lightcurves/lcs_bulla_2022"
bulla_2019_dir = "/home/urash/twouters/KN_Lightcurves/lightcurves/lcs_bulla_2019"

# Choose the model here
# model_name = "Bu2022Ye"
model_name = "Bu2019lm"
model_function = MODEL_FUNCTIONS[model_name]

# Set the location of the lightcurves and outdir based on chosen model
if model_name == "Bu2022Ye":
    lcs_dir = bulla_2022_dir
    
elif model_name == "Bu2019lm":
    lcs_dir = bulla_2019_dir

svd_path = f"/home/urash/twouters/new_nmma_models/"
old_svd_path = f"/home/urash/twouters/nmma_models/"

# Process the KN lightcurves
filenames = os.listdir(lcs_dir)
full_filenames = [os.path.join(lcs_dir, f) for f in filenames]
print(f"There are {len(full_filenames)} lightcurves for this model.")

print("Reading lightcurves and interpolating NaNs...")
data = read_photometry_files(full_filenames)
data = interpolate_nans(data)
keys = list(data.keys())
filts = sorted(list(set(data[keys[0]].keys()) - {"t"}))

print("Reading lightcurves and interpolating NaNs... DONE")

# Limit to the filters of interest for the KN event that Peter is interested in:
if model_name == "Bu2022Ye":
    filts = ["ztfg", "ztfi", "ztfr"] # limited for now for Peter's KN event
else:
    zenodo_filts  = ["ztfg", "ztfi", "ztfr"]
    filts = ['sdss__g', 'sdss__i', 'sdss__r'] # NOTE we ignore , 'sdss__u', 'sdss__z' for the comparison with the zenodo data
    
print("Filters:")
print(filts)

# Get the time array
dat = pd.read_csv(full_filenames[0], delimiter=" ", escapechar='#')
dat = dat.rename(columns={" t[days]": "t"})
t = dat["t"].values

print("Genrating training data...")
training_data, parameters = model_function(data)
print("Genrating training data... DONE")

There are 1596 lightcurves for this model.
Reading lightcurves and interpolating NaNs...
Reading lightcurves and interpolating NaNs... DONE
Filters:
['sdss__g', 'sdss__i', 'sdss__r']
Genrating training data...
Genrating training data... DONE


## Get the models

In [4]:
model_name

'Bu2019lm'

In [7]:
from nmma.em.model import SVDLightCurveModel

print("Loading new LC model...")
new_lc_model = SVDLightCurveModel(
        model_name,
        t,
        svd_path=svd_path,
        parameter_conversion=None,
        mag_ncoeff=10,
        lbol_ncoeff=None,
        interpolation_type="tensorflow",
        model_parameters=None,
        filters=filts,
        local_only=True
)
print("Loading new LC model... DONE")

if model_name == "Bu2022Ye":
        old_model_name = model_name
        old_filts = filts # use the same filters
else:
        old_model_name = model_name 
        old_filts = zenodo_filts # load the zenodo filters instead

print("Loading old LC model...")
old_lc_model = SVDLightCurveModel(
        model_name,
        t,
        svd_path=old_svd_path,
        parameter_conversion=None,
        mag_ncoeff=10,
        lbol_ncoeff=None,
        interpolation_type="tensorflow",
        model_parameters=None,
        filters=old_filts,
        local_only=True
)
print("Loading old LC model... DONE")

Loading new LC model...
Loaded filter sdss__g
Loaded filter sdss__i
Loaded filter sdss__r
Loading new LC model... DONE
Loading old LC model...
Loaded filter ztfg
Loaded filter ztfi
Loaded filter ztfr
Loading old LC model... DONE


## Get the input and output pairs of lightcurve grid

In [None]:
def get_input_values(training_data: dict, parameters: list) -> np.ndarray:
    """
    From a dictionary of training data, extract the input values for the model.
    """
    
    input_values = []
    
    for key in training_data.keys():
        data = training_data[key]
        input_values.append([data[param] for param in parameters])
    
    return np.array(input_values)

In [None]:
def get_output_values(training_data: dict) -> np.ndarray:
    """
    From a dictionary of training data, extract the output values for the model.
    """
    if "data" in training_data[keys[0]].keys():
        # This is the version that was in use for NMMA-GPU
        output_values = [training_data[key]["data"] for key in training_data.keys()]
    else:
        # This is the version for NMMA CPU, 12/12/2023
        output_values = []
        for key in training_data.keys():
            data = training_data[key]
            new_list = [data[f] for f in filts]
            output_values.append(new_list)
    
    return np.array(output_values)

In [None]:
# Sanity check:
keys = list(training_data.keys())
example = training_data[keys[0]]
print(example.keys())

dict_keys(['log10_mej_dyn', 'vej_dyn', 'Yedyn', 'log10_mej_wind', 'vej_wind', 'KNtheta', 't', 'bessellux', 'bessellb', 'bessellv', 'bessellr', 'besselli', 'sdssu', 'ps1__g', 'ps1__r', 'ps1__i', 'ps1__z', 'ps1__y', 'uvot__b', 'uvot__u', 'uvot__uvm2', 'uvot__uvw1', 'uvot__uvw2', 'uvot__v', 'uvot__white', 'atlasc', 'atlaso', '2massj', '2massh', '2massks', 'ztfg', 'ztfr', 'ztfi'])


In [None]:
input_values = get_input_values(training_data, parameters)
output_values = get_output_values(training_data)

In [None]:
print(np.shape(input_values))
print(np.shape(output_values))

(7700, 6)
(7700, 3, 100)


In [None]:
N = 1000
# Select a random subset of indices for the input values
idx_list = np.random.choice(len(input_values), N, replace=False)
sampled_input_values = input_values[idx_list]
sampled_output_values = output_values[idx_list]

Get both outputs flax output

In [None]:
# For this list, we compute the LCs using the flax model
flax_output = []
start = time.time()

old_lc_model_output = []
new_lc_model_output = []

for i in range(len(sampled_input_values)):
    ### OLD model
    # Compute the lightcurve
    _, _, mag = nmma.em.utils.calc_lc(t,
                                sampled_input_values[i], 
                                svd_mag_model = old_lc_model.svd_mag_model, 
                                interpolation_type="tensorflow", 
                                filters = old_filts, 
                                mag_ncoeff = 10
                                )
    # Convert this dictionary to values of the LCs
    mag = mag.values()
    mag = np.array(list(mag))
    old_lc_model_output.append(mag)
    
    ### NEW model
    # Compute the lightcurve
    _, _, mag = nmma.em.utils.calc_lc(t,
                                sampled_input_values[i], 
                                svd_mag_model = new_lc_model.svd_mag_model, 
                                interpolation_type="tensorflow", 
                                filters = filts, 
                                mag_ncoeff = 10
                                )
    # Convert this dictionary to values of the LCs
    mag = mag.values()
    mag = np.array(list(mag))
    new_lc_model_output.append(mag)
end = time.time()
print(f"Computing all the lightcurves for a subset of {N} lightcurves took {end-start} seconds for both new and old model.")

# Convert to np arrays
old_lc_model_output = np.array(old_lc_model_output)
new_lc_model_output = np.array(new_lc_model_output)


Computing all the lightcurves for a subset of 1000 lightcurves took 10.635529041290283 seconds for both new and old model.


## Compare MSE or MAE values

TODO: Best to compare this as a distribution, and perhaps best to consider MAE, or some self-defined loss function or error function?

In [None]:
def mse(y_true, y_pred, axis=None):
    return np.mean((y_true - y_pred)**2, axis=axis)

def se(y_true, y_pred):
    return (y_true - y_pred)**2

def mae(y_true, y_pred, axis=None):
    return np.mean(np.abs(y_true - y_pred), axis=axis)

def ae(y_true, y_pred):
    return np.abs(y_true - y_pred)

def my_format(low: float, med: float, high: float, nb: int = 3) -> str:
    med = np.round(med, nb)
    low = med - low
    low = np.round(low, nb)
    high = high - med
    high = np.round(high, nb)
    
    return f"{med} - {low} + {high}"

# # TODO with arviz summarize the errors
# def summarize_data(values: np.array, percentile: float = 0.95) -> None:
    
#     med = np.median(values)
#     result = arviz.hdi(values, hdi_prob = percentile)
    
#     print(my_format(low, med, high))
    
#     return

In [None]:
# which_dataset = flax_output
# which_error = mae
nb_round = 5
for error_fn, name in zip([mse, mae], ["MSE", "MAE"]):
    for dataset, dataset_name in zip([old_lc_model_output, new_lc_model_output], ["old", "new"]):
        # diffs = se(dataset, sampled_output_values)
        print(f"Computing {name} for {dataset_name} model...")
        axis = 0
        mse_values = error_fn(dataset, sampled_output_values, axis=axis)
        mse_values = np.mean(mse_values, axis=0)
        for f, val in zip(filts, mse_values):
            print(f"{f}: {np.round(val, nb_round)}")

Computing MSE for old model...
ztfg: 2.52212
ztfi: 1.80887
ztfr: 1.08441
Computing MSE for new model...
ztfg: 2.16946
ztfi: 1.48286
ztfr: 0.82267
Computing MAE for old model...
ztfg: 0.98304
ztfi: 0.87592
ztfr: 0.72222
Computing MAE for new model...
ztfg: 0.89032
ztfi: 0.76405
ztfr: 0.60898
