In [None]:
import os
from diffaux.disk_bulge_modeling.generate_bulge_disk_sample import (
    get_bulge_disk_test_sample,
    get_bulge_disk_decomposition,
    get_zindexes,
)

from diffaux.disk_bulge_modeling.disk_bulge_kernels import (
    calc_tform_pop,
    _bulge_sfh_vmap,
    _sigmoid,
)
from jax import random as jran
import numpy as np
import jax.numpy as jnp
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
from itertools import zip_longest

from diffaux.validation.plot_disk_bulge import (
    plot_qs_profiles_for_zvals,
    plot_qs_profiles,
    plot_q_profile,
)
from diffaux.validation.plot_utilities import get_nrow_ncol

ran_key = jran.key(0)

In [None]:
halo_key, ran_key = jran.split(ran_key, 2)
lgmp_min = 11.0
redshift = 0.05
Lbox = 75.0
diffstar_cens = get_bulge_disk_test_sample(halo_key, lgmp_min=lgmp_min, redshift=redshift, Lbox=Lbox)
# diffstar_cens2 = get_bulge_disk_test_sample(halo_key, lgmp_min=lgmp_min, redshift=redshift, Lbox=Lbox)

## logsm0 Bug-fixed Version

In [None]:
disk_bulge_key, ran_key = jran.split(ran_key, 2)
diffstar_cens = get_bulge_disk_decomposition(disk_bulge_key, diffstar_cens, new_model=False)
print(diffstar_cens.keys())
for q in ["tcrit_bulge", "fbulge_early", "fbulge_late"]:
    print(q, diffstar_cens[q].shape)

In [None]:
zvalues = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5]
redshifts = diffstar_cens["z_table"]
zindexes, zs = get_zindexes(zvalues, redshifts)

## Test generate fbulge parameters

In [None]:
from dsps.constants import SFR_MIN
from diffstar.utils import cumulative_mstar_formed_galpop

logM_min = 7.0
logM_max = 11.5
Nm = 18
Mbins = np.linspace(logM_min, logM_max, Nm + 1)

logMz0_min = 8.0
logssfr_min = -16
mass_mask = np.log10(diffstar_cens["smh"][:, -1]) > logMz0_min
sfr_mask = np.log10(diffstar_cens["sSFR"][:, -1]) > logssfr_min
mask = mass_mask & sfr_mask
title = "$M^*_{{z=0}} > 10^{{{:.1f}}} M_\\odot, sSFR > 10^{{{}}} yr^{{-1}}$".format(logMz0_min, logssfr_min)
xname = "log_M0_min_{:.1f}_logssfr_min_{}".format(logMz0_min, -logssfr_min)

tarr = diffstar_cens["t_table"]
smh_pop_cens = diffstar_cens["smh"]
# calculate t10 and t90 exactly as in code
sfh_pop = diffstar_cens["sfh"]
sfh_pop = np.where(sfh_pop < SFR_MIN, SFR_MIN, sfh_pop)
smh_pop = cumulative_mstar_formed_galpop(tarr, sfh_pop)
t10 = calc_tform_pop(tarr, smh_pop, 0.1)
t90 = calc_tform_pop(tarr, smh_pop, 0.9)
logsm0_cens = np.log10(diffstar_cens["smh"][:, -1])
# with SFR correction
logsm0 = np.log10(smh_pop[:, -1])

In [None]:
from diffaux.disk_bulge_modeling.mc_disk_bulge import generate_fbulge_params
from diffaux.disk_bulge_modeling.mc_disk_bulge import mc_disk_bulge

# different methods and logsm0 variants
# use logsm0 with no SFR correction, new key
fbulge_key_cens, ran_key = jran.split(ran_key, 2)
fbulge_params_new_noSFR = generate_fbulge_params(fbulge_key_cens, t10, t90, logsm0_cens)
# use logsm0 with SFR correction, new key
fbulge_key, ran_key = jran.split(ran_key, 2)
fbulge_params_new_SFR = generate_fbulge_params(fbulge_key, t10, t90, logsm0)
# replicate exact call in get_bulge_disk_decomposition
fbulge_params_same_SFR = generate_fbulge_params(disk_bulge_key, t10, t90, logsm0)
fbulge_params_same_noSFR = generate_fbulge_params(disk_bulge_key, t10, t90, logsm0_cens)

