In [None]:
from jax import random as jran

ran_key = jran.key(0)

## Generate Monte Carlo subhalo catalog and SFHs

In [None]:
from diffsky.mc_diffsky import mc_diffstar_galpop, mc_diffstar_cenpop

halo_key, ran_key = jran.split(ran_key, 2)

lgmp_min = 11.0
redshift = 0.05
Lbox = 100.0
volume_com = Lbox**3
args = (ran_key, redshift, lgmp_min, volume_com)
diffstar_data = mc_diffstar_galpop(*args, return_internal_quantities=True)
diffstar_cens = mc_diffstar_cenpop(*args, return_internal_quantities=True)

In [None]:
print(diffstar_data.keys())
print(diffstar_data["subcat"]._fields)
print(diffstar_cens.keys())
print(diffstar_cens["subcat"]._fields)
print(diffstar_data["subcat"].mah_params._fields)

In [None]:
subcat = diffstar_data["subcat"]
import matplotlib.pyplot as plt
import numpy as np
for subcat, l  in zip([diffstar_data["subcat"], diffstar_cens["subcat"]],
                     ['gals', 'cens']):
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    yscale = ax.set_yscale('log')
    __=ax.hist(subcat.logmp0, bins=np.linspace(11, 15, 40))
    ax.set_xlabel('$M_{p0}$')
    fig.suptitle(l)
    plt.savefig('MC_logmp0_{}.png'.format(l))

In [None]:
#Plot some input and output distributions
for subcat, l in zip([diffstar_data["subcat"], diffstar_cens["subcat"]],
                     ['gals', 'cens']):
    fig, ax_all = plt.subplots(1, len(subcat.mah_params), figsize=(18, 3))
    for ax, par, xlbl in zip(ax_all.flat, subcat.mah_params, subcat.mah_params._fields):
        __ = ax.hist(par, bins=50)
        ax.set_xlabel(xlbl)
    fig.suptitle(l)
    plt.savefig('MC_mah_params_{}.png'.format(l))
#print(len(subcat.mah_params[0]), len(subcat.logmp0))

In [None]:
from diffmah import mah_halopop
import matplotlib.cm as cm

In [None]:
for subcat, l, t_table in zip([diffstar_data["subcat"], diffstar_cens["subcat"]],
                ['gals', 'cens'], [diffstar_data["t_table"], diffstar_cens["t_table"]]):
    logt0 = np.log10(t_table[-1])
    dmhdt, log_mah = mah_halopop(subcat.mah_params, t_table, logt0)
    colors=cm.coolwarm(np.linspace(1, 0, len(subcat.logmp0)))
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    __=ax.loglog()
    step = 100
    for logmah, c in zip(log_mah[::step], colors[::step]):
        __=ax.plot(t_table, 10**logmah, color=c)
    fig.suptitle(l)
    plt.savefig('MC_MAH_{}.png'.format(l))

In [None]:
for subcat, lbl in zip([diffstar_data["subcat"], diffstar_cens["subcat"]],
                    ['gals', 'cens']):

    logmu_infall = subcat.logmp_ult_inf - subcat.logmhost_ult_inf
    qs = [logmu_infall, subcat.logmhost_ult_inf, subcat.t_ult_inf]
    labels = ('logmu_infall', 'logmhost_ult_inf', 't_ult_inf')
    fig, ax_all = plt.subplots(1, len(qs), figsize=(15, 4))
    for ax, q, l in zip(ax_all.flat, qs, labels):
        __ = ax.hist(q, bins=50)
        ax.set_xlabel(l)
    fig.suptitle(lbl)
    plt.savefig('MC_logmu_logmhost_t_ult_{}.png'.format(lbl))

## Check Monte Carlo star formation histories

In [None]:
for frac_q, sfh_params, lbl in zip([diffstar_data["frac_q"], diffstar_cens["frac_q"]],
                                   [diffstar_data["sfh_params"], diffstar_cens["sfh_params"]],
                                   ['gals', 'cens']):
    #print(sfh_params.ms_params._fields, sfh_params.q_params._fields )
    q_ms = [getattr(sfh_params.ms_params, f) for f in sfh_params.ms_params._fields]
    q_q = [getattr(sfh_params.q_params, f) for f in sfh_params.q_params._fields]
    labels = ['frac_q'] + list(sfh_params.ms_params._fields) + list(sfh_params.q_params._fields)
    qs = [frac_q] + q_ms + q_q
    print(lbl, len(frac_q), np.min(frac_q), np.max(frac_q))

    ncol = int(len(qs)/2)
    fig, ax_all = plt.subplots(2, ncol, figsize=(5*ncol, 10))
    for ax, q, l in zip(ax_all.flat, qs, labels):
        print(l, np.min(q), np.max(q))
        bins = np.linspace(np.min(q), np.max(q), num=50) if np.min(q) < np.max(q) else 50
        __ = ax.hist(q, bins=bins)
        ax.set_xlabel(l)
    fig.suptitle(lbl)
    plt.savefig('MC_sfh_params_{}.png'.format(lbl))

In [None]:
print(diffstar_data["sfh"].shape, diffstar_cens["sfh"].shape)


In [None]:
for sfh, smh, t_table, l in zip([diffstar_data["sfh"], diffstar_cens["sfh"]],
                               [diffstar_data["smh"], diffstar_cens["smh"]],
                               [diffstar_data["t_table"], diffstar_cens["t_table"]],
                               ['gals', 'cens']):

    fig, ax_all = plt.subplots(1, 2, figsize=(9, 4))
    step = 300
    for ax, tab, label in zip(ax_all.flat, [sfh, smh], ['SFR', 'M*']):
        colors=cm.coolwarm(np.linspace(1, 0, len(tab)))
        for t, c in zip(tab[::step], colors[::step]):
            __=ax.plot(t_table, t, color=c)
        ax.set_ylabel(label)
        ax.set_xlabel('$t$')
    fig.suptitle(l)
    plt.savefig('SFH_{}_step_{}.png'.format(l, step))