In [1]:
%matplotlib inline

In [2]:
#__ = plt.style.use("./diffstar.mplstyle")
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from scipy.optimize import curve_fit
mred = u"#d62728"
morange = u"#ff7f0e"
mgreen = u"#2ca02c"
mblue = u"#1f77b4"
mpurple = u"#9467bd"
plt.rc('font', family="serif")
plt.rc('font', size=22)
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}') #necessary to use \dfrac

In [3]:
from jax import numpy as jnp
from jax import jit as jjit
from jax import vmap
from jax import grad
from jax import random as jran
import os
import h5py
names = ["ulgm", "ulgy", "ul", "utau", "uqt", "uqs", "udrop", "urej"]
from functools import partial


In [4]:
from diffstar.utils import _get_dt_array, jax_np_interp
from diffstar.fit_smah_helpers import get_header, load_diffstar_data
from diffstar.stars import (
    calculate_sm_sfr_fstar_history_from_mah,
    compute_fstar,
    compute_fstar_vmap,
    _sfr_eff_plaw,
    _get_bounded_sfr_params, 
    _get_unbounded_sfr_params,
    _get_unbounded_sfr_params_vmap,
    _get_bounded_sfr_params_vmap,
    calculate_histories,
    calculate_histories_vmap,
    calculate_histories_batch,
    fstar_tools,
)
from diffstar.constants import TODAY

from diffstar.quenching import _get_bounded_q_params, _get_unbounded_q_params, _get_bounded_q_params_vmap, _get_unbounded_q_params_vmap
from diffmah.individual_halo_assembly import _calc_halo_history

from chainconsumer import ChainConsumer

def calculate_fstar_data_batch(tarr, sm_cumsum, index_select, index_high, fstar_tdelay):
    ng = len(sm_cumsum)
    nt = len(index_high)
    fstar = np.zeros((ng,nt))
    indices = np.array_split(np.arange(ng), int(ng/5000))
    for inds in indices:
        fstar[inds] = compute_fstar_vmap(tarr, sm_cumsum[inds], index_select, index_high, fstar_tdelay)
    return fstar
    
def _calculate_sm(lgt, dt, mah_params, sfr_params, q_params, index_select, index_high, fstar_tdelay):
    dmhdt, log_mah = _calc_halo_history(lgt, *mah_params)
    mstar, sfr, fstar = calculate_sm_sfr_fstar_history_from_mah(
        lgt, dt, dmhdt, log_mah, sfr_params, q_params,
        index_select,
        index_high,
        fstar_tdelay
    )
    return mstar, sfr, fstar, dmhdt, log_mah



calculate_sm = jjit(vmap(_calculate_sm, in_axes=(None, None, 0, 0, 0, None, None, None)))

colnames = get_header()[1:].strip().split()
sfr_colnames = colnames[1:6]
q_colnames = colnames[6:10]


In [5]:
from diffstarpop.utils import get_t50_p50
_calc_halo_history_vmap = jjit(vmap(_calc_halo_history, in_axes=(None, *[0]*6)))

path = "/Users/alarcon/Documents/diffmah_data/SMDPL/"
runname = "run1_SMDPL_diffstar_default_%i.h5"

def get_mah_params(runname, data_path=path):

    fitting_data = dict()

    fn = os.path.join(data_path, runname)
    with h5py.File(fn, "r") as hdf:
        for key in hdf.keys():
            if key == "halo_id":
                fitting_data[key] = hdf[key][...]
            else:
                fitting_data["fit_" + key] = hdf[key][...]

    mah_params = np.array(
        [
            np.log10(fitting_data["fit_t0"]),
            fitting_data["fit_logmp_fit"],
            fitting_data["fit_mah_logtc"],
            fitting_data["fit_mah_k"],
            fitting_data["fit_early_index"],
            fitting_data["fit_late_index"],
        ]
    ).T
    return mah_params

