In [None]:
import os
import re
from diffsky.experimental.disk_bulge_modeling.generate_bulge_disk_sample import (
    get_bulge_disk_test_sample,
    get_bulge_disk_decomposition,
)
from diffaux.validation.plot_utilities import get_zindexes
# or from lsstdesc_diffsky.disk_bulge_modeling.

from jax import random as jran
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import copy
from collections import OrderedDict, namedtuple
from itertools import zip_longest
from diffaux.validation.plot_disk_bulge import (
    plot_q_with_cuts,
    plot_q1_vs_q2,
    plot_q1_q2,
    plot_qs_profiles_with_cuts,
)
from diffaux.validation.plot_utilities import get_nrow_ncol

ran_key = jran.key(0)

In [None]:
halo_key, ran_key = jran.split(ran_key, 2)
lgmp_min = 10.5
redshift = 0.05
Lbox = 100.0
diffstar_cens = get_bulge_disk_test_sample(halo_key, lgmp_min=lgmp_min, redshift=redshift, Lbox=Lbox)
print(list(diffstar_cens.keys()))

In [None]:
zvalues = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5]
redshifts = diffstar_cens["z_table"]
zindexes, zs = get_zindexes(zvalues, redshifts)
print(len(redshifts))

## Tinker with Fbulge Sigmoid Parameters

In [None]:
from diffsky.experimental.disk_bulge_modeling.mc_disk_bulge import DEFAULT_FBULGE_2dSIGMOID_PARAMS

print(DEFAULT_FBULGE_2dSIGMOID_PARAMS)
diffstar_old = copy.deepcopy(diffstar_cens)
diffstar_old = get_bulge_disk_decomposition(
    diffstar_old,
    fbulge_2d_params=DEFAULT_FBULGE_2dSIGMOID_PARAMS,
)

In [None]:
fbulge_1 = DEFAULT_FBULGE_2dSIGMOID_PARAMS._replace(early_zmax=0.2)
fbulge_2 = fbulge_1._replace(early_logsm0_x0=10.5)
fbulge_3 = fbulge_2._replace(early_logssfr0_x0=-10.5)
fbulge_4 = fbulge_3._replace(early_logssfr0_k=0.8)
fbulge_5 = fbulge_4._replace(early_logsm0_k=0.6)
fbulge_6 = fbulge_5._replace(early_zmin=0.9)
fbulge_7 = fbulge_6._replace(late_zmax=0.0)
fbulge_8 = fbulge_7._replace(late_logsm0_x0=10.5)
fbulge_9 = fbulge_8._replace(late_logssfr0_x0=-10.5)
fbulge_10 = fbulge_9._replace(late_logssfr0_k=0.5)
fbulge_new = fbulge_10._replace(late_logsm0_k=0.4)
print(fbulge_new)
diffstar_new = copy.deepcopy(diffstar_cens)
diffstar_new = get_bulge_disk_decomposition(
    diffstar_new,
    fbulge_2d_params=fbulge_new,
)

In [None]:
for k in ["bth", "logsm_obs", "logssfr_obs", "fbulge_tcrit", "fbulge_early", "fbulge_late"]:
    print(
        k, np.min(diffstar_old[k]), np.max(diffstar_old[k]), np.min(diffstar_new[k]), np.max(diffstar_new[k])
    )

## Fbulge and B/T Comparisons

In [None]:
plotdir = "/Users/kovacs/cosmology/BulgeDisk/DiskBulgePlots"

In [None]:
# setup mass mask
logMz0_min = 7.5
mass_mask = np.log10(diffstar_cens["smh"][:, -1]) > logMz0_min
lgnd_title = ", $\\log_{{10}}(M^*_{{z=0}}/M_\\odot) > {:.1f}$".format(logMz0_min)
versions = ["(old)", "(new)"]
BT_bins = np.linspace(0.0, 1.0, 21)

sfr_cuts_lo = [-16, -10]
sfr_cuts_hi = [-11, -8]

