# Comparing KN surrogate models

**Abstract:** Here we compare the MSE of the surrogate models between, e.g. Tensorflow and Jax/flax models.

In [1]:
%load_ext autoreload 
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from nmma.em.training import SVDTrainingModel
import nmma as nmma
import time
import arviz

params = {"axes.grid": True,
        "text.usetex" : True,
        "font.family" : "serif",
        "ytick.color" : "black",
        "xtick.color" : "black",
        "axes.labelcolor" : "black",
        "axes.edgecolor" : "black",
        "font.serif" : ["Computer Modern Serif"],
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "axes.labelsize": 16,
        "legend.fontsize": 16,
        "legend.title_fontsize": 16,
        "figure.titlesize": 16}

plt.rcParams.update(params)

from nmma.em.io import read_photometry_files
from nmma.em.utils import interpolate_nans

import inspect 
import nmma.em.model_parameters as model_parameters

MODEL_FUNCTIONS = {
    k: v for k, v in model_parameters.__dict__.items() if inspect.isfunction(v)
}

model_name = "Bu2022Ye"
model_function = MODEL_FUNCTIONS[model_name]

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  from .autonotebook import tqdm as notebook_tqdm


Install afterglowpy if you want to simulate afterglows.
Install wrapt_timeout_decorator if you want timeout simulations.


In [2]:
import jax
import jaxlib
jax.devices() # check if CUDA is present

[cuda(id=0)]

## Preprocessing data

In [3]:
lcs_dir = "/home/urash/twouters/KN_Lightcurves/lightcurves/lcs_bulla_2022" # for remote SSH Potsdam
out_dir = "/home/urash/twouters/nmma_models/flax_models/" # initial flax models will be saved here
filenames = os.listdir(lcs_dir)
full_filenames = [os.path.join(lcs_dir, f) for f in filenames]
print(f"There are {len(full_filenames)} lightcurves for this model.")

There are 7700 lightcurves for this model.


In [4]:
# Read the data and interpolate the NaNs, and get training data
data = read_photometry_files(full_filenames)
data = interpolate_nans(data)
training_data, parameters = model_function(data)

In [5]:
#Extract times and filters
key = list(training_data.keys())[0]
example = training_data[key]
t = example["t"]
keys = list(example.keys())
filts = [k for k in keys if k not in parameters + ["t"]]

## Get the flax model


In [6]:
training_model = SVDTrainingModel(
        model_name,
        training_data,
        parameters,
        t,
        filts,
        interpolation_type="flax",
        svd_path=out_dir # initial flax models will be saved here
    )

print(training_model.svd_path)

The grid will be interpolated to sample_time with interp1d
Model exists... will load that model.
/home/urash/twouters/nmma_models/flax_models/


In [7]:
training_model.__dict__.keys()

dict_keys(['model', 'data', 'model_parameters', 'sample_times', 'filters', 'n_coeff', 'n_epochs', 'interpolation_type', 'data_type', 'data_time_unit', 'plot', 'plotdir', 'ncpus', 'univariate_spline', 'univariate_spline_s', 'random_seed', 'svd_path', 'svd_model'])

## Get the Tensorflow model (to do)

In [8]:
# Read the data and interpolate the NaNs, and get training data
data = read_photometry_files(full_filenames)
data = interpolate_nans(data)
training_data, parameters = model_function(data)

KeyboardInterrupt: 

In [None]:
svd_path_tensorflow = "/home/urash/twouters/nmma_models"
tensorflow_model = SVDTrainingModel(
        model_name,
        training_data,
        parameters,
        t,
        filts,
        interpolation_type="tensorflow",
        svd_path=svd_path_tensorflow
    )

The grid will be interpolated to sample_time with interp1d
Model exists... will load that model.


2023-12-12 15:32:57.704761: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-12 15:32:57.704856: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-12 15:32:57.704930: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


/home/urash/twouters/nmma_models/flax_models/


## Get the input and output pairs of the Bu2022Ye model

In [None]:
def get_input_values(training_data: dict, parameters: list) -> np.ndarray:
    """
    From a dictionary of training data, extract the input values for the model.
    """
    
    input_values = []
    
    for key in training_data.keys():
        data = training_data[key]
        input_values.append([data[param] for param in parameters])
    
    return np.array(input_values)

In [None]:
def get_output_values(training_data: dict) -> np.ndarray:
    """
    From a dictionary of training data, extract the output values for the model.
    """
    
    output_values = [training_data[key]["data"] for key in training_data.keys()]
    
    return np.array(output_values)

In [None]:
input_values = get_input_values(training_data, parameters)
output_values = get_output_values(training_data)

In [None]:
print(np.shape(input_values))
print(np.shape(output_values))

(7700, 6)
(7700, 100, 26)


NOTE: this can be done faster with jax.vmap etc, but not for 

In [None]:
N = 100
# Select a random subset of indices for the input values
idx_list = np.random.choice(len(input_values), N, replace=False)
sampled_input_values = input_values[idx_list]
sampled_output_values = output_values[idx_list]

Get the flax output

In [None]:
# For this list, we compute the LCs using the flax model
flax_output = []
start = time.time()
for i in idx_list:
    # Compute the lightcurve
    _, _, mag = nmma.em.utils.calc_lc(t,
                                input_values[i], 
                                svd_mag_model = training_model.svd_model, 
                                interpolation_type="flax", 
                                filters = filts, 
                                )
    # Convert this dictionary to values of the LCs
    mag = mag.values()
    mag = np.array(list(mag)).T
    flax_output.append(mag)
end = time.time()
print(f"Computing all the flax lightcurves for a subset of {N} lightcurves took {end-start} seconds.")
# Make sure this is a np.ndarray
flax_output = np.array(flax_output)

Install afterglowpy if you want to simulate afterglows.
Install wrapt_timeout_decorator if you want timeout simulations.