def get_sfh_params(run_name):
    sfr_fitdata = dict()

    fn = os.path.join(path, run_name)
    with h5py.File(fn, "r") as hdf:
        for key in hdf.keys():
            sfr_fitdata[key] = hdf[key][...]

    colnames = get_header()[1:].strip().split()
    sfr_colnames = colnames[1:6]
    q_colnames = colnames[6:10]

    u_fit_params = np.array([sfr_fitdata[key] for key in sfr_colnames + q_colnames]).T
    u_sfr_fit_params = np.array([sfr_fitdata[key] for key in sfr_colnames]).T
    u_q_fit_params = np.array([sfr_fitdata[key] for key in q_colnames]).T

    sfr_fit_params = _get_bounded_sfr_params_vmap(*u_sfr_fit_params.T)
    q_fit_params = _get_bounded_q_params_vmap(*u_q_fit_params.T)
    
    sfr_fit_params = np.array([np.array(x) for x in sfr_fit_params]).T
    q_fit_params = np.array([np.array(x) for x in q_fit_params]).T

    fit_params = np.concatenate((sfr_fit_params, q_fit_params), axis=1)
    
    return u_fit_params, fit_params


def calculate_histories():
    """
    fstar_tdelay = 1.0
    index_select, index_high = fstar_tools(tarr, fstar_tdelay=fstar_tdelay)
    
    dmhdt = []
    log_mah = []
    mstar = []
    sfr = []
    fstar = []
    """
    tarr = np.linspace(0.1, TODAY, 100)

    dt = _get_dt_array(tarr)
    lgt = np.log10(tarr)

    
    log_mah = []
    mah_params_arr = []
    u_fit_params_arr = []
    fit_params_arr = []
    for i in range(576):
        runname = "run1_SMDPL_diffmah_default_%d.h5" % i
        mah_params = get_mah_params(runname)

        runname = "run1_SMDPL_diffstar_default_%i.h5" % i
        u_fit_params, fit_params = get_sfh_params(runname)

        mpeak = mah_params[:,1]
        mh_trans = 12.0
        weight = np.where(mah_params[:,1] > mh_trans, 1.0, 10**(0.82*(mah_params[:,1] - mh_trans)))
        rand = np.random.uniform(size=len(mpeak))
        mask = rand < weight
        
        mah_params = mah_params[mask]
        u_fit_params = u_fit_params[mask]
        fit_params = fit_params[mask]
        
        mah_params_arr.append(mah_params)
        u_fit_params_arr.append(u_fit_params)
        fit_params_arr.append(fit_params)
        print(i, len(mah_params))
        
        _res = _calc_halo_history_vmap(lgt, *mah_params.T)
        log_mah.append(_res[1])
        """
        

        _res = sm_sfr_history_diffstar_scan_XsfhXmah_vmap(
            tarr,
            lgt,
            dt,
            mah_params[:, [1,2,4,5]],
            sfr_ms_params,
            q_params,
            index_select,
            index_high,
            fstar_tdelay,
        )
        mstar.append(_res[0])
        sfr.append(_res[1])
        fstar.append(_res[2])
        """
    """
    dmhdt = np.concatenate(dmhdt)
    log_mah = np.concatenate(log_mah)
    mstar = np.concatenate(mstar)
    sfr = np.concatenate(sfr)
    fstar = np.concatenate(fstar)

    print("Calculating p50...")

    t50, p50 = get_t50_p50(tarr, 10**log_mah, 0.5, log_mah[:,-1], window_length = 101)
    """
    log_mah = np.concatenate(log_mah)
    p50 = get_t50_p50(tarr, 10**log_mah, 0.5, log_mah[:,-1], window_length = 101)[1]
    mah_params_arr = np.concatenate(mah_params_arr)
    u_fit_params_arr = np.concatenate(u_fit_params_arr)
    fit_params_arr = np.concatenate(fit_params_arr)
    return mah_params_arr, u_fit_params_arr, fit_params_arr, p50

mah_params_arr, u_fit_params_arr, fit_params_arr, p50_arr = calculate_histories()

