# Study the correlation between the SSP Parameters

Study the correlation between the parameters using `fors2tostellarpopsynthesis`package

- Author Sylvie Dagoret-Campagne
- Afflilation : IJCLab/IN2P3/CNRS
- Organisation : LSST-DESC
- creation date : 2023-12-05
- last update : 2023-12-05


| computer | processor | kernel              |    date     |
| --- | --- | --- | --- |
| CC       | CPU       | conda_jax0325_py310 | 2023-11-10  |
| macbookpro | CPU | conda_jaxcpu_dsps_py310 | 2023-11-10  | 


libraries 
=========

jax
---

- jaxlib-0.3.25+cuda11.cudnn82
- jaxopt
- optax
- corner
- arviz
- numpyro
- graphviz

sps
---

- fsps
- prospect
- dsps
- diffstar
- diffmah



(conda_jax0325_py310) 
`/pbs/throng/lsst/users/dagoret/desc/StellarPopulationSynthesis>pip list | grep` 

| lib | version |
|--- | --- | 
|jax  |                         0.4.20 |
|jaxlib |                       0.4.20 |
|jaxopt  |                      0.8.2 |



## examples

- jaxcosmo : https://github.com/DifferentiableUniverseInitiative/jax-cosmo-paper/blob/master/notebooks/VectorizedNumPyro.ipynb
- on atmosphere : https://github.com/sylvielsstfr/FitDiffAtmo/blob/main/docs/notebooks/fitdiffatmo/test_numpyro_orderedict_diffatmemul_5params_P_pwv_oz_tau_beta.ipynb

## Import

### import external packages

In [None]:
import h5py
import pandas as pd
import numpy as np
import os
import re
import pickle 
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.colors as colors
import matplotlib.cm as cmx
import collections
from collections import OrderedDict
import re
import matplotlib.gridspec as gridspec
from sklearn.gaussian_process import GaussianProcessRegressor, kernels

In [None]:
import jax
import jax.numpy as jnp
from jax import vmap
import jaxopt
import optax
jax.config.update("jax_enable_x64", True)
from interpax import interp1d

from jax.lax import fori_loop
from jax.lax import select,cond
from jax.lax import concatenate

In [None]:
import numpyro
from numpyro import optim
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, HMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal, AutoMultivariateNormal
from numpyro.infer.reparam import NeuTraReparam
from numpyro.handlers import seed, trace, condition


import corner
import arviz as az

### import internal packages

In [None]:
from fors2tostellarpopsynthesis.parameters  import SSPParametersFit,paramslist_to_dict

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_jaxopt import (SSP_DATA,mean_spectrum,mean_mags,mean_sfr,ssp_spectrum_fromparam)

In [None]:
from fors2tostellarpopsynthesis.fitters.fitter_numpyro import(PARAM_SIMLAW_NODUST,PARAM_SIMLAW_WITHDUST,
                            PARAM_NAMES,PARAM_VAL,PARAM_MIN,PARAM_MAX,PARAM_SIGMA)

from fors2tostellarpopsynthesis.fitters.fitter_numpyro import(galaxymodel_nodust_av,galaxymodel_nodust,galaxymodel_withdust_av,galaxymodel_withdust)

## Configuration

### matplotlib configuration

In [None]:
plt.rcParams["figure.figsize"] = (12,6)
plt.rcParams["axes.labelsize"] = 'xx-large'
plt.rcParams['axes.titlesize'] = 'xx-large'
plt.rcParams['xtick.labelsize']= 'xx-large'
plt.rcParams['ytick.labelsize']= 'xx-large'
plt.rcParams['legend.fontsize']=  16
plt.rcParams['font.size'] = 15

## Fit parameters

In [None]:
p = SSPParametersFit()

In [None]:
p.INIT_PARAMS = p.INIT_PARAMS.at[-4].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-3].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-2].set(0.)
p.INIT_PARAMS = p.INIT_PARAMS.at[-1].set(1.)

In [None]:
wlsall,spec_rest,spec_rest_att = ssp_spectrum_fromparam(p.DICT_PARAMS_true,0)

In [None]:
print(PARAM_SIMLAW_NODUST)  
print(PARAM_SIMLAW_WITHDUST) 
print(PARAM_NAMES)
print(PARAM_VAL)
print(PARAM_MIN) 
print(PARAM_MAX)
print(PARAM_SIGMA)