TypeError: Expected a callable value, got (array([ 0.103,  0.109,  0.116,  0.122,  0.13 ,  0.137,  0.145,  0.154,
        0.163,  0.173,  0.183,  0.194,  0.206,  0.218,  0.231,  0.244,
        0.259,  0.274,  0.29 ,  0.308,  0.326,  0.345,  0.366,  0.387,
        0.41 ,  0.435,  0.461,  0.488,  0.517,  0.547,  0.58 ,  0.614,
        0.651,  0.689,  0.73 ,  0.773,  0.819,  0.868,  0.919,  0.974,
        1.032,  1.093,  1.158,  1.226,  1.299,  1.376,  1.458,  1.544,
        1.636,  1.733,  1.836,  1.944,  2.06 ,  2.182,  2.311,  2.448,
        2.594,  2.747,  2.91 ,  3.083,  3.266,  3.459,  3.665,  3.882,
        4.112,  4.356,  4.614,  4.888,  5.178,  5.485,  5.81 ,  6.155,
        6.52 ,  6.906,  7.316,  7.75 ,  8.21 ,  8.696,  9.212,  9.759,
       10.337, 10.95 , 11.6  , 12.288, 13.016, 13.788, 14.606, 15.472,
       16.39 , 17.362, 18.392, 19.482, 20.638, 21.862, 23.158, 24.532,
       25.987, 27.528, 29.16 , 30.89 ]), array([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf]), {'bessellux': Array([-16.751293 , -18.977451 , -18.94074  , -18.91399  , -18.814096 ,
       -18.700592 , -18.60614  , -18.485725 , -18.386997 , -18.311417 ,
       -18.246378 , -18.148388 , -18.064178 , -18.011395 , -17.95673  ,
       -17.917662 , -17.85725  , -17.803675 , -17.766825 , -17.70323  ,
       -17.631596 , -17.553183 , -17.47446  , -17.364845 , -17.247435 ,
       -17.139841 , -17.004412 , -16.881496 , -16.71247  , -16.554745 ,
       -16.372353 , -16.182894 , -15.9832115, -15.771519 , -15.5336685,
       -15.304959 , -15.02552  , -14.725747 , -14.465273 , -14.1835575,
       -13.905733 , -13.56159  , -13.30611  , -12.993885 , -12.751465 ,
       -12.436157 , -12.120747 , -11.849781 , -11.598325 , -11.243218 ,
       -10.979667 , -10.650383 , -10.345467 , -10.013781 ,  -9.685959 ,
        -9.35085  ,  -8.964596 ,  -8.849244 ,  -8.635338 ,  -8.095481 ,
        -7.8237476,  -7.5578537,  -7.2342424,  -7.380654 ,  -6.9913883,
        -6.498862 ,  -6.13962  ,  -5.394746 ,  -5.10202  ,  -4.404832 ,
        -3.2549524,  -2.6166115,  -0.5141649,   2.2920647,   4.8976793,
         8.034873 ,  11.520538 ,  14.527928 ,  15.5009775,  17.339388 ,
        19.174082 ,  21.039932 ,  22.545977 ,  24.501453 ,  26.897596 ,
        30.731693 ,  35.254303 ,  40.084644 ,  44.743538 ,  49.29722  ,
        52.381268 ,  55.152714 ,  58.204903 ,  60.022858 ,  61.907036 ,
        63.266647 ,  64.91644  ,  66.38286  ,  68.02579  ,  69.76727  ],      dtype=float32), 'bessellb': Array([-16.309128 , -18.783136 , -18.90305  , -18.892752 , -18.867716 ,
       -18.806423 , -18.745462 , -18.660173 , -18.571793 , -18.496632 ,
       -18.389078 , -18.306093 , -18.217342 , -18.107147 , -18.015453 ,
       -17.950567 , -17.890646 , -17.810745 , -17.756104 , -17.686335 ,
       -17.624882 , -17.550453 , -17.49119  , -17.415472 , -17.328423 ,
       -17.236637 , -17.155731 , -17.062426 , -16.971764 , -16.872553 ,
       -16.760162 , -16.638058 , -16.496582 , -16.36097  , -16.20666  ,
       -16.015234 , -15.83116  , -15.656319 , -15.452609 , -15.246688 ,
       -15.036016 , -14.798851 , -14.5674305, -14.361815 , -14.147237 ,
       -13.9093485, -13.709379 , -13.490095 , -13.253906 , -13.03146  ,
       -12.796795 , -12.55353  , -12.311841 , -12.027464 , -11.756523 ,
       -11.475058 , -11.176333 , -10.832855 , -10.54365  , -10.214512 ,
        -9.883816 ,  -9.560476 ,  -9.204601 ,  -8.836167 ,  -8.483488 ,
        -8.095986 ,  -7.741948 ,  -7.3040047,  -6.9602427,  -6.4538217,
        -6.1068964,  -5.690644 ,  -5.1574993,  -4.525301 ,  -3.9455538,
        -3.3897858,  -2.3562784,  -1.5402222,  -0.5834913,   0.8914299,
         2.151147 ,   2.694314 ,   3.554696 ,   3.9009743,   3.1202955,
         1.9719787,   1.858006 ,   3.7686214,   8.874469 ,  13.899042 ,
        18.443184 ,  22.965801 ,  28.474243 ,  33.068863 ,  37.902878 ,
        42.7128   ,  48.11064  ,  52.873917 ,  58.202698 ,  63.905365 ],      dtype=float32), 'bessellv': Array([-15.849637  , -18.482498  , -18.564224  , -18.610699  ,
       -18.644447  , -18.612736  , -18.573618  , -18.475414  ,
       -18.402103  , -18.327763  , -18.201773  , -18.106447  ,
       -18.003498  , -17.910772  , -17.824392  , -17.740467  ,
       -17.640379  , -17.562002  , -17.489098  , -17.456892  ,
       -17.407028  , -17.348299  , -17.32369   , -17.274317  ,
       -17.229382  , -17.177849  , -17.112076  , -17.055222  ,
       -16.994482  , -16.932205  , -16.859774  , -16.78524   ,
       -16.709627  , -16.624437  , -16.537699  , -16.432632  ,
       -16.310627  , -16.18729   , -16.052572  , -15.912987  ,
       -15.769782  , -15.61715   , -15.4595175 , -15.308001  ,
       -15.160095  , -15.01586   , -14.881351  , -14.732122  ,
       -14.58401   , -14.431541  , -14.267277  , -14.110758  ,
       -13.94143   , -13.774079  , -13.59107   , -13.405139  ,
       -13.209355  , -12.991715  , -12.786954  , -12.549671  ,
       -12.298037  , -12.040736  , -11.780163  , -11.498024  ,
       -11.201969  , -10.914095  , -10.597845  , -10.302679  ,
        -9.988976  ,  -9.633313  ,  -9.278353  ,  -8.896804  ,
        -8.5273075 ,  -8.117317  ,  -7.6661315 ,  -7.1881876 ,
        -6.7437377 ,  -6.277047  ,  -5.7604184 ,  -5.218246  ,
        -4.6575975 ,  -4.0856786 ,  -3.2725782 ,  -2.6690674 ,
        -1.92521   ,  -1.081543  ,  -0.60533047,  -0.45557404,
        -0.07952023,   0.34092522,   0.39949608,   1.700201  ,
         4.410519  ,   7.247734  ,  10.361893  ,  14.433945  ,
        19.1445    ,  24.580658  ,  30.585602  ,  36.95114   ],      dtype=float32), 'bessellr': Array([-16.302277  , -18.855976  , -18.899433  , -18.871824  ,
       -18.853073  , -18.816357  , -18.720966  , -18.650497  ,
       -18.552517  , -18.465002  , -18.362104  , -18.276165  ,
       -18.126282  , -18.015757  , -17.878946  , -17.788185  ,
       -17.699785  , -17.60648   , -17.536297  , -17.464073  ,
       -17.404842  , -17.369167  , -17.336172  , -17.309772  ,
       -17.255121  , -17.227282  , -17.19997   , -17.16262   ,
       -17.12644   , -17.106007  , -17.037193  , -16.989977  ,
       -16.95648   , -16.919172  , -16.878796  , -16.823915  ,
       -16.764357  , -16.681795  , -16.591715  , -16.524275  ,
       -16.428785  , -16.336191  , -16.242289  , -16.150637  ,
       -16.071629  , -15.982826  , -15.907478  , -15.822877  ,
       -15.730568  , -15.641329  , -15.540405  , -15.436243  ,
       -15.317936  , -15.194929  , -15.065442  , -14.945208  ,
       -14.814129  , -14.67357   , -14.527254  , -14.367859  ,
       -14.210293  , -14.052744  , -13.878063  , -13.690914  ,
       -13.491549  , -13.293713  , -13.073487  , -12.8358    ,
       -12.587292  , -12.334347  , -12.060528  , -11.766717  ,
       -11.461586  , -11.124941  , -10.776482  , -10.416521  ,
       -10.03817   ,  -9.675401  ,  -9.296481  ,  -8.907358  ,
        -8.538157  ,  -8.139436  ,  -7.747313  ,  -7.358813  ,
        -6.997682  ,  -6.631397  ,  -6.206792  ,  -5.839556  ,
        -5.6056366 ,  -5.759703  ,  -5.914816  ,  -5.6711206 ,
        -4.079891  ,  -1.8490028 ,   0.2322712 ,   1.8976822 ,
         2.0095177 ,   0.6936035 ,  -0.40873718,  -1.577301  ],      dtype=float32), 'besselli': Array([-15.648238  , -18.375725  , -18.544743  , -18.546608  ,
       -18.58786   , -18.549358  , -18.523165  , -18.461584  ,
       -18.357218  , -18.242828  , -18.14328   , -18.039835  ,
       -17.93264   , -17.81729   , -17.729265  , -17.633673  ,
       -17.528992  , -17.438463  , -17.38109   , -17.36412   ,
       -17.33354   , -17.288935  , -17.271788  , -17.238472  ,
       -17.2332    , -17.229645  , -17.196663  , -17.180244  ,
       -17.160696  , -17.147598  , -17.128815  , -17.089993  ,
       -17.0705    , -17.03476   , -16.989628  , -16.945219  ,
       -16.90071   , -16.839676  , -16.786259  , -16.724003  ,
       -16.663733  , -16.605007  , -16.54404   , -16.483875  ,
       -16.407951  , -16.341095  , -16.272055  , -16.217718  ,
       -16.154121  , -16.082819  , -16.011969  , -15.934174  ,
       -15.847744  , -15.758758  , -15.666677  , -15.589335  ,
       -15.504266  , -15.410749  , -15.303921  , -15.185944  ,
       -15.076662  , -14.972307  , -14.864415  , -14.747538  ,
       -14.620932  , -14.490513  , -14.352356  , -14.196676  ,
       -14.031161  , -13.861691  , -13.677768  , -13.477423  ,
       -13.272227  , -13.03734   , -12.788935  , -12.518487  ,
       -12.239972  , -11.969065  , -11.670268  , -11.361851  ,
       -11.072774  , -10.763947  , -10.440321  , -10.128799  ,
        -9.820688  ,  -9.48521   ,  -9.134303  ,  -8.815043  ,
        -8.524538  ,  -8.396261  ,  -8.192995  ,  -7.5663695 ,
        -6.635439  ,  -4.957306  ,  -2.5247498 ,  -0.31444168,
         1.2374649 ,   1.3063507 ,   1.4540863 ,   1.6106873 ],      dtype=float32), 'sdssu': Array([-16.54648   , -18.807915  , -18.844185  , -18.870224  ,
       -18.75462   , -18.686298  , -18.610731  , -18.509138  ,
       -18.426586  , -18.366499  , -18.293013  , -18.21133   ,
       -18.161274  , -18.089577  , -18.042309  , -18.004053  ,
       -17.96433   , -17.892075  , -17.824705  , -17.76595   ,
       -17.70331   , -17.617258  , -17.513153  , -17.378979  ,
       -17.251627  , -17.131208  , -16.985386  , -16.838678  ,
       -16.668018  , -16.490635  , -16.30442   , -16.079508  ,
       -15.886096  , -15.631777  , -15.363921  , -15.104986  ,
       -14.791289  , -14.490036  , -14.173283  , -13.876753  ,
       -13.570662  , -13.245192  , -12.921764  , -12.633241  ,
       -12.331796  , -12.005334  , -11.736166  , -11.407467  ,
       -11.151259  , -10.889794  , -10.593793  , -10.330121  ,
       -10.049509  ,  -9.796641  ,  -9.550878  ,  -9.281895  ,
        -9.062252  ,  -8.88962   ,  -8.67458   ,  -8.425657  ,
        -8.234057  ,  -8.025497  ,  -7.8011365 ,  -7.6041455 ,
        -7.3691807 ,  -7.1457996 ,  -6.906489  ,  -6.6339235 ,
        -6.374765  ,  -6.1330185 ,  -5.881382  ,  -5.5906987 ,
        -5.2765985 ,  -4.98473   ,  -4.654337  ,  -4.315755  ,
        -3.947905  ,  -3.5816102 ,  -3.227591  ,  -2.7824879 ,
        -2.5172386 ,  -2.1338139 ,  -1.7429261 ,  -1.3034754 ,
        -0.98911285,  -0.41415977,   0.13832569,   0.49893522,
         0.742363  ,   0.60571814,   0.4073887 ,   1.268553  ,
         2.1786113 ,   3.0156755 ,   4.441366  ,   6.230213  ,
         7.1608124 ,   8.110947  ,   9.62291   ,  11.225616  ],      dtype=float32), 'ps1__g': Array([-16.030113  , -18.556503  , -18.658009  , -18.696165  ,
       -18.706587  , -18.682186  , -18.640715  , -18.558952  ,
       -18.48913   , -18.41789   , -18.288857  , -18.203363  ,
       -18.107876  , -18.008272  , -17.919067  , -17.845526  ,
       -17.763636  , -17.686567  , -17.62367   , -17.562729  ,
       -17.506083  , -17.43981   , -17.39213   , -17.32094   ,
       -17.25313   , -17.169062  , -17.08971   , -17.008118  ,
       -16.926018  , -16.8428    , -16.741272  , -16.640854  ,
       -16.53321   , -16.417662  , -16.303162  , -16.152452  ,
       -15.991707  , -15.837172  , -15.661578  , -15.487061  ,
       -15.309354  , -15.107672  , -14.907578  , -14.730427  ,
       -14.543741  , -14.346797  , -14.17284   , -13.978741  ,
       -13.783546  , -13.608024  , -13.398989  , -13.190023  ,
       -12.987444  , -12.763016  , -12.540903  , -12.291926  ,
       -12.034767  , -11.767007  , -11.504545  , -11.21234   ,
       -10.92881   , -10.611417  , -10.286892  ,  -9.93836   ,
        -9.6010685 ,  -9.284616  ,  -8.907492  ,  -8.523253  ,
        -8.20516   ,  -7.745872  ,  -7.3311443 ,  -6.907628  ,
        -6.4097867 ,  -5.867861  ,  -5.332534  ,  -4.834978  ,
        -4.1586847 ,  -3.5316124 ,  -2.5443745 ,  -1.5274343 ,
        -0.38869858,   0.09842491,   1.0464029 ,   1.92099   ,
         3.4690247 ,   4.6025543 ,   5.6727085 ,   8.387989  ,
        12.220398  ,  15.9719925 ,  20.041275  ,  24.93805   ,
        31.271164  ,  36.920364  ,  43.477356  ,  49.822296  ,
        57.147125  ,  63.918427  ,  71.22015   ,  78.960205  ],      dtype=float32), 'ps1__r': Array([-15.483446  , -18.110651  , -18.24542   , -18.30535   ,
       -18.347326  , -18.376303  , -18.37554   , -18.355864  ,
       -18.30812   , -18.294806  , -18.238707  , -18.162838  ,
       -18.064274  , -17.992804  , -17.910019  , -17.833433  ,
       -17.75556   , -17.671432  , -17.595943  , -17.541348  ,
       -17.496508  , -17.461689  , -17.429695  , -17.399683  ,
       -17.352951  , -17.307167  , -17.257582  , -17.215702  ,
       -17.158752  , -17.109364  , -17.048134  , -16.992971  ,
       -16.950235  , -16.89963   , -16.842499  , -16.767605  ,
       -16.694637  , -16.606627  , -16.511665  , -16.416973  ,
       -16.312214  , -16.202158  , -16.091236  , -15.976486  ,
       -15.873231  , -15.773337  , -15.68892   , -15.586396  ,
       -15.471753  , -15.351587  , -15.220101  , -15.085382  ,
       -14.9460945 , -14.805103  , -14.644435  , -14.493931  ,
       -14.339003  , -14.165812  , -13.985205  , -13.782293  ,
       -13.579292  , -13.3679495 , -13.153318  , -12.920231  ,
       -12.667095  , -12.41777   , -12.150534  , -11.878872  ,
       -11.570686  , -11.24859   , -10.920385  , -10.574147  ,
       -10.178253  ,  -9.805195  ,  -9.350393  ,  -8.891343  ,
        -8.42395   ,  -7.952493  ,  -7.402361  ,  -6.9177155 ,
        -6.547129  ,  -6.005271  ,  -5.338272  ,  -4.822984  ,
        -4.435397  ,  -3.924654  ,  -3.8285403 ,  -3.4762683 ,
        -2.8782773 ,  -1.855258  ,  -1.1233997 ,  -0.03410435,
         2.7343512 ,   5.20446   ,   7.7431107 ,  11.206802  ,
        15.111313  ,  19.378601  ,  24.122894  ,  29.152039  ],      dtype=float32), 'ps1__i': Array([-16.173225  , -18.826937  , -18.908346  , -18.845749  ,
       -18.840805  , -18.8012    , -18.717789  , -18.659771  ,
       -18.575548  , -18.473167  , -18.373922  , -18.284027  ,
       -18.161364  , -18.04766   , -17.937635  , -17.815605  ,
       -17.70512   , -17.601562  , -17.524012  , -17.4833    ,
       -17.4523    , -17.385174  , -17.361723  , -17.330608  ,
       -17.314625  , -17.295948  , -17.270864  , -17.238268  ,
       -17.213194  , -17.194254  , -17.162601  , -17.115093  ,
       -17.081604  , -17.054102  , -17.00893   , -16.947258  ,
       -16.898745  , -16.836004  , -16.778881  , -16.70948   ,
       -16.653124  , -16.586615  , -16.519188  , -16.451061  ,
       -16.380749  , -16.31391   , -16.253922  , -16.192358  ,
       -16.123997  , -16.046377  , -15.96349   , -15.869167  ,
       -15.762212  , -15.649428  , -15.532258  , -15.423274  ,
       -15.296728  , -15.160771  , -15.006073  , -14.837848  ,
       -14.673998  , -14.518722  , -14.349201  , -14.169172  ,
       -13.987568  , -13.81148   , -13.626747  , -13.422241  ,
       -13.213837  , -12.998818  , -12.76506   , -12.513693  ,
       -12.251691  , -11.954986  , -11.645338  , -11.323454  ,
       -10.989334  , -10.667326  , -10.3188095 ,  -9.980192  ,
        -9.667329  ,  -9.335608  ,  -8.990913  ,  -8.63635   ,
        -8.352436  ,  -8.096096  ,  -7.9260793 ,  -7.936622  ,
        -8.401615  ,  -8.922652  ,  -8.577123  ,  -6.6876698 ,
        -4.090405  ,  -2.8242798 ,  -1.5990448 ,  -0.80552673,
         1.2428436 ,   2.967392  ,   5.2423706 ,   7.6538086 ],      dtype=float32), 'ps1__z': Array([-17.748104 , -18.879087 , -19.016047 , -19.057371 , -19.201767 ,
       -19.116167 , -18.993755 , -18.91021  , -18.802591 , -18.652195 ,
       -18.571775 , -18.316826 , -18.215008 , -18.07901  , -18.07429  ,
       -17.910051 , -17.777657 , -17.701164 , -17.655525 , -17.588552 ,
       -17.498798 , -17.488483 , -17.458206 , -17.489956 , -17.412409 ,
       -17.467257 , -17.37884  , -17.394701 , -17.377735 , -17.343067 ,
       -17.331566 , -17.304335 , -17.30163  , -17.28161  , -17.208563 ,
       -17.181343 , -17.111324 , -17.067654 , -16.99641  , -16.949123 ,
       -16.919106 , -16.818584 , -16.754345 , -16.680084 , -16.61668  ,
       -16.559116 , -16.483072 , -16.440874 , -16.341507 , -16.269371 ,
       -16.20914  , -16.162523 , -16.06291  , -15.981928 , -15.892845 ,
       -15.790895 , -15.7091675, -15.595276 , -15.518221 , -15.421805 ,
       -15.3233   , -15.22784  , -15.1267605, -15.0176735, -14.90353  ,
       -14.8004465, -14.678591 , -14.554637 , -14.414427 , -14.256507 ,
       -14.091943 , -13.926966 , -13.741155 , -13.536703 , -13.338746 ,
       -13.117107 , -12.86856  , -12.648153 , -12.395281 , -12.121191 ,
       -11.887711 , -11.609364 , -11.372234 , -11.155703 , -10.852533 ,
       -10.608599 , -10.2776375, -10.027847 , -10.084621 , -10.187433 ,
        -9.982109 ,  -8.272735 ,  -6.4392676,  -5.7932997,  -4.6797905,
        -4.485321 ,  -5.5999565,  -8.875397 , -13.442028 , -18.28286  ],      dtype=float32), 'ps1__y': Array([-15.280893 , -17.689173 , -18.216358 , -18.447485 , -18.568314 ,
       -18.675688 , -18.692175 , -18.644146 , -18.632725 , -18.605755 ,
       -18.529257 , -18.455753 , -18.43392  , -18.287151 , -18.08537  ,
       -17.9077   , -17.718267 , -17.518023 , -17.44941  , -17.379902 ,
       -17.299133 , -17.266287 , -17.228405 , -17.214222 , -17.18322  ,
       -17.196133 , -17.177769 , -17.162466 , -17.144093 , -17.140566 ,
       -17.123024 , -17.106665 , -17.090656 , -17.060116 , -17.058605 ,
       -17.020134 , -17.003178 , -16.957312 , -16.908028 , -16.857573 ,
       -16.807833 , -16.75177  , -16.708992 , -16.657911 , -16.599724 ,
       -16.552326 , -16.510302 , -16.46044  , -16.40849  , -16.356863 ,
       -16.311996 , -16.259598 , -16.202635 , -16.137163 , -16.076168 ,
       -16.007288 , -15.936945 , -15.85627  , -15.763028 , -15.6764145,
       -15.581012 , -15.493702 , -15.3987465, -15.301661 , -15.209252 ,
       -15.109459 , -15.011099 , -14.904584 , -14.796735 , -14.684258 ,
       -14.565681 , -14.4390335, -14.303766 , -14.156435 , -14.003881 ,
       -13.840004 , -13.6707325, -13.502876 , -13.335838 , -13.190179 ,
       -13.036857 , -12.873279 , -12.710732 , -12.543224 , -12.350831 ,
       -12.12623  , -11.8747015, -11.587004 , -11.3375025, -11.424696 ,
       -11.359578 , -10.824601 ,  -8.887899 ,  -9.044867 ,  -9.250221 ,
        -9.790344 ,  -9.975357 ,  -9.688721 ,  -8.948608 ,  -8.1640625],      dtype=float32), 'uvot__b': Array([-16.407314  , -18.87761   , -18.992867  , -18.984484  ,
       -18.964344  , -18.890358  , -18.811031  , -18.727142  ,
       -18.616083  , -18.508213  , -18.379887  , -18.273073  ,
       -18.161112  , -18.033102  , -17.929432  , -17.849884  ,
       -17.790731  , -17.701868  , -17.639198  , -17.56354   ,
       -17.505249  , -17.43361   , -17.372755  , -17.308914  ,
       -17.235743  , -17.153906  , -17.084715  , -17.000532  ,
       -16.917747  , -16.832718  , -16.731272  , -16.6226    ,
       -16.491882  , -16.368694  , -16.227592  , -16.055897  ,
       -15.87947   , -15.70799   , -15.513462  , -15.311445  ,
       -15.1059265 , -14.877479  , -14.654379  , -14.442151  ,
       -14.226418  , -13.987337  , -13.768276  , -13.529065  ,
       -13.288548  , -13.060665  , -12.8280735 , -12.577612  ,
       -12.298759  , -12.002382  , -11.735022  , -11.434414  ,
       -11.115827  , -10.762122  , -10.466394  , -10.122644  ,
        -9.748995  ,  -9.399335  ,  -9.067306  ,  -8.704098  ,
        -8.387806  ,  -7.8970056 ,  -7.537551  ,  -7.1293306 ,
        -6.814447  ,  -6.5147333 ,  -6.0013885 ,  -5.5782967 ,
        -5.315764  ,  -4.9580336 ,  -4.196333  ,  -3.3954153 ,
        -2.3079796 ,  -0.88792133,   0.16502666,   1.4537697 ,
         2.7463102 ,   2.856905  ,   3.1053438 ,   2.7712975 ,
         3.6358213 ,   5.7300262 ,   9.642443  ,  14.012829  ,
        19.585518  ,  24.396824  ,  28.601345  ,  33.230297  ,
        39.478462  ,  44.916122  ,  50.332047  ,  54.916794  ,
        61.05011   ,  66.58551   ,  72.052826  ,  77.84839   ],      dtype=float32), 'uvot__u': Array([-1.6262770e+01, -1.8562687e+01, -1.8571199e+01, -1.8554790e+01,
       -1.8496716e+01, -1.8443750e+01, -1.8389111e+01, -1.8318354e+01,
       -1.8271307e+01, -1.8217104e+01, -1.8162924e+01, -1.8109177e+01,
       -1.8047045e+01, -1.8009550e+01, -1.7960176e+01, -1.7906912e+01,
       -1.7850449e+01, -1.7777596e+01, -1.7711853e+01, -1.7614264e+01,
       -1.7520620e+01, -1.7416164e+01, -1.7290337e+01, -1.7153112e+01,
       -1.7000849e+01, -1.6850758e+01, -1.6688742e+01, -1.6532507e+01,
       -1.6353886e+01, -1.6163876e+01, -1.5962245e+01, -1.5751445e+01,
       -1.5546451e+01, -1.5323575e+01, -1.5083391e+01, -1.4838632e+01,
       -1.4543272e+01, -1.4244198e+01, -1.3967769e+01, -1.3684238e+01,
       -1.3397837e+01, -1.3088680e+01, -1.2780863e+01, -1.2488417e+01,
       -1.2206369e+01, -1.1885511e+01, -1.1618450e+01, -1.1255810e+01,
       -1.0998631e+01, -1.0675369e+01, -1.0338176e+01, -9.9470062e+00,
       -9.5364933e+00, -9.1623182e+00, -8.7452955e+00, -8.3760748e+00,
       -7.8683844e+00, -7.5117178e+00, -7.0521517e+00, -6.4710045e+00,
       -5.9083834e+00, -5.5061255e+00, -4.9566307e+00, -4.4076691e+00,
       -4.0001049e+00, -3.2340555e+00, -2.5686131e+00, -1.9720335e+00,
       -1.3629751e+00, -7.8985119e-01,  3.0343056e-02,  5.6671333e-01,
        1.3407516e+00,  2.0564156e+00,  3.1614218e+00,  4.3063345e+00,
        5.7614737e+00,  7.0440407e+00,  8.2497673e+00,  9.3607788e+00,
        1.0597897e+01,  1.1524124e+01,  1.2054529e+01,  1.2611225e+01,
        1.3381859e+01,  1.5637648e+01,  1.9726307e+01,  2.5643867e+01,
        3.2688374e+01,  3.9385639e+01,  4.4789917e+01,  4.8937840e+01,
        5.3084553e+01,  5.6956497e+01,  6.0881531e+01,  6.4722664e+01,
        6.9074844e+01,  7.2851013e+01,  7.6737183e+01,  8.0856812e+01],      dtype=float32), 'uvot__uvm2': Array([-16.223093  , -18.640867  , -18.584923  , -18.44068   ,
       -18.346802  , -18.199108  , -18.111822  , -18.037071  ,
       -17.989565  , -17.955605  , -17.92617   , -17.888475  ,
       -17.856026  , -17.786455  , -17.720884  , -17.628439  ,
       -17.512175  , -17.381054  , -17.22528   , -17.055597  ,
       -16.848415  , -16.628584  , -16.384157  , -16.114105  ,
       -15.839399  , -15.552998  , -15.229023  , -14.908245  ,
       -14.573969  , -14.2277    , -13.851395  , -13.501066  ,
       -13.149678  , -12.76346   , -12.356773  , -11.958153  ,
       -11.550663  , -11.083935  , -10.613113  , -10.235422  ,
        -9.739975  ,  -9.138271  ,  -8.572482  ,  -8.0568285 ,
        -7.580141  ,  -7.060055  ,  -6.6224194 ,  -5.9753275 ,
        -5.55361   ,  -5.092286  ,  -4.4822617 ,  -4.1059284 ,
        -3.401164  ,  -2.539689  ,  -2.0482788 ,  -1.2980556 ,
        -0.5560522 ,   0.33388615,   1.212965  ,   2.3453722 ,
         3.6945734 ,   4.9619007 ,   6.7065554 ,   8.557312  ,
        10.659222  ,  13.038017  ,  16.24377   ,  19.334705  ,
        23.305305  ,  27.135197  ,  31.248734  ,  35.325085  ,
        39.652927  ,  43.18199   ,  46.25222   ,  49.911022  ,
        53.6805    ,  56.880642  ,  59.073845  ,  61.545025  ,
        62.946068  ,  64.74722   ,  66.31885   ,  67.73751   ,
        68.795685  ,  70.34557   ,  71.89758   ,  73.13817   ,
        74.5971    ,  75.78513   ,  76.80831   ,  77.92353   ,
        79.643425  ,  80.98483   ,  82.82054   ,  84.36571   ,
        86.28694   ,  88.29062   ,  90.39355   ,  92.62276   ],      dtype=float32), 'uvot__uvw1': Array([-16.672739  , -18.996126  , -18.919998  , -18.816149  ,
       -18.70409   , -18.567585  , -18.455244  , -18.322954  ,
       -18.209606  , -18.144495  , -18.063015  , -17.990858  ,
       -17.92661   , -17.875114  , -17.82109   , -17.743412  ,
       -17.656656  , -17.568644  , -17.448807  , -17.315006  ,
       -17.154161  , -16.993498  , -16.801508  , -16.610409  ,
       -16.385897  , -16.166662  , -15.937201  , -15.690605  ,
       -15.43832   , -15.151317  , -14.881915  , -14.596727  ,
       -14.307682  , -14.030901  , -13.730566  , -13.423858  ,
       -13.1071    , -12.789061  , -12.473412  , -12.162176  ,
       -11.820712  , -11.475555  , -11.131695  , -10.795643  ,
       -10.479919  , -10.156736  ,  -9.8655815 ,  -9.543903  ,
        -9.219154  ,  -8.916015  ,  -8.618376  ,  -8.290066  ,
        -7.9402494 ,  -7.606264  ,  -7.2973213 ,  -6.965143  ,
        -6.6271496 ,  -6.2540636 ,  -5.927517  ,  -5.582843  ,
        -5.2500887 ,  -4.91978   ,  -4.549348  ,  -4.190838  ,
        -3.8282194 ,  -3.4551258 ,  -3.0881147 ,  -2.6821318 ,
        -2.2577038 ,  -1.8461905 ,  -1.3998175 ,  -0.9578991 ,
        -0.43126106,   0.13713169,   0.73911476,   1.4375634 ,
         2.3853168 ,   3.1743054 ,   4.210164  ,   5.5553923 ,
         6.8041296 ,   7.734975  ,   8.147635  ,   7.940865  ,
         7.635991  ,   8.015276  ,   8.672232  ,  10.283695  ,
        12.376645  ,  14.397923  ,  16.22261   ,  19.940199  ,
        25.35794   ,  31.46304   ,  39.147636  ,  46.57347   ,
        55.204163  ,  63.136017  ,  71.43259   ,  80.227264  ],      dtype=float32), 'uvot__uvw2': Array([-1.6223652e+01, -1.8539698e+01, -1.8449846e+01, -1.8322731e+01,
       -1.8225981e+01, -1.8097013e+01, -1.7981098e+01, -1.7883631e+01,
       -1.7815510e+01, -1.7759548e+01, -1.7681271e+01, -1.7634663e+01,
       -1.7581600e+01, -1.7515211e+01, -1.7432787e+01, -1.7331675e+01,
       -1.7217070e+01, -1.7085505e+01, -1.6916975e+01, -1.6740477e+01,
       -1.6526367e+01, -1.6302940e+01, -1.6060333e+01, -1.5784461e+01,
       -1.5490775e+01, -1.5201364e+01, -1.4884102e+01, -1.4553843e+01,
       -1.4202816e+01, -1.3835253e+01, -1.3479160e+01, -1.3119163e+01,
       -1.2749113e+01, -1.2390896e+01, -1.2001925e+01, -1.1637350e+01,
       -1.1244538e+01, -1.0870258e+01, -1.0490719e+01, -1.0121117e+01,
       -9.7422295e+00, -9.3310585e+00, -8.9159918e+00, -8.5541687e+00,
       -8.1842718e+00, -7.8192348e+00, -7.5056877e+00, -7.1316757e+00,
       -6.8096533e+00, -6.4904404e+00, -6.1804085e+00, -5.8751011e+00,
       -5.5192804e+00, -5.1632719e+00, -4.8909559e+00, -4.5459633e+00,
       -4.2196941e+00, -3.8450069e+00, -3.5205860e+00, -3.1592703e+00,
       -2.7665377e+00, -2.3867846e+00, -2.0467272e+00, -1.6662292e+00,
       -1.2786350e+00, -8.2216358e-01, -3.9320278e-01,  5.5187225e-02,
        4.5770645e-01,  8.4973621e-01,  1.4054270e+00,  1.8609347e+00,
        2.1640553e+00,  2.6321502e+00,  3.5278215e+00,  4.3764434e+00,
        5.7193961e+00,  7.1366129e+00,  8.2992573e+00,  9.5568924e+00,
        1.1009037e+01,  1.1628003e+01,  1.2306355e+01,  1.3010990e+01,
        1.4310606e+01,  1.6526339e+01,  2.0094360e+01,  2.4170494e+01,
        2.8747696e+01,  3.3720428e+01,  3.8120510e+01,  4.2008034e+01,
        4.6422821e+01,  4.9381973e+01,  5.1989464e+01,  5.4004333e+01,
        5.6645828e+01,  5.8335602e+01,  6.0074890e+01,  6.1899689e+01],      dtype=float32), 'uvot__v': Array([-16.082767  , -18.609787  , -18.688726  , -18.697365  ,
       -18.709858  , -18.656559  , -18.59812   , -18.472887  ,
       -18.36734   , -18.267864  , -18.119291  , -18.023748  ,
       -17.947697  , -17.841022  , -17.767632  , -17.693207  ,
       -17.592962  , -17.5216    , -17.471115  , -17.457457  ,
       -17.41074   , -17.357803  , -17.350039  , -17.302082  ,
       -17.265394  , -17.21849   , -17.155186  , -17.10411   ,
       -17.038767  , -16.983507  , -16.912064  , -16.831696  ,
       -16.74483   , -16.661085  , -16.572985  , -16.470768  ,
       -16.344639  , -16.224152  , -16.093906  , -15.95051   ,
       -15.812548  , -15.653697  , -15.498938  , -15.339443  ,
       -15.194935  , -15.042944  , -14.899298  , -14.750746  ,
       -14.588947  , -14.433643  , -14.255469  , -14.08002   ,
       -13.894093  , -13.710523  , -13.512467  , -13.32094   ,
       -13.084571  , -12.843918  , -12.604874  , -12.323567  ,
       -12.005646  , -11.719849  , -11.372591  , -11.008387  ,
       -10.623217  , -10.274706  ,  -9.909782  ,  -9.522337  ,
        -9.124023  ,  -8.690749  ,  -8.209464  ,  -7.7394676 ,
        -7.283542  ,  -6.7065554 ,  -6.070883  ,  -5.2486925 ,
        -4.5342493 ,  -3.939207  ,  -3.2499094 ,  -2.303112  ,
        -1.2470541 ,  -0.5383282 ,   0.34345245,   1.0674639 ,
         2.1858063 ,   3.4487    ,   3.6514196 ,   4.0894194 ,
         3.8225632 ,   4.2923822 ,   6.067768  ,   9.087004  ,
        12.691689  ,  16.462692  ,  21.003387  ,  25.436752  ,
        30.313934  ,  35.203827  ,  40.374817  ,  45.856293  ],      dtype=float32), 'uvot__white': Array([-16.358513  , -18.69475   , -18.636652  , -18.547543  ,
       -18.434357  , -18.315506  , -18.221758  , -18.114775  ,
       -18.024252  , -17.949314  , -17.872982  , -17.811752  ,
       -17.744509  , -17.689678  , -17.63056   , -17.568636  ,
       -17.494246  , -17.411905  , -17.327532  , -17.233677  ,
       -17.127539  , -17.00669   , -16.884033  , -16.74578   ,
       -16.597542  , -16.452074  , -16.298777  , -16.147589  ,
       -15.996634  , -15.8379    , -15.684668  , -15.532518  ,
       -15.375284  , -15.236743  , -15.092187  , -14.929336  ,
       -14.756325  , -14.580505  , -14.392279  , -14.2071495 ,
       -14.021442  , -13.832154  , -13.640781  , -13.462077  ,
       -13.2972    , -13.1409445 , -13.005913  , -12.85216   ,
       -12.698731  , -12.534834  , -12.367848  , -12.199224  ,
       -12.016863  , -11.835243  , -11.646686  , -11.457499  ,
       -11.265835  , -11.046671  , -10.831181  , -10.593597  ,
       -10.358668  , -10.108255  ,  -9.853073  ,  -9.582443  ,
        -9.306504  ,  -9.032211  ,  -8.736001  ,  -8.434916  ,
        -8.117977  ,  -7.783633  ,  -7.432095  ,  -7.0599833 ,
        -6.67362   ,  -6.280486  ,  -5.8339014 ,  -5.4125085 ,
        -4.9641485 ,  -4.5504894 ,  -4.120958  ,  -3.6346092 ,
        -3.1967492 ,  -2.7194037 ,  -2.261681  ,  -1.8901005 ,
        -1.4622669 ,  -1.1429682 ,  -0.72464323,  -0.46876287,
        -0.54933834,  -1.0872045 ,  -1.2190075 ,  -0.2231512 ,
         2.7933712 ,   5.3260117 ,   7.036293  ,   9.208748  ,
        12.637695  ,  16.333755  ,  20.428696  ,  24.76944   ],      dtype=float32), 'atlasc': Array([-16.287638  , -18.772953  , -18.801167  , -18.821253  ,
       -18.798717  , -18.75109   , -18.685146  , -18.592882  ,
       -18.50972   , -18.426449  , -18.315012  , -18.215803  ,
       -18.113964  , -18.018656  , -17.929436  , -17.844948  ,
       -17.757593  , -17.677828  , -17.606886  , -17.546312  ,
       -17.498009  , -17.428236  , -17.379553  , -17.32066   ,
       -17.255245  , -17.18673   , -17.114084  , -17.045374  ,
       -16.97148   , -16.902618  , -16.821957  , -16.742954  ,
       -16.664602  , -16.585064  , -16.503479  , -16.402092  ,
       -16.291534  , -16.179464  , -16.05835   , -15.934354  ,
       -15.806213  , -15.673554  , -15.538871  , -15.41024   ,
       -15.282967  , -15.161581  , -15.055734  , -14.933249  ,
       -14.802426  , -14.664959  , -14.518381  , -14.374186  ,
       -14.218008  , -14.055441  , -13.870512  , -13.702307  ,
       -13.514254  , -13.305563  , -13.102136  , -12.871044  ,
       -12.631449  , -12.386379  , -12.144876  , -11.892049  ,
       -11.609972  , -11.339907  , -11.045628  , -10.75536   ,
       -10.447658  , -10.099319  ,  -9.762301  ,  -9.388537  ,
        -9.023083  ,  -8.617689  ,  -8.15678   ,  -7.6605253 ,
        -7.223755  ,  -6.712633  ,  -6.2455244 ,  -5.785849  ,
        -5.1726975 ,  -4.521707  ,  -3.6800847 ,  -2.8839898 ,
        -2.0051422 ,  -0.8891525 ,  -0.09560776,   0.55977726,
         1.0034504 ,   1.5671206 ,   2.8600245 ,   5.887272  ,
         9.331846  ,  12.687336  ,  16.539337  ,  20.59526   ,
        25.477814  ,  30.727127  ,  36.978546  ,  43.605286  ],      dtype=float32), 'atlaso': Array([-16.283981  , -18.908432  , -18.936174  , -18.875057  ,
       -18.86838   , -18.830883  , -18.727018  , -18.615755  ,
       -18.501331  , -18.361597  , -18.2499    , -18.141413  ,
       -18.015879  , -17.894463  , -17.780748  , -17.679113  ,
       -17.585604  , -17.501556  , -17.435587  , -17.390556  ,
       -17.357046  , -17.316332  , -17.289648  , -17.266071  ,
       -17.245892  , -17.22203   , -17.195024  , -17.16967   ,
       -17.137053  , -17.103258  , -17.061884  , -17.016003  ,
       -16.974464  , -16.930214  , -16.87542   , -16.814896  ,
       -16.748257  , -16.674202  , -16.597094  , -16.520826  ,
       -16.440166  , -16.352283  , -16.264694  , -16.17591   ,
       -16.088285  , -15.995161  , -15.912338  , -15.82749   ,
       -15.735565  , -15.6360655 , -15.530423  , -15.416619  ,
       -15.291068  , -15.164326  , -15.028656  , -14.906351  ,
       -14.772788  , -14.627901  , -14.473642  , -14.306238  ,
       -14.141898  , -13.984777  , -13.809423  , -13.626504  ,
       -13.43593   , -13.249734  , -13.051955  , -12.831694  ,
       -12.607392  , -12.376748  , -12.123653  , -11.85573   ,
       -11.57851   , -11.270992  , -10.953611  , -10.61573   ,
       -10.272453  ,  -9.950995  ,  -9.6080475 ,  -9.268942  ,
        -8.95549   ,  -8.601849  ,  -8.239805  ,  -7.87652   ,
        -7.533234  ,  -7.136754  ,  -6.720172  ,  -6.340086  ,
        -6.107077  ,  -6.1855903 ,  -6.03901   ,  -5.1090837 ,
        -3.3520002 ,  -2.3713074 ,  -2.1812363 ,  -2.025711  ,
        -0.88720703,  -0.1350708 ,   1.006134  ,   2.2159424 ],      dtype=float32), '2massj': Array([-15.23453 , -16.536264, -17.08165 , -17.48962 , -17.60108 ,
       -17.907343, -18.048222, -18.276505, -18.440216, -18.44364 ,
       -18.331799, -18.35409 , -18.299368, -18.247402, -18.149498,
       -18.083086, -17.929863, -17.691704, -17.518032, -17.356796,
       -17.246109, -17.192987, -17.157457, -17.068867, -17.0374  ,
       -16.987835, -16.971998, -16.932026, -16.933968, -16.954113,
       -16.972942, -16.947983, -16.924448, -16.93974 , -16.93415 ,
       -16.952353, -16.884182, -16.882515, -16.8357  , -16.813732,
       -16.779009, -16.735153, -16.689852, -16.655645, -16.644754,
       -16.617235, -16.608715, -16.613214, -16.584122, -16.579226,
       -16.559896, -16.546585, -16.542387, -16.546768, -16.553633,
       -16.561676, -16.576178, -16.572414, -16.575096, -16.574776,
       -16.580383, -16.584587, -16.59186 , -16.580105, -16.564259,
       -16.531252, -16.481052, -16.419628, -16.334843, -16.241098,
       -16.132238, -16.013706, -15.882619, -15.738175, -15.589442,
       -15.422646, -15.245774, -15.089586, -14.938534, -14.790201,
       -14.643371, -14.502363, -14.370993, -14.237255, -14.110386,
       -13.977399, -13.834917, -13.665228, -13.512692, -13.491491,
       -13.594628, -13.092566, -11.647736,  -9.686649,  -8.577446,
        -8.762695,  -8.805893,  -8.323845,  -8.582542, -10.186569],      dtype=float32), '2massh': Array([-14.370985, -16.833868, -17.434078, -17.478277, -17.072977,
       -16.906565, -16.801466, -16.780231, -16.789455, -16.812962,
       -16.727184, -16.773834, -16.843712, -16.882683, -16.874413,
       -16.933887, -16.897457, -16.827904, -16.779295, -16.752014,
       -16.688393, -16.640255, -16.639736, -16.563065, -16.549162,
       -16.575657, -16.60181 , -16.612335, -16.62733 , -16.63786 ,
       -16.62267 , -16.64159 , -16.653397, -16.676102, -16.668993,
       -16.656689, -16.635115, -16.618258, -16.605722, -16.592215,
       -16.571396, -16.554619, -16.529278, -16.516756, -16.506147,
       -16.51203 , -16.517452, -16.513268, -16.497168, -16.491877,
       -16.472973, -16.461681, -16.45138 , -16.450565, -16.44758 ,
       -16.440916, -16.447546, -16.439514, -16.438442, -16.437515,
       -16.437414, -16.444683, -16.442764, -16.443623, -16.444304,
       -16.45074 , -16.45449 , -16.448837, -16.44512 , -16.438366,
       -16.423588, -16.404488, -16.38255 , -16.356285, -16.323618,
       -16.283297, -16.23492 , -16.18423 , -16.12875 , -16.060942,
       -15.986155, -15.897761, -15.806158, -15.703618, -15.599552,
       -15.474547, -15.342377, -15.201944, -15.05316 , -14.84199 ,
       -14.577156, -14.342713, -14.093035, -13.659736, -13.24516 ,
       -13.068741, -12.114056, -10.638458,  -9.991051,  -9.304779],      dtype=float32), '2massks': Array([ -9.470455 ,  -7.693514 , -11.808802 , -14.336    , -15.188431 ,
       -15.597246 , -16.081512 , -16.108406 , -16.195578 , -16.424559 ,
       -16.376066 , -16.46755  , -16.546705 , -16.314112 , -16.074366 ,
       -16.031136 , -16.104807 , -16.053707 , -16.10529  , -16.0253   ,
       -16.023441 , -16.027727 , -15.979695 , -15.989905 , -16.019669 ,
       -16.018301 , -16.10069  , -16.133362 , -16.11332  , -16.12715  ,
       -16.136425 , -16.149717 , -16.179405 , -16.213253 , -16.278404 ,
       -16.223282 , -16.167067 , -16.191193 , -16.166613 , -16.15434  ,
       -16.143814 , -16.174332 , -16.129467 , -16.167284 , -16.161982 ,
       -16.158895 , -16.194027 , -16.182142 , -16.185385 , -16.18294  ,
       -16.216427 , -16.19144  , -16.21519  , -16.238007 , -16.251562 ,
       -16.28459  , -16.289328 , -16.308838 , -16.315271 , -16.351673 ,
       -16.37933  , -16.404045 , -16.427315 , -16.460306 , -16.48463  ,
       -16.507921 , -16.519485 , -16.547253 , -16.561182 , -16.554445 ,
       -16.550827 , -16.542912 , -16.526627 , -16.506136 , -16.470331 ,
       -16.441257 , -16.391012 , -16.335787 , -16.259274 , -16.17529  ,
       -16.07635  , -15.958251 , -15.846211 , -15.7337055, -15.622306 ,
       -15.50075  , -15.3865385, -15.255363 , -15.127911 , -14.993031 ,
       -14.844004 , -14.70205  , -14.511631 , -14.306505 , -14.018784 ,
       -13.668093 , -13.236774 , -12.670759 , -11.872883 , -15.151148 ],      dtype=float32), 'ztfg': Array([-15.772129  , -18.320892  , -18.43288   , -18.484575  ,
       -18.491686  , -18.481766  , -18.448809  , -18.37889   ,
       -18.318335  , -18.253336  , -18.140322  , -18.069223  ,
       -17.979584  , -17.8856    , -17.803896  , -17.73556   ,
       -17.65985   , -17.587206  , -17.532799  , -17.470852  ,
       -17.417587  , -17.35754   , -17.312248  , -17.245367  ,
       -17.178871  , -17.097235  , -17.029188  , -16.949701  ,
       -16.87215   , -16.798468  , -16.706663  , -16.620708  ,
       -16.519032  , -16.418047  , -16.318588  , -16.179554  ,
       -16.03045   , -15.886807  , -15.721505  , -15.556217  ,
       -15.390172  , -15.185933  , -14.995248  , -14.82428   ,
       -14.649346  , -14.457754  , -14.300677  , -14.121938  ,
       -13.935627  , -13.767874  , -13.568681  , -13.37904   ,
       -13.163475  , -12.954523  , -12.732788  , -12.510538  ,
       -12.248693  , -11.973305  , -11.721714  , -11.415252  ,
       -11.141406  , -10.838574  , -10.499096  , -10.10349   ,
        -9.768442  ,  -9.408928  ,  -9.0772705 ,  -8.684952  ,
        -8.345127  ,  -7.8161416 ,  -7.4304767 ,  -7.0433264 ,
        -6.4263754 ,  -5.7191877 ,  -5.055436  ,  -4.3633537 ,
        -3.3309002 ,  -2.625084  ,  -1.546751  ,  -0.34356213,
         0.48571968,   0.60820293,  -0.10072327,  -0.66506577,
        -1.5962591 ,  -1.7468328 ,   1.1143837 ,   4.8076286 ,
        10.069817  ,  14.982475  ,  20.013168  ,  25.336075  ,
        31.681     ,  37.480804  ,  43.70526   ,  49.804474  ,
        56.62909   ,  62.9823    ,  69.82837   ,  77.08545   ],      dtype=float32), 'ztfr': Array([-15.960594  , -18.583803  , -18.66496   , -18.656425  ,
       -18.65863   , -18.636261  , -18.555864  , -18.489721  ,
       -18.40367   , -18.314716  , -18.210762  , -18.096031  ,
       -17.961716  , -17.852213  , -17.734444  , -17.621525  ,
       -17.547304  , -17.447742  , -17.378788  , -17.325811  ,
       -17.2939    , -17.26101   , -17.230722  , -17.215385  ,
       -17.18534   , -17.158182  , -17.118877  , -17.095705  ,
       -17.056248  , -17.016502  , -16.965748  , -16.919447  ,
       -16.878061  , -16.830587  , -16.771276  , -16.70102   ,
       -16.632004  , -16.548016  , -16.45801   , -16.366737  ,
       -16.269419  , -16.166422  , -16.056835  , -15.949722  ,
       -15.848029  , -15.7500515 , -15.66906   , -15.576261  ,
       -15.4730425 , -15.36188   , -15.244295  , -15.124579  ,
       -14.998223  , -14.871509  , -14.734875  , -14.608687  ,
       -14.483653  , -14.336303  , -14.181261  , -14.014381  ,
       -13.846426  , -13.6729965 , -13.495538  , -13.302461  ,
       -13.09746   , -12.889979  , -12.667173  , -12.433775  ,
       -12.161707  , -11.871835  , -11.553125  , -11.221348  ,
       -10.856398  , -10.473299  , -10.032507  ,  -9.610563  ,
        -9.16524   ,  -8.704119  ,  -8.284309  ,  -7.765768  ,
        -7.3549037 ,  -6.8569074 ,  -6.337792  ,  -5.845028  ,
        -5.610551  ,  -5.083511  ,  -4.6806498 ,  -4.6215205 ,
        -4.7558975 ,  -4.8308287 ,  -4.3931828 ,  -2.3722725 ,
         0.93536186,   4.086624  ,   7.582306  ,  11.76059   ,
        17.027634  ,  22.72963   ,  29.175262  ,  36.007904  ],      dtype=float32), 'ztfi': Array([-15.792493  , -18.324368  , -18.464746  , -18.45646   ,
       -18.50505   , -18.460644  , -18.424038  , -18.353561  ,
       -18.283194  , -18.164705  , -18.080393  , -18.00904   ,
       -17.913483  , -17.834442  , -17.747639  , -17.666817  ,
       -17.56364   , -17.488167  , -17.42873   , -17.402395  ,
       -17.381649  , -17.33349   , -17.300434  , -17.26037   ,
       -17.25128   , -17.242388  , -17.188198  , -17.15699   ,
       -17.127197  , -17.113838  , -17.082489  , -17.034948  ,
       -16.99881   , -16.973415  , -16.932024  , -16.877914  ,
       -16.832945  , -16.758835  , -16.707357  , -16.636919  ,
       -16.574066  , -16.517593  , -16.46005   , -16.402489  ,
       -16.335796  , -16.279882  , -16.22033   , -16.177853  ,
       -16.132893  , -16.068192  , -16.014868  , -15.948302  ,
       -15.869617  , -15.787788  , -15.700533  , -15.629569  ,
       -15.536847  , -15.43677   , -15.315778  , -15.189185  ,
       -15.068685  , -14.9490385 , -14.819545  , -14.677732  ,
       -14.527348  , -14.368735  , -14.2005625 , -14.018286  ,
       -13.823902  , -13.627251  , -13.41889   , -13.193704  ,
       -12.964486  , -12.707693  , -12.438104  , -12.159219  ,
       -11.852673  , -11.564217  , -11.246758  , -10.919043  ,
       -10.6069975 , -10.293084  ,  -9.9546585 ,  -9.637359  ,
        -9.324613  ,  -9.003272  ,  -8.642328  ,  -8.288932  ,
        -7.949544  ,  -7.9238396 ,  -8.01818   ,  -7.7933044 ,
        -6.6145267 ,  -4.8133717 ,  -2.4050732 ,  -0.21736717,
        -1.2484894 ,  -2.327942  ,  -3.1846237 ,  -4.092758  ],      dtype=float32)})