mc_key, ran_key = jran.split(ran_key, 2)
tarr = diffstar_cens["t_table"]
sfh = diffstar_cens["sfh"]
# repeat call to mc_disk_bulge with new key
fbulge_params_mc, smh_mc, eff_bulge_mc, sfh_bulge_mc, smh_bulge_mc, bth_mc = mc_disk_bulge(mc_key, tarr, sfh)
print(fbulge_params_mc.shape)
print(eff_bulge_mc.shape)
print(diffstar_cens["eff_bulge"].shape)
# check agreement
print(jnp.isclose(diffstar_cens["smh"] / smh_mc, smh_mc / smh_mc).all())
print(jnp.isclose(diffstar_cens["bth"] / bth_mc, bth_mc / bth_mc).all())
print(jnp.isclose(diffstar_cens["eff_bulge"] / eff_bulge_mc, eff_bulge_mc / eff_bulge_mc).all())

In [None]:
# compare values of fbulge
plotdir = "/Users/kovacs/cosmology/BulgeDisk/DiskBulgePlots"
fa = (
    diffstar_cens["tcrit_bulge"][mask],
    diffstar_cens["fbulge_early"][mask],
    diffstar_cens["fbulge_late"][mask],
)
fc = (fbulge_params_new_SFR[:, 0][mask], fbulge_params_new_SFR[:, 1][mask], fbulge_params_new_SFR[:, 2][mask])
fd = (fbulge_params_mc[:, 0][mask], fbulge_params_mc[:, 1][mask], fbulge_params_mc[:, 2][mask])
fb = (
    fbulge_params_new_noSFR[:, 0][mask],
    fbulge_params_new_noSFR[:, 1][mask],
    fbulge_params_new_noSFR[:, 2][mask],
)
fe = (
    fbulge_params_same_SFR[:, 0][mask],
    fbulge_params_same_SFR[:, 1][mask],
    fbulge_params_same_SFR[:, 2][mask],
)
ff = (
    fbulge_params_same_noSFR[:, 0][mask],
    fbulge_params_same_noSFR[:, 1][mask],
    fbulge_params_same_noSFR[:, 2][mask],
)
xa = [logsm0[mask] for i in range(len(fa))]
xb = [logsm0_cens[mask] for i in range(len(fb))]
lxs = ("logsm0" for i in range(len(fa)))
# fb = (fbulge_tcrit, fbulge_early, fbulge_late)
lys = ("fbulge_tcrit", "fbulge_early", "fbulge_late")
xba = [xa, xb, xa, xa, xa, xb]
fba = [fa, fb, fc, fd, fe, ff]
binz = [Mbins for i in range(len(fba))]
labels = [
    "mc_original",
    "fbulge_new_key_noSFRcut",
    "fbulge_new_key_SFRcut",
    "mc_new_key",
    "fbulge_same_key_SFRcut",
    "fbulge_same_key_noSFRcut",
]
plot_qs_profiles(
    xba,
    fba,
    labels,
    lxs,
    lys,
    binz,
    plotdir=plotdir,
    title=title,
    xname=xname,
    pltname="check_fbulge_{}.png",
    plotsubdir="Fbulge",
)