0 3251
1 2980
2 3147
3 3060
4 3164
5 3209
6 3091
7 3043
8 3347
9 3334
10 3190
11 3296
12 3205
13 3261
14 3280
15 3201
16 3468
17 3534
18 3741
19 3570
20 3498
21 3672
22 3353
23 3476
24 3029
25 3150
26 3320
27 4195
28 2267
29 3114
30 3215
31 3114
32 3167
33 3056
34 3274
35 3259
36 3111
37 3095
38 3163
39 3222
40 3549
41 3553
42 3385
43 3443
44 3446
45 3498
46 3547
47 3426
48 3181
49 3258
50 3189
51 3043
52 3188
53 3156
54 3272
55 3297
56 3088
57 3199
58 3210
59 3189
60 3216
61 3047
62 3233
63 3154
64 3462
65 3583
66 3313
67 3429
68 3635
69 3441
70 3724
71 3452
72 3163
73 3972
74 2556
75 3142
76 3074
77 3279
78 3170
79 3001
80 3207
81 3232
82 3213
83 3203
84 3132
85 3130
86 3157
87 3176
88 3396
89 3611
90 3485
91 3456
92 3516
93 3494
94 3450
95 3587
96 3359
97 3123
98 3213
99 3176
100 3180
101 3070
102 3289
103 3063
104 3273
105 3357
106 3194
107 3355
108 3216
109 3174
110 3202
111 3400
112 3657
113 3627
114 3474
115 3562
116 3556
117 3601
118 3578
119 3507
120 3164
121 3312
122 3188
123

In [13]:
from diffstarpop.monte_carlo_diff_halo_population import (
    sm_sfr_history_diffstar_scan_XsfhXmah_vmap,
    sm_sfr_history_diffstar_scan_MS_XsfhXmah_vmap,
    _jax_get_dt_array

)
def calculate_SMDPL_sumstats(
    t_table,
    logm0_binmids,
    logm0_bin_widths,
    mah_params, 
    fit_params, 
    p50,
):
    logmpeak = mah_params[:, 1]

    lgt = np.log10(t_table)

    fstar_tdelay = 1.0
    index_select, index_high = fstar_tools(t_table, fstar_tdelay=fstar_tdelay)
    dt = _jax_get_dt_array(t_table)

    stats = []
    for i in range(len(logm0_binmids)):

        print("Calculating m0=[%.2f, %.2f]"%(logm0_binmids[i] - logm0_bin_widths[i], logm0_binmids[i] + logm0_bin_widths[i]))
        sel = (logmpeak > logm0_binmids[i] - logm0_bin_widths[i]) & (logmpeak < logm0_binmids[i] + logm0_bin_widths[i])
        print("Nhalos:", sel.sum())
        _res = sm_sfr_history_diffstar_scan_XsfhXmah_vmap(
            t_table,
            lgt,
            dt,
            mah_params[sel][:, [1,2,4,5]],
            fit_params[sel][:,[0,1,2,4]].copy(),
            fit_params[sel][:,[5,6,7,8]].copy(),
            index_select,
            index_high,
            fstar_tdelay,
        )
        
        (
            mstar_histories,
            sfr_histories,
            fstar_histories,
        ) = _res
        
        ssfr = sfr_histories / mstar_histories
        weights_quench_bin = jnp.where(ssfr > 1e-11, 1.0, 0.0)
        
        _stats = calculate_sumstats_bin(
            mstar_histories, 
            sfr_histories, 
            p50[sel], 
            weights_quench_bin
        )
        stats.append(_stats)

    print("Reshaping results")
        
    new_stats = []
    nres = len(_stats)
    for j in range(nres):
        _new_stats = []
        for i in range(len(logm0_binmids)):
            _new_stats.append(stats[i][j])
        new_stats.append(np.array(_new_stats))


    np.save("/Users/alarcon/Documents/diffmah_data/SMDPL_sfh_sumstats.npy", new_stats)
    return new_stats