for sfr_lo, sfr_hi, label, colors in zip(
    sfr_cuts_lo, sfr_cuts_hi, ["Bulge {}", "Disk {}"], [["orange", "r"], ["cyan", "blue"]]
):
    qlabels = [label.format(v) for v in versions]
    plot_q1_q2(
        diffstar_old["bth"][mass_mask],
        diffstar_new["bth"][mass_mask],
        zvalues,
        redshifts,
        jnp.log10(diffstar_old["sSFR"])[mass_mask],
        sfr_lo,
        sfr_hi,
        dz=0.2,
        lgnd_title=lgnd_title,
        plotdir=plotdir,
        xname="log_sSFR_for_{}_model_compare_tuning".format(re.split(" ", label)[0]),
        xlabel="B/T",
        cut_name="$\\log_{10}(sSFR/yr^{-1})$",
        pltname="BT_cut_on_{}.png",
        qlabels=qlabels,
        cut_at_z0=False,
        colors=colors,
        xscale="linear",
        bins=BT_bins,
        xlimits=(0.0, 1.0),
    )

## sSFR vs M*colored by B/T

In [None]:
logMz0_min = 8.0
logssfr_min = -14
mass_mask = np.log10(diffstar_cens["smh"][:, -1]) > logMz0_min
title = "$M^*_{{z=0}} > 10^{{{:.1f}}} M_\\odot, sSFR > 10^{{{}}} yr^{{-1}}$".format(logMz0_min, logssfr_min)
xname = "log_M0_min_{:.1f}_logssfr_min_{}".format(logMz0_min, -logssfr_min)

for diffstar, v in zip([diffstar_old, diffstar_new], ["old", "new"]):
    plot_q1_vs_q2(
        jnp.log10(diffstar["smh"][mass_mask]),
        jnp.log10(diffstar["sSFR"][mass_mask]),
        zvalues,
        redshifts,
        diffstar["bth"][mass_mask],
        title=" ".join([title, v]),
        xname=xname + "_model_tweak_{}".format(v),
        ymin=logssfr_min,
        N=2000,
        plotdir=plotdir,
    )

## B/T as a function of M* for SF and Q galaxies

In [None]:
logMz0_min = 8.0
logssfr_min = -15
mass_mask = np.log10(diffstar_cens["smh"][:, -1]) > logMz0_min
sfr_mask = np.log10(diffstar_cens["sSFR"][:, -1]) > logssfr_min
mask = mass_mask & sfr_mask
title = "$M^*_{{z=0}} > 10^{{{:.1f}}} M_\\odot, sSFR > 10^{{{}}} yr^{{-1}}$".format(logMz0_min, logssfr_min)
xname = "log_M0_min_{:.1f}_logssfr_min_{}".format(logMz0_min, -logssfr_min)

In [None]:
for diffstar, v in zip([diffstar_old, diffstar_new], ["old", "new"]):
    plot_q1_vs_q2(
        jnp.log10(diffstar["smh"][mask]),
        diffstar["bth"][mask],
        zvalues,
        redshifts,
        jnp.log10(diffstar["sSFR"][mask]),
        title=" ".join([title, v]),
        xname=xname + "_model_tweak_{}".format(v),
        cmap="jet_r",
        N=2000,
        ymin=0,
        ymax=1.0,
        xmin=7,
        xmax=12,
        xlabel="$\\log_{10}(M^*/M_\\odot)$",
        ylabel="B/T",
        cbar_title="$\\log_{10}(sSFR/yr)$",
        pltname="BoverT_vs_logMstar_{}.png",
        plotdir=plotdir,
        label_y=0.95,
    )

In [None]:
qs = [diffstar_old["bth"][mask], diffstar_new["bth"][mask]]
xs = [diffstar_old["smh"][mask], diffstar_new["smh"][mask]]
cut_arrays = [jnp.log10(diffstar_old["sSFR"][mask]), jnp.log10(diffstar_new["sSFR"][mask])]
cuts = [-11, -10.5]
colors_list = (("orange", "cyan"), ("r", "blue"))

In [None]:
plot_qs_profiles_with_cuts(
    qs,
    xs,
    zvalues,
    redshifts,
    cut_arrays,
    cuts,
    slabels=["old", "new"],
    dz=0.2,
    bin_lo=7.5,
    bin_hi=12.0,
    Nbins=18,
    cut_labels=("{{}} $\\leq$ {:.0f}", "{{}} $\\geq$ {:.0f}"),
    colors_list=colors_list,
    cut_name="$\\log_{10}(sSFR/yr^{-1})$",
    plotdir=plotdir,
    plotsubdir="DiskBulge_Profiles",
    pltname="BT_cut_on_{}_comparison.png",
    yscale="",
    xscale="log",
    xlabel="$\\log_{10}(M^*/M_\\odot)$",
    xname="log_sSFR",
    lgnd_title="",
    error=False,
    ylabel="B/T",
)