In [None]:
def plot_comparison_profiles_at_zvalues(
    xlist,
    qlist,
    binz,
    labels,
    zindexes,
    zs,
    xlabel,
    ylabel,
    plotdir=plotdir,
    title=title,
    xname=xname,
    error=True,
    pltname="check_eff_bulge_{}.png",
    plotsubdir="Fbulge",
):
    plotdir = os.path.join(plotdir, plotsubdir)
    nrow, ncol = get_nrow_ncol(len(zs))
    fig, ax_all = plt.subplots(nrow, ncol, figsize=(5 * ncol, 4 * nrow))

    for x, q, bins, label in zip(xlist, qlist, binz, labels):
        xvals = [x[:, zidx] for zidx in zindexes]
        yvals = [q[:, zidx] for zidx in zindexes]
        plot_q_profile(ax_all, xvals, yvals, bins, label, error=error)

    zlabels = ["$z = {:.2f}$".format(z) for z in zs]
    for ax, zlabel in zip(ax_all.flat, zlabels):
        ax.legend(loc="best", title=zlabel)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)

    fn = os.path.join(plotdir, pltname.format(xname))
    fig.suptitle(title, y=0.97)
    # plt.tight_layout()
    plt.savefig(fn)
    print(f"Saving {fn}")

In [None]:
# check agreement of other bulge quantities computed in different ways
# generate values from fbulge above
_res = _bulge_sfh_vmap(tarr, sfh_pop, fbulge_params_new_SFR)
smh_fnew_SFR, eff_bulge_fnew_SFR, sfh_bulge_fnew_SFR, smh_bulge_fnew_SFR, bth_fnew_SFR = _res

# generate from returned values by
_res = _bulge_sfh_vmap(tarr, sfh_pop, fbulge_params_same_SFR)
smh_fnew_mc, eff_bulge_fnew_mc, sfh_bulge_fnew_mc, smh_bulge_fnew_mc, bth_fnew_mc = _res

# generate from dict values of fbulge
fbulge_params_same_cens = jnp.asarray(
    [diffstar_cens[q] for q in ["tcrit_bulge", "fbulge_early", "fbulge_late"]]
).T
_res = _bulge_sfh_vmap(tarr, sfh_pop, fbulge_params_same_cens)
smh_same_cens, eff_bulge_same_cens, sfh_bulge_same_cens, smh_bulge_same_cens, bth_same_cens = _res

qlist1 = [
    eff_bulge_mc[mask],
    diffstar_cens["eff_bulge"][mask],
    eff_bulge_fnew_SFR[mask],
    eff_bulge_fnew_mc[mask],
    eff_bulge_same_cens[mask],
]
qlist2 = [
    bth_mc[mask],
    diffstar_cens["bth"][mask],
    bth_fnew_SFR[mask],
    bth_fnew_mc[mask],
    bth_same_cens[mask],
]
labels = ["mc_new_key", "mc_original", "fbulge_new_SFR", "fbulge_vmap_same_SFR", "fbulge_cens_vmap"]
mh = np.log10(diffstar_cens["smh"])[mask]
xlist = [mh for i in range(len(qlist1))]
ylabel2 = "B/T"
ylabel1 = "$\\epsilon_{bulge}$"
xlabel = "$\\log_{10}(M^*/M_\\odot)$"
binz = [Mbins for i in range(len(qlist1))]

In [None]:
plot_comparison_profiles_at_zvalues(
    xlist,
    qlist1,
    binz,
    labels,
    zindexes,
    zs,
    xlabel,
    ylabel1,
)
plot_comparison_profiles_at_zvalues(
    xlist, qlist2, binz, labels, zindexes, zs, xlabel, ylabel2, pltname="check_bth_{}.png"
)

In [None]:
# reproduce generate_fbulge_params from scratch
from dsps.constants import SFR_MIN
from diffaux.disk_bulge_modeling.disk_bulge_kernels import (
    calc_tform_pop,
    _get_params_from_u_params,
    _get_params_from_u_params_vmap,
)