def calculate_sumstats_bin(
    mstar_histories, sfr_histories, p50, weights_MS
):

    weights_Q = 1.0 - weights_MS

    # Clip weights. When all weights in a time
    # step are 0, Nans will occur in gradients.
    eps = 1e-10
    weights_Q = jnp.clip(weights_Q, eps, None)
    weights_MS = jnp.clip(weights_MS, eps, None)

    weights_early = jnp.where(p50 < 0.5, 1.0, 0.0)
    weights_late = 1.0 - weights_early
    weights_early = jnp.clip(weights_early, eps, None)
    weights_late = jnp.clip(weights_late, eps, None)

    mstar_histories = jnp.where(mstar_histories > 0.0, jnp.log10(mstar_histories), 0.0)
    sfr_histories = jnp.where(sfr_histories > 0.0, jnp.log10(sfr_histories), 0.0)
    # fstar_histories = jnp.where(fstar_histories > 0.0, jnp.log10(fstar_histories), 0.0)

    mean_sm = jnp.average(mstar_histories, axis=0)
    mean_sfr_MS = jnp.average(sfr_histories, weights=weights_MS, axis=0)
    mean_sfr_Q = jnp.average(sfr_histories, weights=weights_Q, axis=0)

    mean_sm_early = jnp.average(mstar_histories, weights=weights_early, axis=0)
    mean_sm_late = jnp.average(mstar_histories, weights=weights_late, axis=0)

    variance_sm = jnp.average(
        (mstar_histories - mean_sm[None, :]) ** 2, axis=0,
    )

    variance_sfr_MS = jnp.average(
        (sfr_histories - mean_sfr_MS[None, :]) ** 2, weights=weights_MS, axis=0,
    )
    variance_sfr_Q = jnp.average(
        (sfr_histories - mean_sfr_Q[None, :]) ** 2, weights=weights_Q, axis=0,
    )
    variance_sm_early = jnp.average(
        (mstar_histories - mean_sm[None, :]) ** 2, weights=weights_early, axis=0,
    )
    variance_sm_late = jnp.average(
        (mstar_histories - mean_sm[None, :]) ** 2, weights=weights_late, axis=0,
    )

    NHALO_MS = jnp.sum(weights_MS, axis=0)
    NHALO_Q = jnp.sum(weights_Q, axis=0)
    quench_frac = NHALO_Q / (NHALO_Q + NHALO_MS)

    mean_sfr_Q = jnp.where(quench_frac == 0.0, 0.0, mean_sfr_Q)
    variance_sfr_Q = jnp.where(quench_frac == 0.0, 0.0, variance_sfr_Q)
    mean_sfr_MS = jnp.where(quench_frac == 1.0, 0.0, mean_sfr_MS)
    variance_sfr_MS = jnp.where(quench_frac == 1.0, 0.0, variance_sfr_MS)

    NHALO_MS_early = jnp.sum(weights_MS * weights_early[:, None], axis=0)
    NHALO_Q_early = jnp.sum(weights_Q * weights_early[:, None], axis=0)
    quench_frac_early = NHALO_Q_early / (NHALO_Q_early + NHALO_MS_early)

    NHALO_MS_late = jnp.sum(weights_MS * weights_late[:, None], axis=0)
    NHALO_Q_late = jnp.sum(weights_Q * weights_late[:, None], axis=0)
    quench_frac_late = NHALO_Q_late / (NHALO_Q_late + NHALO_MS_late)

    _out = (
        mean_sm,
        variance_sm,
        mean_sfr_MS,
        mean_sfr_Q,
        variance_sfr_MS,
        variance_sfr_Q,
        quench_frac,
        mean_sm_early,
        mean_sm_late,
        variance_sm_early,
        variance_sm_late,
        quench_frac_early,
        quench_frac_late,
    )
    return _out

In [14]:
t_table = np.linspace(1.0, TODAY, 20)
logm0_binmids = np.linspace(11.5, 13.5, 5)
logm0_bin_widths = np.ones_like(logm0_binmids) * 0.1

MC_res_target = calculate_SMDPL_sumstats(
    t_table,
    logm0_binmids,
    logm0_bin_widths,
    mah_params_arr, 
    fit_params_arr, 
    p50_arr
)

Calculating m0=[11.40, 11.60]
Nhalos: 252210
Calculating m0=[11.90, 12.10]
Nhalos: 221323
Calculating m0=[12.40, 12.60]
Nhalos: 81419
Calculating m0=[12.90, 13.10]
Nhalos: 27465
Calculating m0=[13.40, 13.60]
Nhalos: 8594
Reshaping results
