In [1]:
import matplotlib.cm as cm
from time import time
from matplotlib import lines as mlines
import numpy as np
from jax import random as jran

ran_key = jran.PRNGKey(0)

### Call the target data generator to compute $\langle\log_{10}c(t)\vert M_0\rangle$

In [2]:
import os
from astropy.table import Table
diffprof_drn = "/Users/aphearin/work/DATA/diffprof_data"
mdpl2 = Table.read(os.path.join(diffprof_drn, "MDPL2_halo_table.hdf5"))
bpl = Table.read(os.path.join(diffprof_drn, "BPL_halo_table.hdf5"))
print(bpl.keys())

['halo_id', 'conc_beta_early', 'conc_beta_late', 'conc_k', 'conc_lgtc', 'u_conc_beta_early', 'u_conc_beta_late', 'u_conc_k', 'u_conc_lgtc', 'logmp', 'mah_early', 'mah_late', 'mah_logtc', 'mah_k', 'log_mah_fit', 'conch_fit', 'tform_50', 'p_tform_50']


In [3]:
from diffprof.nfw_evolution import lgc_vs_lgt

lgc_vs_lgt_vmap = jjit(jvmap(lgc_vs_lgt, in_axes=(None, 0, 0, 0, 0)))

N_T = 25
TARR_FIT = np.linspace(2, 13.8, N_T)

lgconc_history_bpl = np.array(lgc_vs_lgt_vmap(
    np.log10(TARR_FIT), bpl["conc_lgtc"], bpl["conc_k"], bpl["conc_beta_early"], bpl["conc_beta_late"]))
lgconc_history_mdpl2 = np.array(lgc_vs_lgt_vmap(
    np.log10(TARR_FIT), mdpl2["conc_lgtc"], mdpl2["conc_k"], mdpl2["conc_beta_early"], mdpl2["conc_beta_late"]))


In [4]:
from diffprof.latin_hypercube import get_scipy_kdtree, retrieve_lh_sample_indices
tree_bpl = get_scipy_kdtree(bpl['logmp'])
tree_mdpl2 = get_scipy_kdtree(mdpl2['logmp'])

indx_bpl = retrieve_lh_sample_indices(tree_bpl, 11.35, 13.65, 1, 100_000)
indx_mdpl2 = retrieve_lh_sample_indices(tree_mdpl2, 13.4, 14.6, 1, 100_000)

In [5]:
from diffprof.get_target_simdata import target_data_generator

N_MH_TARGETS, N_P_TARGETS = 1, 15

args = (bpl['logmp'][indx_bpl],
    mdpl2['logmp'][indx_mdpl2],
    lgconc_history_bpl[indx_bpl],
    lgconc_history_mdpl2[indx_mdpl2],
    bpl['p_tform_50'][indx_bpl],
    mdpl2['p_tform_50'][indx_mdpl2],
    N_MH_TARGETS,
    N_P_TARGETS,
       )
gen = target_data_generator(*args, lgmh_min=14.0, lgmh_max=14.25)

In [6]:
target_data = next(gen)
lgmhalo_targets, p50_targets = target_data[0:2]
target_avg_log_conc_lgm0, target_std_log_conc_lgm0 = target_data[2:4]
target_avg_log_conc_p50_lgm0, target_std_log_conc_p50_lgm0 = target_data[4:]
target_std_log_conc_p50_lgm0.shape

(1, 15, 25)

In [7]:
# tarr = np.linspace(1, 13.8, 200)
# p50_arr = np.array((0.1, 0.5, 0.9))
# lgm0 = 14.0
# zz = np.zeros_like(tarr)

### Call the target data model to compute $\langle\log_{10}c(t)\vert M_0,p_{50\%}\rangle$

In [8]:
from diffprof import target_data_model as tdm

# mean_lgc_old_tdm = tdm.approximate_lgconc_vs_lgm_p50(
#     tarr, lgm0, p50_arr[0], *tdm.target_data_model_params_mean_lgconc.values())
# mean_lgc_mid_tdm = tdm.approximate_lgconc_vs_lgm_p50(
#     tarr, lgm0, p50_arr[1], *tdm.target_data_model_params_mean_lgconc.values())
# mean_lgc_young_tdm = tdm.approximate_lgconc_vs_lgm_p50(
#     tarr, lgm0, p50_arr[2], *tdm.target_data_model_params_mean_lgconc.values())

_tdm_params = np.array(list(tdm.target_data_model_params_mean_lgconc.values()))
im = 0
x = [tdm.approximate_lgconc_vs_lgm_p50(TARR_FIT, lgmhalo_targets[im], p, *_tdm_params) for p in p50_targets]
_s = (p50_targets.shape[0], TARR_FIT.shape[0])
target_avg_log_conc_p50 = np.concatenate([*x]).reshape((_s))

# assert np.allclose(target_avg_log_conc_p50[0, :], mean_lgc_old_tdm)
# assert np.allclose(target_avg_log_conc_p50[1, :], mean_lgc_mid_tdm)
# assert np.allclose(target_avg_log_conc_p50[2, :], mean_lgc_young_tdm)

### Call the target data model to compute $\sigma\left(\log_{10}c(t)\vert M_0,p_{50\%}\right)$

In [9]:
target_std_lgc_lgm = tdm.approx_std_lgconc_vs_lgm(
    TARR_FIT, lgmhalo_targets[im], *tdm.target_data_model_params_std_lgconc.values())

_pars = list(tdm.target_data_model_params_std_lgconc_p50.values())
target_std_log_conc_p50 = np.array(
    [tdm._scatter_vs_p50_and_lgmhalo(lgmhalo_targets[im], p, *_pars) for p in p50_targets])

In [10]:
from diffprof.dpp_loss_funcs import _mse_loss_singlemass
from diffprof.bpl_dpp import DEFAULT_PARAMS