Get the TF output

In [None]:
# For this list, we compute the LCs using the flax model
tf_output = []
start = time.time()
for i in idx_list:
    # Compute the lightcurve
    _, _, mag = nmma.em.utils.calc_lc(t,
                                input_values[i], 
                                svd_mag_model = tensorflow_model.svd_model, 
                                interpolation_type="tensorflow", 
                                filters = filts, 
                                )
    # Convert this dictionary to values of the LCs
    mag = mag.values()
    mag = np.array(list(mag)).T
    tf_output.append(mag)
end = time.time()
print(f"Computing all the tensorflow lightcurves for a subset of {N} lightcurves took {end-start} seconds.")
# Make sure this is a np.ndarray
tf_output = np.array(tf_output)

Computing all the flax lightcurves for a subset of 100 lightcurves took 5.414904594421387 seconds.


## Compare MSE or MAE values

TODO: Best to compare this as a distribution, and perhaps best to consider MAE, or some self-defined loss function or error function?

In [None]:
def mse(y_true, y_pred, axis=None):
    return np.mean((y_true - y_pred)**2, axis=axis)

def se(y_true, y_pred):
    return (y_true - y_pred)**2

def mae(y_true, y_pred, axis=None):
    return np.mean(np.abs(y_true - y_pred), axis=axis)

