# Kilonova surrogate modelling with $\texttt{jax}$ and $\texttt{flax}$

**Abstract:** Building and training KN surrogate models with jax and flax. 

See the new tutorial notebook in the main NMMA repository to get a sense of these individual steps.

In [38]:
%load_ext autoreload 
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from nmma.em.training import SVDTrainingModel
import nmma as nmma

params = {"axes.grid": True,
        "text.usetex" : True,
        "font.family" : "serif",
        "ytick.color" : "black",
        "xtick.color" : "black",
        "axes.labelcolor" : "black",
        "axes.edgecolor" : "black",
        "font.serif" : ["Computer Modern Serif"],
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "axes.labelsize": 16,
        "legend.fontsize": 16,
        "legend.title_fontsize": 16,
        "figure.titlesize": 16}

plt.rcParams.update(params)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [39]:
import jax
import jaxlib
jax.devices() # check if CUDA is present

[cuda(id=0)]

## Preprocessing data

In [40]:
lcs_dir = "/home/urash/twouters/KN_Lightcurves/lightcurves/lcs_bulla_2022" # for remote SSH Potsdam
out_dir = "/home/urash/twouters/nmma_models/flax_models/" # initial flax models will be saved here
filenames = os.listdir(lcs_dir)
full_filenames = [os.path.join(lcs_dir, f) for f in filenames]
print(f"There are {len(full_filenames)} lightcurves for this model.")

There are 7700 lightcurves for this model.


In [41]:
from nmma.em.io import read_photometry_files
from nmma.em.utils import interpolate_nans

data = read_photometry_files(full_filenames)
data = interpolate_nans(data)

In [None]:
import inspect 
import nmma.em.model_parameters as model_parameters

MODEL_FUNCTIONS = {
    k: v for k, v in model_parameters.__dict__.items() if inspect.isfunction(v)
}

model_name = "Bu2022Ye"
model_function = MODEL_FUNCTIONS[model_name]
training_data, parameters = model_function(data)

In [None]:
key = list(training_data.keys())[0]
example = training_data[key]
t = example["t"]
keys = list(example.keys())
filts = [k for k in keys if k not in parameters + ["t"]]
print(filts)

['bessellux', 'bessellb', 'bessellv', 'bessellr', 'besselli', 'sdssu', 'ps1__g', 'ps1__r', 'ps1__i', 'ps1__z', 'ps1__y', 'uvot__b', 'uvot__u', 'uvot__uvm2', 'uvot__uvw1', 'uvot__uvw2', 'uvot__v', 'uvot__white', 'atlasc', 'atlaso', '2massj', '2massh', '2massks', 'ztfg', 'ztfr', 'ztfi']


## Flax model


In [None]:
training_model = SVDTrainingModel(
        model_name,
        training_data,
        parameters,
        t,
        filts,
        interpolation_type="flax",
        svd_path=out_dir # initial flax models will be saved here
    )

print(training_model.svd_path)

The grid will be interpolated to sample_time with interp1d
Model exists... will load that model.
getting model
getting model: OK


KeyError: 'gmodelps'

In [None]:
training_model.__dict__.keys()

dict_keys(['model', 'data', 'model_parameters', 'sample_times', 'filters', 'n_coeff', 'n_epochs', 'interpolation_type', 'data_type', 'data_time_unit', 'plot', 'plotdir', 'ncpus', 'univariate_spline', 'univariate_spline_s', 'random_seed', 'svd_path'])

How to generate the lightcurve

In [None]:
example_parameters = [-2.30103, 0.12, 0.3, -1.30103, 0.03, 25.84]

In [None]:
test = nmma.em.utils.calc_lc(t,
                            example_parameters, 
                            svd_mag_model = training_model, 
                            interpolation_type="flax", 
                            filters = filts, 
                            )

TypeError: 'SVDTrainingModel' object is not subscriptable