In [None]:
import os
from diffaux.disk_bulge_modeling.generate_bulge_disk_sample import (
    get_bulge_disk_test_sample,
    get_bulge_disk_decomposition,
    get_zindexes,
)
from jax import random as jran
import numpy as np
import jax.numpy as jnp
from jax import jit as jjit
from jax import lax, nn
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import itertools
from itertools import zip_longest
from diffsky.utils import _sigmoid

from diffaux.validation.plot_disk_bulge import (
    plot_qs_profiles_for_zvals,
    plot_qs_profiles,
    plot_q_profile,
    plot_sigmoid_2d,
)
from diffaux.disk_bulge_modeling.disk_bulge_kernels import (
    _sigmoid_2d,
)
from diffaux.validation.plot_utilities import get_nrow_ncol, get_subsample

ran_key = jran.key(0)

In [None]:
halo_key, ran_key = jran.split(ran_key, 2)
lgmp_min = 11.0
redshift = 0.05
Lbox = 75.0
diffstar_cens = get_bulge_disk_test_sample(halo_key, lgmp_min=lgmp_min, redshift=redshift, Lbox=Lbox)

In [None]:
zvalues = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5]
redshifts = diffstar_cens["z_table"]
zindexes, zs = get_zindexes(zvalues, redshifts)

In [None]:
from diffsky.dustpop.avpop import DEFAULT_AVPOP_PDICT, DEFAULT_AVPOP_PARAMS, LGSM_K, LGSSFR_K
print(DEFAULT_AVPOP_PDICT, DEFAULT_AVPOP_PARAMS)

In [None]:
def get_av_from_avpop_params(avpop_params, logsm, logssfr):
    lgav_logssfr_x0 = _sigmoid(
        logsm,
        avpop_params.lgav_logsm_x0_x0,
        LGSM_K,
        avpop_params.lgav_logsm_ylo_x0,
        avpop_params.lgav_logsm_yhi_x0,
    )
    lgav_logssfr_q = _sigmoid(
        logsm,
        avpop_params.lgav_logsm_x0_q,
        LGSM_K,
        avpop_params.lgav_logsm_ylo_q,
        avpop_params.lgav_logsm_yhi_q,
    )
    lgav_logssfr_ms = _sigmoid(
        logsm,
        avpop_params.lgav_logsm_x0_ms,
        LGSSFR_K,
        avpop_params.lgav_logsm_ylo_ms,
        avpop_params.lgav_logsm_yhi_ms,
    )

    lgav = _sigmoid(
        logssfr,
        lgav_logssfr_x0,
        LGSSFR_K,
        lgav_logssfr_q,
        lgav_logssfr_ms,
    )
    return lgav_logssfr_x0, lgav_logssfr_q, lgav_logssfr_ms, lgav

In [None]:
logMz0_min = 8.0
logssfr_min = -16
mass_mask = np.log10(diffstar_cens["smh"][:, -1]) > logMz0_min
sfr_mask = np.log10(diffstar_cens["sSFR"][:, -1]) > logssfr_min
mask = mass_mask & sfr_mask
title = "$M^*_{{z=0}} > 10^{{{:.1f}}} M_\\odot, sSFR > 10^{{{}}} yr^{{-1}}$".format(logMz0_min, logssfr_min)
xname = "log_M0_min_{:.1f}_logssfr_min_{}".format(logMz0_min, -logssfr_min)

logsm = np.log10(diffstar_cens['smh'])[mask]
logssfr = np.log10(diffstar_cens['sSFR'])[mask]
lgav_logssfr_x0, lgav_logssfr_q, lgav_logssfr_ms, lgav = get_av_from_avpop_params(DEFAULT_AVPOP_PARAMS, logsm, logssfr)

In [None]:
from diffaux.validation.plot_utilities import get_nrow_ncol

plotdir = "/Users/kovacs/cosmology/BulgeDisk/DiskBulgePlots"

logM_min = 7.5
logM_max = 11.5
Nm = 8
Mbins = np.linspace(logM_min, logM_max, Nm+1)
logsSFR_min = -15
logsSFR_max = -8
Ns = 7
sbins = np.linspace(logsSFR_min, logsSFR_max, Ns +1)