def ae(y_true, y_pred):
    return np.abs(y_true - y_pred)

def my_format(low: float, med: float, high: float, nb: int = 3) -> str:
    med = np.round(med, nb)
    low = med - low
    low = np.round(low, nb)
    high = high - med
    high = np.round(high, nb)
    
    return f"{med} - {low} + {high}"

# # TODO with arviz summarize the errors
# def summarize_data(values: np.array, percentile: float = 0.95) -> None:
    
#     med = np.median(values)
#     result = arviz.hdi(values, hdi_prob = percentile)
    
#     print(my_format(low, med, high))
    
#     return

In [None]:
which_dataset = flax_output
which_error = mae

diffs = se(which_dataset, sampled_output_values)
axis = 0
mse_values = which_error(which_dataset, sampled_output_values, axis=axis)
mse_values = np.mean(mse_values, axis=0)
for f, val in zip(filts, mse_values):
    print(f"{f}: {val}")

bessellux: 1.860606712292868
bessellb: 1.7921464100782574
bessellv: 1.2305521863871158
bessellr: 0.8128799439449892
besselli: 0.7583155514392446
sdssu: 1.0992091068356826
ps1__g: 1.8798792498944232
ps1__r: 1.0871671955916962
ps1__i: 0.8739246164214045
ps1__z: 0.8943951113252402
ps1__y: 0.9373291937542888
uvot__b: 1.8021596713178516
uvot__u: 1.7665092726441651
uvot__uvm2: 1.871222932129877
uvot__uvw1: 1.6923179907486763
uvot__uvw2: 1.991923127090116
uvot__v: 1.5010415651636924
uvot__white: 0.8017472904174968
atlasc: 1.278656720359771
atlaso: 0.7374362913587105
2massj: 0.7798304492014924
2massh: 0.5952465996901724
2massks: 0.679392397537318
ztfg: 1.8178616188139933
ztfr: 1.0294186804296865
ztfi: 0.8224096511046391


