# Training KN flax models

In [1]:
%load_ext autoreload 
%autoreload 2

import inspect 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import time
import arviz
from sklearn.model_selection import train_test_split

# NMMA imports
from nmma.em.training import SVDTrainingModel
import nmma as nmma
from nmma.em.io import read_photometry_files
from nmma.em.utils import interpolate_nans
import nmma.em.model_parameters as model_parameters

### jax and friends
import jax
import jax.numpy as jnp
from flax import linen as nn  # Linen API
from flax.training import train_state  # Useful dataclass to keep train state
from flax import struct                # Flax dataclasses
import optax

print(jax.devices()) # check presence of CUDA is OK

params = {"axes.grid": True,
        "text.usetex" : True,
        "font.family" : "serif",
        "ytick.color" : "black",
        "xtick.color" : "black",
        "axes.labelcolor" : "black",
        "axes.edgecolor" : "black",
        "font.serif" : ["Computer Modern Serif"],
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "axes.labelsize": 16,
        "legend.fontsize": 16,
        "legend.title_fontsize": 16,
        "figure.titlesize": 16}
plt.rcParams.update(params)

# Get preprocessing function to read the data
MODEL_FUNCTIONS = {
    k: v for k, v in model_parameters.__dict__.items() if inspect.isfunction(v)
}

model_name = "Bu2022Ye"
model_function = MODEL_FUNCTIONS[model_name]

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  from .autonotebook import tqdm as notebook_tqdm


Install afterglowpy if you want to simulate afterglows.
Install wrapt_timeout_decorator if you want timeout simulations.
[cuda(id=0)]


## Data preprocessing

In [2]:
lcs_dir = "/home/urash/twouters/KN_Lightcurves/lightcurves/lcs_bulla_2022" # for remote SSH Potsdam
filenames = os.listdir(lcs_dir)
full_filenames = [os.path.join(lcs_dir, f) for f in filenames]

out_dir = "/home/urash/twouters/nmma_models/flax_models_new/" # the trained models will be saved here
# Check if out directory exists, if not, create it
print("Checking if output directory exists and cleaning it...")
if not os.path.isdir(out_dir):
    os.mkdir(out_dir)
# If the directory exists, clean it
else:
    for file in os.listdir(out_dir):
        os.remove(os.path.join(out_dir, file))

print("Cleaning data...")
data = read_photometry_files(full_filenames)
data = interpolate_nans(data)

print("Getting training data...")
training_data, parameters = model_function(data)

key = list(training_data.keys())[0]
example = training_data[key]
t = example["t"]
keys = list(example.keys())
filts = [k for k in keys if k not in parameters + ["t"]]
print("Filters:")
print(filts)

print("Getting the SVD model, start_training=False")
svd_ncoeff = 10
training_model = SVDTrainingModel(
        model_name,
        training_data,
        parameters,
        t,
        filts,
        n_coeff=svd_ncoeff,
        interpolation_type="flax",
        svd_path=out_dir,
        start_training=False, # don't train, just prep the data, we train later on
        load_model=False # don't load model, we train later on
    )

Checking if output directory exists and cleaning it...
Cleaning data...
Getting training data...
Filters:
['bessellux', 'bessellb', 'bessellv', 'bessellr', 'besselli', 'sdssu', 'ps1__g', 'ps1__r', 'ps1__i', 'ps1__z', 'ps1__y', 'uvot__b', 'uvot__u', 'uvot__uvm2', 'uvot__uvw1', 'uvot__uvw2', 'uvot__v', 'uvot__white', 'atlasc', 'atlaso', '2massj', '2massh', '2massks', 'ztfg', 'ztfr', 'ztfi']
Getting the SVD model, start_training=False
The grid will be interpolated to sample_time with interp1d
Not loading new model


In [3]:
print("Fitting SVD etc")
svd_model = training_model.generate_svd_model()
training_model.svd_model = svd_model
print("Fitting SVD etc DONE")

Fitting SVD etc
Normalizing mag filter bessellux...
Normalizing mag filter bessellb...
Normalizing mag filter bessellv...
Normalizing mag filter bessellr...
Normalizing mag filter besselli...
Normalizing mag filter sdssu...
Normalizing mag filter ps1__g...
Normalizing mag filter ps1__r...
Normalizing mag filter ps1__i...
Normalizing mag filter ps1__z...
Normalizing mag filter ps1__y...
Normalizing mag filter uvot__b...
Normalizing mag filter uvot__u...
Normalizing mag filter uvot__uvm2...
Normalizing mag filter uvot__uvw1...
Normalizing mag filter uvot__uvw2...
Normalizing mag filter uvot__v...
Normalizing mag filter uvot__white...
Normalizing mag filter atlasc...
Normalizing mag filter atlaso...
Normalizing mag filter 2massj...
Normalizing mag filter 2massh...
Normalizing mag filter 2massks...
Normalizing mag filter ztfg...
Normalizing mag filter ztfr...
Normalizing mag filter ztfi...
Fitting SVD etc DONE