def generate_uparams_test(
    ran_key,
    t10,
    t90,
    logsm0,
):
    mu_u_tcrit = 2
    delta_mu_u_tcrit = 3
    mu_u_early = 5
    delta_mu_u_early = 0.1
    mu_u_late = 5
    delta_mu_u_late = 3
    scale_u_early = 10
    scale_u_late = 8
    scale_u_tcrit = 20

    n = t10.size

    tcrit_key, early_key, late_key = jran.split(ran_key, 3)

    u_tcrit_table = [
        mu_u_tcrit - delta_mu_u_tcrit * scale_u_tcrit,
        mu_u_tcrit + delta_mu_u_tcrit * scale_u_tcrit,
    ]
    print(u_tcrit_table)
    logsm_table = 8, 11.5
    mu_u_tcrit_pop = np.interp(logsm0, logsm_table, u_tcrit_table)
    mc_u_tcrit = jran.normal(tcrit_key, shape=(n,)) * scale_u_tcrit + mu_u_tcrit_pop

    u_early_table = [
        mu_u_early - delta_mu_u_early * scale_u_early,
        mu_u_early + delta_mu_u_early * scale_u_early,
    ]
    mu_u_early_pop = np.interp(logsm0, logsm_table, u_early_table)
    mc_u_early = jran.normal(early_key, shape=(n,)) * scale_u_early + mu_u_early_pop
    print(u_early_table)
    u_late_table = [
        mu_u_late + delta_mu_u_late * scale_u_late,
        mu_u_late - delta_mu_u_late * scale_u_late,
    ]
    print(u_late_table)
    mu_u_late_pop = np.interp(logsm0, logsm_table, u_late_table)
    mc_u_late = jran.normal(late_key, shape=(n,)) * scale_u_late + mu_u_late_pop

    u_params = mc_u_tcrit, mc_u_early, mc_u_late
    u_params_pop = mu_u_tcrit_pop, mu_u_early_pop, mu_u_late_pop
    # print(u_params[0], len(u_params))
    return u_params, u_params_pop

In [None]:
test_key, ran_key = jran.split(ran_key, 2)
u_params_test, u_params_test_pop = generate_uparams_test(test_key, t10, t90, logsm0)
print(len(u_params_test), len(u_params_test[0]))

In [None]:
from scipy.stats import binned_statistic

x0 = (logsm0, logsm0, logsm0)
q1 = (u_params_test[0], u_params_test[1], u_params_test[2])
l1 = ("mc_u_tcrit", "mc_u_early", "mc_u_late")
q2 = (u_params_test_pop[0], u_params_test_pop[1], u_params_test_pop[2])
l2 = ("mu_u_tcrit_pop", "mu_u_early_pop", "mu_u_late_pop")
pltname = "generate_uparams_{}.png"
bin0 = (Mbins, Mbins, Mbins)
plotd = os.path.join(plotdir, "Fbulge")
nrow, ncol = get_nrow_ncol(len(x0))
fig, ax_all = plt.subplots(nrow, ncol, figsize=(5 * ncol, 4 * nrow))

for ql, ll in zip([q1, q2], [l1, l2]):
    for ax, x, q, bins, label in zip(ax_all.flat, x0, ql, bin0, ll):
        xmeans, _, _ = binned_statistic(x, x, bins=bins)
        ymeans, _, _ = binned_statistic(x, q, bins=bins)
        std, _, _ = binned_statistic(x, q, bins=bins, statistic="std")
        ax.plot(xmeans, ymeans, label=label)
        ax.fill_between(xmeans, ymeans - std, ymeans + std, alpha=0.3)

    for ax in ax_all.flat:
        ax.legend(loc="best")
        ax.set_xlabel("logsm0")

fn = os.path.join(plotd, pltname.format(xname))
fig.suptitle(title, y=0.97)
plt.savefig(fn)
print(f"Saving {fn}")