qs = [lgav_logssfr_x0, lgav_logssfr_q, lgav_logssfr_ms, lgav]
xs = [logsm, logsm, logsm, logssfr]
binz = [Mbins, Mbins, Mbins, sbins]
labels = ['logssfr_x0', 'logssfr_q', 'logssfr_ms', 'lgav']
xlabels = ['logsm', 'logsm', 'logsm', 'logssfr']

#qs0 = [lgav_logssfr_x00, lgav_logssfr_q0, lgav_logssfr_ms0, lgav0]
#xs0 = [logsm0, logsm0, logsm0, logssfr0]
#labels0 = ['logssfr_x00', 'logssfr_q0', 'logssfr_ms0', 'lgav0']
#xlabels = ['logsm', 'logsm', 'logsm', 'logssfr']

In [None]:

plot_qs_profiles_for_zvals(xs, qs, binz, labels, zindexes, zs, xlabels,
                           plotsubdir='sigmoids', plotdir=plotdir,
                           pltname="avpop_z_{}.png", title=title, xname=xname)

### 2d Sigmoid

In [None]:
# use Mbins and sbins for x and y
x0 = -10.
y0 = 10
kx_min = 0.2
kx_max = 0.8
Nkx = 2
kxs = np.linspace(kx_min, kx_max, Nkx+1)

ky_min = 0.2
ky_max = 0.8
Nky = 2
kys = np.linspace(ky_min, ky_max, Nky+1)
print(kxs, kys)
kpairs = np.array(list(itertools.product(kxs, kys)))

In [None]:
X, Y = jnp.meshgrid(sbins, Mbins)
print(X[0], Y[:, 0], jnp.min(X), jnp.max(Y))

zmin=1.0
zmax=0.0
for kx, ky in kpairs:
    print(kx, ky)
    z = _sigmoid_2d(X, x0, Y, y0, kx, ky, zmin, zmax)
    print(jnp.min(z), jnp.max(z))
    

In [None]:
plot_sigmoid_2d(X, x0, Y, y0, kpairs, zmin, zmax,
                alpha=0.6, contour=True, ytit=0.99,
               plotdir=plotdir, title=title, xname=xname)

In [None]:
# Try values similar to tcrit
x0 = -10.
y0 = 10
tx_min = 0.2
tx_max = 0.8
Ntx = 2
txs = np.linspace(tx_min, tx_max, Ntx+1)
ty_min = 0.2
ty_max = 0.8
Nty = 2
tys = np.linspace(ty_min, ty_max, Nty+1)
print(txs, tys)
tpairs = np.array(list(itertools.product(txs, tys)))
X, Y = jnp.meshgrid(sbins, Mbins)
print(X[0], Y[:, 0], jnp.min(X), jnp.max(Y))

tmin=13.0
tmax=2.0

In [None]:
plot_sigmoid_2d(X, x0, Y, y0, tpairs, tmin, tmax,
                alpha=0.6, contour=True, ytit=0.99,
                pltname="sigmoid_2d_tcritlike_{}.png",
                plotdir=plotdir, title=title, xname=xname)

In [None]:
zmin=0.2
zmax=0.95
X, Y = jnp.meshgrid(Mbins, sbins)
print(X[0], Y[:, 0], jnp.min(X), jnp.max(Y))
x0 = 10.
y0 = -10.
kxs = np.linspace(-kx_min, -kx_max, Nkx+1)
kys = np.linspace(-ky_min, -ky_max, Nky+1)
print(kxs, kys)
kpairs = np.array(list(itertools.product(kxs, kys)))

In [None]:
plot_sigmoid_2d(X, x0, Y, y0, kpairs, zmin, zmax,
                ylabel="$\\log_{10}(sSFR/yr)$",
                xlabel="$\\log_{10}(M^*/M_\\odot)$",
                alpha=0.6, contour=True, ytit=0.99,
                pltname="sigmoid_2d_reversed_{}.png",
                plotdir=plotdir, title=title, xname=xname)