## Training the NN on SVD-decomposed data

In [4]:
# training_model.train_model()

# X = training_model.svd_model[filts[0]]['param_array_postprocess']
# print(f"Features have shape {X.shape}")

# y = training_model.svd_model[filts[0]]['cAmat'].T
# _, output_ndim = y.shape
# print(f"Labels have shape {y.shape}")

# train_X, val_X, train_y, val_y = train_test_split(X, y, random_state=0)

# training_model.train_model()

## Training the NN on pure lightcurve data

In [5]:
f = filts[0]
example = training_model.svd_model[f]
print(example.keys())

dict_keys(['param_array_postprocess', 'param_mins', 'param_maxs', 'mins', 'maxs', 'data_postprocess', 'tt', 'n_coeff', 'cAmat', 'cAstd', 'VA'])


In [6]:
example_lc = example["data_postprocess"][0]
print(np.shape(example_lc))
print(example_lc)

(100,)
[0.0826376  0.15969144 0.16007111 0.23861728 0.20389001 0.23419204
 0.21398067 0.25880838 0.25490905 0.2163728  0.26100838 0.27132975
 0.3287103  0.27528475 0.25608953 0.27083055 0.2250996  0.2271274
 0.25033793 0.25257262 0.24311594 0.24934321 0.24304573 0.21190975
 0.20025202 0.18777384 0.20451045 0.1688794  0.16265685 0.16674764
 0.14438413 0.13087566 0.13134213 0.11371091 0.09845088 0.09084181
 0.08185636 0.07280343 0.06862349 0.0710695  0.06948268 0.06184098
 0.05235351 0.05395652 0.06808086 0.06734241 0.04973246 0.06139467
 0.06115733 0.0680774  0.05455876 0.06639323 0.06043593 0.07564822
 0.07554679 0.07642581 0.07406344 0.07451619 0.07369871 0.06734786
 0.06828704 0.06667483 0.05962791 0.06161439 0.05520395 0.05569221
 0.05585517 0.05421866 0.05890765 0.06243982 0.06608411 0.06731206
 0.05943458 0.06467805 0.07009179 0.08482901 0.06662319 0.04768074
 0.06146692 0.07485542 0.04930198 0.10634124 0.10679315 0.18439244
 0.27207393 0.26129637 0.19121844 0.14393029 0.0612196  

In [7]:
training_model

<nmma.em.training.SVDTrainingModel at 0x7fd3d0207670>

In [9]:
import nmma.em.utils_flax as utils_flax

key = jax.random.PRNGKey(0)

for jj, filt in enumerate(training_model.filters):
    # Split the random key to get a PRNG key for initialization of the network parameters
    key, init_key = jax.random.split(key)
    print("Computing NN (using flax) for filter %s..." % filt)

    param_array_postprocess = training_model.svd_model[filt]["param_array_postprocess"]
    cAmat = training_model.svd_model[filt]["cAmat"]

    train_X, val_X, train_y, val_y = train_test_split(
        param_array_postprocess,
        cAmat.T,
        shuffle=True,
        test_size=0.25,
        random_state=training_model.random_seed,
    )

    # Config holds everything for the training setup
    config = utils_flax.get_default_config()
    # Input dimension can be found inside param array postprocess TODO can this be done more elegantly?
    input_ndim = training_model.svd_model[filt]["param_array_postprocess"].shape[1]

    # TODO - make architecture also part of config, if changed later on?
    # Create neural network and initialize the state
    net = utils_flax.MLP(layer_sizes=config.layer_sizes, act_func=config.act_func)
    state = utils_flax.create_train_state(net, jnp.ones(input_ndim), init_key, config)

    # Perform training loop
    state, train_losses, val_losses = utils_flax.train_loop(state, train_X, train_y, val_X, val_y, config)
    
    training_model.svd_model[filt]["model"] = state

Computing NN (using flax) for filter bessellux...


2023-12-12 18:02:27.461969: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 16.18MiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 16962016 bytes.