In [None]:
print("MAE flax")
print(mae(flax_output, sampled_output_values))
print("MSE flax")
print(mse(flax_output, sampled_output_values))

MAE flax
1.2459069052297256
MSE flax
20.490649731915262


In [None]:
print("MAE tf")
print(mae(tf_output, sampled_output_values))
print("MSE tf")
print(mse(tf_output, sampled_output_values))

MAE tf
1.958473016121413
MSE tf
44.68970564439654


## Speed: can we improve the speed generation of flax?

In [None]:
import jax
import jax.numpy as jnp
import scipy.interpolate as interp

In [None]:
# def calc_lc_flax(
#     tt,
#     param_list, # TODO add type hint, but can break if jax not imported
#     svd_mag_model=None,
#     svd_lbol_model=None,
#     mag_ncoeff=None,
#     lbol_ncoeff=None,
#     filters=None,
#     use_jit=True, # TODO implement a jitted version?
# ):
#     mAB = {}

#     if filters is None:
#         filters = list(svd_mag_model.keys())
#     else:
#         # add null output for radio and X-ray filters
#         for filt in filters:
#             if filt.startswith(("radio", "X-ray")):
#                 mAB[filt] = jnp.inf * jnp.ones(len(tt))

#     for jj, filt in enumerate(filters):
#         if filt in mAB:
#             continue