In [None]:
def get_fbulge_from_params(u_params, x0, t10, t90, BOUNDING_K=0.1, FBULGE_MIN=0.05, FBULGE_MAX=0.95):
    mc_u_tcrit, mc_u_early, mc_u_late = u_params
    t50 = (t10 + t90) / 2
    fbulge_tcrit = _sigmoid(mc_u_tcrit, t50, BOUNDING_K, t10, t90)

    x0 = (FBULGE_MIN + FBULGE_MAX) / 2
    fbulge_early = _sigmoid(mc_u_early, x0, BOUNDING_K, FBULGE_MIN, FBULGE_MAX)

    x0_late = (fbulge_early + FBULGE_MIN) / 2
    fbulge_late = _sigmoid(mc_u_late, x0_late, BOUNDING_K, fbulge_early, FBULGE_MIN)
    fbulge_params = np.asarray((fbulge_tcrit, fbulge_early, fbulge_late)).T
    # print( fbulge_params.shape)
    return fbulge_params

In [None]:
fbulge_params_gen = get_fbulge_from_params(u_params_test, x0, t10, t90)
fbulge_params_pop = get_fbulge_from_params(u_params_test_pop, x0, t10, t90)
fbulge_params_func = _get_params_from_u_params(u_params_test, t10, t90)
print(fbulge_params_gen[0])
print(fbulge_params_func[0][0], fbulge_params_func[1][0], fbulge_params_func[2][0])
for gen, func in zip(fbulge_params_gen.T, fbulge_params_func):
    print(jnp.isclose(gen, func).all())

In [None]:
# xa = (mu_u_tcrit_pop, mu_u_early_pop, mu_u_late_pop)
xa = (u_params_test_pop[0], u_params_test_pop[1], u_params_test_pop[2])
ga = (fbulge_params_gen[:, 0], fbulge_params_gen[:, 1], fbulge_params_gen[:, 2])
gc = (fbulge_params_func[0], fbulge_params_func[1], fbulge_params_func[2])
lxs = ("u_tcrit", "u_early", "u_late")
# xb = (mc_u_tcrit, mc_u_early, mc_u_late)
xb = (u_params_test[0], u_params_test[1], u_params_test[2])
gb = (fbulge_params_pop[:, 0], fbulge_params_pop[:, 1], fbulge_params_pop[:, 2])
lys = ("fbulge_tcrit", "fbulge_early", "fbulge_late")
xba = [xb, xa, xb]
fba = [gb, ga, gc]
labels = ["u_params_test", "u_params_pop", "u_params_func"]
Nbins = 50
bins = [Nbins for i in range(len(xa))]
plot_qs_profiles(
    xba, fba, labels, lxs, lys, bins, title=title, xname=xname, plotdir=plotdir, plotsubdir="Fbulge"
)

## Compare fbulge parameters again

In [None]:
ha = (
    diffstar_cens["tcrit_bulge"][mask],
    diffstar_cens["fbulge_early"][mask],
    diffstar_cens["fbulge_late"][mask],
)
hc = (fbulge_params_new_SFR[:, 0][mask], fbulge_params_new_SFR[:, 1][mask], fbulge_params_new_SFR[:, 2][mask])
hd = (fbulge_params_gen[:, 0][mask], fbulge_params_gen[:, 1][mask], fbulge_params_gen[:, 2][mask])
he = (
    fbulge_params_same_SFR[:, 0][mask],
    fbulge_params_same_SFR[:, 1][mask],
    fbulge_params_same_SFR[:, 2][mask],
)
xxa = [logsm0[mask] for i in range(len(ha))]
lxs = ("logsm0" for i in range(len(ha)))
# fb = (fbulge_tcrit, fbulge_early, fbulge_late)
lys = ("fbulge_tcrit", "fbulge_early", "fbulge_late")
xx = [xxa, xxa, xxa, xxa]
hba = [ha, hc, hd, he]
binz = [Mbins for i in range(len(hba))]
hlabels = [
    "mc_original",
    "fbulge_new_key_SFRcut",
    "fbulge_new_gen_test",
    "fbulge_same_key_SFRcut",
]
plot_qs_profiles(
    xx,
    hba,
    hlabels,
    lxs,
    lys,
    binz,
    plotdir=plotdir,
    title=title,
    xname=xname,
    pltname="check_generated_fbulge_{}.png",
    plotsubdir="Fbulge",
)