In [None]:
import os
from diffaux.disk_bulge_modeling.generate_bulge_disk_sample import (
    get_bulge_disk_test_sample,
    get_bulge_disk_decomposition,
    get_zindexes,
)

from jax import random as jran
import numpy as np
import jax.numpy as jnp
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict, namedtuple
from itertools import zip_longest
from diffaux.validation.plot_disk_bulge import (
    plot_qs_profiles_for_zvals,
    plot_qs_profiles,
    plot_q_profile,
    plot_q_vs_xs_color_scatter,
)
from diffaux.disk_bulge_modeling.mc_disk_bulge import generate_fbulge_parameters_2d_sigmoid
from diffaux.disk_bulge_modeling.disk_bulge_kernels import (
    calc_tform_pop,
    _sigmoid,
    _sigmoid_2d,
)
from diffaux.validation.plot_utilities import get_nrow_ncol

# from diffsky.utils import _sigmoid

ran_key = jran.key(0)

In [None]:
halo_key, ran_key = jran.split(ran_key, 2)
lgmp_min = 11.0
redshift = 0.05
Lbox = 75.0
diffstar_cens = get_bulge_disk_test_sample(halo_key, lgmp_min=lgmp_min, redshift=redshift, Lbox=Lbox)

In [None]:
zvalues = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5]
redshifts = diffstar_cens["z_table"]
zindexes, zs = get_zindexes(zvalues, redshifts)

## Fbulge Experiments

In [None]:
logMz0_min = 8.0
logssfr_min = -16
mass_mask = np.log10(diffstar_cens["smh"][:, -1]) > logMz0_min
sfr_mask = np.log10(diffstar_cens["sSFR"][:, -1]) > logssfr_min
mask = mass_mask & sfr_mask
title = "$M^*_{{z=0}} > 10^{{{:.1f}}} M_\\odot, sSFR > 10^{{{}}} yr^{{-1}}$".format(logMz0_min, logssfr_min)
xname = "log_M0_min_{:.1f}_logssfr_min_{}".format(logMz0_min, -logssfr_min)

logsm = np.log10(diffstar_cens["smh"])[mask]
logssfr = np.log10(diffstar_cens["sSFR"])[mask]

In [None]:
from dsps.constants import SFR_MIN
from diffstar.utils import cumulative_mstar_formed_galpop
from diffaux.disk_bulge_modeling.disk_bulge_kernels import (
    calc_tform_pop,
)

tarr = diffstar_cens["t_table"]
print(SFR_MIN)

# calculate t10 and t90 exactly as in code
sfh_pop = diffstar_cens["sfh"]
sfh_pop = np.where(sfh_pop < SFR_MIN, SFR_MIN, sfh_pop)
smh_pop = cumulative_mstar_formed_galpop(tarr, sfh_pop)
t10 = calc_tform_pop(tarr, smh_pop, 0.1)
t90 = calc_tform_pop(tarr, smh_pop, 0.9)
print(np.min(t10), np.max(t10), np.min(t90), np.max(t90))
print(np.min(sfh_pop), np.max(sfh_pop), np.min(smh_pop), np.max(smh_pop))

## 2d Sigmoid: Fbulge_early

In [None]:
# Try using sample logssfr0 and logsm0
x0 = -10.0
y0 = 10
kx = 0.2
ky = 0.2
zmin = 1.0
zmax = 0.2
logsm0 = np.log10(diffstar_cens["smh"][:, -1])
logssfr0 = np.log10(diffstar_cens["sSFR"][:, -1])
fbulge_early = _sigmoid_2d(logssfr0, x0, logsm0, y0, kx, ky, zmin, zmax)
print(fbulge_early.shape)

plotdir = "/Users/kovacs/cosmology/BulgeDisk/DiskBulgePlots"

In [None]:
xs = [logsm0[mask], logssfr0[mask]]
color_arrays = [logssfr0[mask], logsm0[mask]]

plot_q_vs_xs_color_scatter(
    fbulge_early[mask],
    xs,
    color_arrays,
    "fbulge_early",
    xname=xname,
    title=title,
    wspace=0.35,
    N=2000,
    plotdir=plotdir,
    pltname="{}_vs_SFR_Mstar_{}_new.png",
)

## 2d Sigmoid: Fbulge_late

In [None]:
fbulge_late = _sigmoid_2d(logssfr0, x0, logsm0, y0, kx, ky, fbulge_early, zmax)
print(fbulge_late.shape)

In [None]:
plot_q_vs_xs_color_scatter(
    fbulge_late[mask],
    xs,
    color_arrays,
    "fbulge_late",
    xname=xname,
    title=title,
    wspace=0.35,
    N=2000,
    plotdir=plotdir,
    pltname="{}_vs_SFR_Mstar_{}_new.png",
)

In [None]:
fs = [fbulge_late[mask], fbulge_late[mask]]
flabels = ["fbulge_late", "fbulge_late"]
plot_q_vs_xs_color_scatter(
    fbulge_early[mask],
    fs,
    color_arrays,
    "fbulge_early_late",
    xname=xname,
    title=title,
    wspace=0.35,
    xlabels=flabels,
    N=2000,
    plotdir=plotdir,
    pltname="{}_vs_SFR_Mstar_{}_new.png",
)

## 2d sigmoid: tcrit_bulge

In [None]:
# For tcrit, use sigmoid
km = 0.8
kf = 0.5
xf = -10.0
xm = 10
tcrit_bulge = _sigmoid_2d(logssfr0, xf, logsm0, xm, kf, km, t90, t10)
print(tcrit_bulge.shape)

In [None]:
plot_q_vs_xs_color_scatter(
    tcrit_bulge[mask],
    xs,
    color_arrays,
    "tcrit_bulge",
    xname=xname,
    title=title,
    wspace=0.35,
    N=3000,
    plotdir=plotdir,
    pltname="{}_vs_SFR_Mstar_{}_new.png",
)

In [None]:
ts = [t10[mask], t10[mask]]
tlabels = ["t10", "t10"]
plot_q_vs_xs_color_scatter(
    tcrit_bulge[mask],
    ts,
    color_arrays,
    "tcrit_bulge",
    xname=xname,
    title=title,
    wspace=0.35,
    xlabels=tlabels,
    N=3000,
    plotdir=plotdir,
    pltname="{}_vs_t10_{}_new.png",
)

In [None]:
ts = [t90[mask], t90[mask]]
tlabels = ["t90", "t90"]
plot_q_vs_xs_color_scatter(
    tcrit_bulge[mask],
    ts,
    color_arrays,
    "tcrit_bulge",
    xname=xname,
    title=title,
    wspace=0.35,
    xlabels=tlabels,
    N=3000,
    plotdir=plotdir,
    pltname="{}_vs_t90_{}_new.png",
)

In [None]:
# Put this together into a new generate_fbulge_parameters
# No stocastcity yet
from diffaux.disk_bulge_modeling.mc_disk_bulge import DEFAULT_FBULGEPARAMS

print(DEFAULT_FBULGEPARAMS)
print(DEFAULT_FBULGEPARAMS.tcrit_logsm0_k)

In [None]:
fbulge_key, ran_key = jran.split(ran_key, 2)
fbulge_params = generate_fbulge_parameters_2d_sigmoid(
    ran_key, logsm0, logssfr0, t10, t90, DEFAULT_FBULGEPARAMS
)
print(fbulge_params.shape)