#         if mag_ncoeff:
#             n_coeff = min(mag_ncoeff, svd_mag_model[filt]["n_coeff"])
#         else:
#             n_coeff = svd_mag_model[filt]["n_coeff"]
#         # param_array = svd_mag_model[filt]["param_array"]
#         # cAmat = svd_mag_model[filt]["cAmat"]
#         VA = svd_mag_model[filt]["VA"]
#         param_mins = svd_mag_model[filt]["param_mins"]
#         param_maxs = svd_mag_model[filt]["param_maxs"]
#         mins = svd_mag_model[filt]["mins"]
#         maxs = svd_mag_model[filt]["maxs"]
#         tt_interp = svd_mag_model[filt]["tt"]

#         param_list_postprocess = (param_list - param_mins) / (param_maxs - param_mins)

#         # Watch out for possible confusion with names here
#         state = svd_mag_model[filt]["model"]

#         # Apply the model to the given parameters
#         cAproj = state.apply_fn({'params': state.params}, param_list_postprocess)
#         cAstd = jnp.ones((n_coeff,))

#         # Go from SVD coefficients to original lightcurve data
#         mag_back = jnp.dot(VA[:, :n_coeff], cAproj)
#         mag_back = mag_back * (maxs - mins) + mins

#         ## TODO how to implement this in jax?
#         # ii = jnp.where((~jnp.isnan(mag_back)) * (tt_interp < 20.0))[0]
#         # if len(ii) < 2:
#         #     maginterp = jnp.nan * jnp.ones(tt.shape)
#         # else:
#         #     f = interp.interp1d(tt_interp[ii], mag_back[ii], fill_value="extrapolate")
#         #     maginterp = f(tt)

