# Training KN flax models

In [1]:
%load_ext autoreload 
%autoreload 2

import inspect 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import time
import arviz
from sklearn.model_selection import train_test_split

# NMMA imports
from nmma.em.training import SVDTrainingModel
import nmma as nmma
from nmma.em.io import read_photometry_files
from nmma.em.utils import interpolate_nans
import nmma.em.model_parameters as model_parameters

### jax and friends
import jax
import jax.numpy as jnp
from flax import linen as nn  # Linen API
from flax.training import train_state  # Useful dataclass to keep train state
from flax import struct                # Flax dataclasses
import optax

print(jax.devices()) # check presence of CUDA is OK

params = {"axes.grid": True,
        "text.usetex" : True,
        "font.family" : "serif",
        "ytick.color" : "black",
        "xtick.color" : "black",
        "axes.labelcolor" : "black",
        "axes.edgecolor" : "black",
        "font.serif" : ["Computer Modern Serif"],
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "axes.labelsize": 16,
        "legend.fontsize": 16,
        "legend.title_fontsize": 16,
        "figure.titlesize": 16}
plt.rcParams.update(params)

# Get preprocessing function to read the data
MODEL_FUNCTIONS = {
    k: v for k, v in model_parameters.__dict__.items() if inspect.isfunction(v)
}

model_name = "Bu2022Ye"
model_function = MODEL_FUNCTIONS[model_name]

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  from .autonotebook import tqdm as notebook_tqdm


Install afterglowpy if you want to simulate afterglows.
Install wrapt_timeout_decorator if you want timeout simulations.


## Data preprocessing

In [None]:
lcs_dir = "/home/urash/twouters/KN_Lightcurves/lightcurves/lcs_bulla_2022" # for remote SSH Potsdam
filenames = os.listdir(lcs_dir)
full_filenames = [os.path.join(lcs_dir, f) for f in filenames]

print("Cleaning data...")
data = read_photometry_files(full_filenames)
data = interpolate_nans(data)

print("Getting training data...")
training_data, parameters = model_function(data)

key = list(training_data.keys())[0]
example = training_data[key]
t = example["t"]
keys = list(example.keys())
filts = [k for k in keys if k not in parameters + ["t"]]
print("Filters:")
print(filts)

print("Getting the SVD model, start_training=False")
svd_ncoeff = 10
training_model = SVDTrainingModel(
        model_name,
        training_data,
        parameters,
        t,
        filts,
        n_coeff=svd_ncoeff,
        interpolation_type="flax",
        start_training=False # don't train, just prep the data, we train later on
    )

In [None]:
print("Fitting SVD etc")
svd_model = training_model.generate_svd_model()
training_model.svd_model = svd_model
print("Fitting SVD etc DONE")

## Training the NN on SVD-decomposed data

In [None]:

# training_model.train_model()

# X = training_model.svd_model[filts[0]]['param_array_postprocess']
# print(f"Features have shape {X.shape}")

# y = training_model.svd_model[filts[0]]['cAmat'].T
# _, output_ndim = y.shape
# print(f"Labels have shape {y.shape}")

# train_X, val_X, train_y, val_y = train_test_split(X, y, random_state=0)

## Training the NN on pure lightcurve data