In [None]:
z_obs = 0.5
sigmarel_obs = 0.1
sigma_obs = 1e-11

## Bayesian modelling

In [None]:
condlist_fix = jnp.where(PARAM_SIMLAW_NODUST == "fixed",True,False)
condlist_fix

In [None]:
condlist_uniform = jnp.where(PARAM_SIMLAW_NODUST == "uniform",True,False)
condlist_uniform

In [None]:
numpyro.render_model(galaxymodel_nodust, model_args=(jnp.array([0.]),jnp.array([1.]),
                                                     PARAM_VAL, 
                                                     PARAM_MIN,
                                                     PARAM_MAX,
                                                     PARAM_SIGMA, 
                                                     PARAM_NAMES,
                                                     z_obs,sigma_obs),render_distributions=True)

In [None]:
numpyro.render_model(galaxymodel_withdust, model_args=(jnp.array([0.]),jnp.array([1.]),
                                                     PARAM_VAL, 
                                                     PARAM_MIN,
                                                     PARAM_MAX,
                                                     PARAM_SIGMA, 
                                                     PARAM_NAMES,
                                                     z_obs,sigma_obs),render_distributions=True)

In [None]:
assert False

In [None]:

# So, let's generate the data at the fiducial parameters
fiducial_model = condition(galaxymodel_nodust, p.DICT_PARAMS_true)
# Run the conditionned model (if not conditionned the priors are executed)
trace_data_nodust = trace(seed(fiducial_model, jax.random.PRNGKey(42))).get_trace(wlsall,
                       minparamval = PARAM_MIN,
                       maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,
                       paramnames = PARAM_NAMES,
                       z_obs = z_obs,
                       sigma = sigma_obs)

In [None]:
spec_nodust = trace_data_nodust['F']["value"]

In [None]:
# So, let's generate the data at the fiducial parameters
fiducial_model = condition(galaxymodel_withdust, p.DICT_PARAMS_true)
# Run the conditionned model (if not conditionned the priors are executed)
trace_data_withdust = trace(seed(fiducial_model, jax.random.PRNGKey(42))).get_trace(wlsall,
                       minparamval = PARAM_MIN,
                       maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,
                       paramnames = PARAM_NAMES,
                       z_obs = z_obs,
                       sigma = sigma_obs)

In [None]:
spec_withdust = trace_data_withdust['F']["value"]

In [None]:
fig,ax = plt.subplots(1,1,figsize=(10,3))
ax.errorbar(wlsall,spec_nodust,yerr=sigma_obs,fmt='o',ms=0.5 ,linewidth=2, capsize=0, c='k', label="no dust")
ax.errorbar(wlsall,spec_withdust,yerr=sigma_obs, fmt='o', ms=0.5,linewidth=2, capsize=0, c='r', label="with dust")
ax.set_xlabel("$\lambda$ (nm)")
ax.set_ylabel("DSPS spectrim")
ax.legend()
ax.set_yscale('log')
ax.set_ylim(1e-11,1e-5)
ax.set_xlim(1e2,1e6)
ax.set_xscale('log')
ax.grid();


In [None]:
# Run NUTS.
rng_key = jax.random.PRNGKey(42)
rng_key, rng_key0, rng_key1, rng_key2 = jax.random.split(rng_key, 4)


kernel = NUTS(galaxymodel_nodust, dense_mass=True, target_accept_prob=0.9,
              init_strategy=numpyro.infer.init_to_median())
num_samples = 5_000
n_chains = 4
mcmc = MCMC(kernel, num_warmup=1_000, num_samples=num_samples,  
            num_chains=n_chains,
            chain_method='vectorized',
            progress_bar=True)
mcmc.run(rng_key, wlsin=wlsall, Fobs=spec_nodust,
                       minparamval = PARAM_MIN,
                       maxparamval = PARAM_MAX,
                       sigmaparamval = PARAM_SIGMA,
                       paramnames = PARAM_NAMES,
                       z_obs = z_obs,
                       sigma = sigma_obs,
                       extra_fields=('potential_energy',))
mcmc.print_summary()
samples_nuts = mcmc.get_samples()

In [None]:
az.ess(samples_nuts, relative=True)  # efficacité relative

In [None]:
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);