#         # f = interp.interp1d(tt, mag_back, fill_value="extrapolate")
#         # maginterp = f(tt)

#         maginterp = mag_back
#         mAB[filt] = maginterp

#     # TODO: currently not used, what if we want to use it?
#     if svd_lbol_model is not None:
#         if lbol_ncoeff:
#             n_coeff = min(lbol_ncoeff, svd_lbol_model["n_coeff"])
#         else:
#             n_coeff = svd_lbol_model["n_coeff"]
#         # param_array = svd_lbol_model["param_array"]
#         # cAmat = svd_lbol_model["cAmat"]
#         VA = svd_lbol_model["VA"]
#         param_mins = svd_lbol_model["param_mins"]
#         param_maxs = svd_lbol_model["param_maxs"]
#         mins = svd_lbol_model["mins"]
#         maxs = svd_lbol_model["maxs"]
#         gps = svd_lbol_model["gps"]
#         tt_interp = svd_lbol_model["tt"]

#         param_list_postprocess = param_list
#         for i in range(len(param_mins)):
#             param_list_postprocess.at[i].set((param_list_postprocess[i] - param_mins[i]) / (
#                 param_maxs[i] - param_mins[i]
#             ))

#         # TODO add this for flax?
#         # ...
#         # ...

#         lbol_back = jnp.dot(VA[:, :n_coeff], cAproj)
#         lbol_back = lbol_back * (maxs - mins) + mins
#         # lbol_back = scipy.signal.medfilt(lbol_back, kernel_size=3)

#         ii = jnp.where(~jnp.isnan(lbol_back))[0]
#         if len(ii) < 2:
#             lbolinterp = jnp.nan * jnp.ones(tt.shape)
#         else:
#             f = interp.interp1d(tt_interp[ii], lbol_back[ii], fill_value="extrapolate")
#             lbolinterp = 10 ** f(tt)
#         lbol = lbolinterp
#     else:
#         lbol = np.inf * np.ones(len(tt))

#     return np.squeeze(tt), np.squeeze(lbol), mAB

### 1. Jit, no vmap

In [None]:
# Lambda function, so that we focus on the parameters as being the input only
from nmma.em.utils import get_calc_lc_jit
calc_lc_given_params_jit = get_calc_lc_jit(t, svd_mag_model=training_model.svd_model, filters=filts)
start = time.time()
for i in range(N):
    flax_output_jit = calc_lc_given_params_jit(sampled_input_values[i])
end = time.time()
print(f"Computing all the flax lightcurves for a subset of {N} lightcurves took {end-start} seconds.")

Install afterglowpy if you want to simulate afterglows.
Install wrapt_timeout_decorator if you want timeout simulations.
Computing all the flax lightcurves for a subset of 100 lightcurves took 1.4477312564849854 seconds.


### 2. Jit and vmap

In [None]:
# vmap the function
calc_lc_given_params_vmap = jax.vmap(calc_lc_given_params_jit)
# apply to input_values
flax_output_jit = calc_lc_given_params_vmap(sampled_input_values)