In [None]:
import os
from diffsky.mc_diffsky import mc_diffstar_galpop, mc_diffstar_cenpop
from jax import random as jran

ran_key = jran.key(0)

### Generate subcat and SFH catalog

In [None]:
halo_key, ran_key = jran.split(ran_key, 2)
lgmp_min = 11.0
redshift = 0.05
Lbox = 100.0
volume_com = Lbox**3
args = (ran_key, redshift, lgmp_min, volume_com)
diffstar_cens = mc_diffstar_cenpop(*args, return_internal_quantities=True)

In [None]:
print(len(diffstar_cens['t_table']), diffstar_cens['sfh'].shape)

### Explore Existing Model for Disk-Bulge Decomposition

In [None]:
from diffaux.disk_bulge_modeling.mc_disk_bulge import mc_disk_bulge, generate_fbulge_params

In [None]:
disk_bulge_key, ran_key = jran.split(ran_key, 2)
_res = mc_disk_bulge(ran_key, diffstar_cens['t_table'], diffstar_cens['sfh'])
fbulge_params, smh, eff_bulge, sfh_bulge, smh_bulge, bth = _res
tcrit_bulge = fbulge_params[:, 0]
fbulge_early = fbulge_params[:, 1]
fbulge_late = fbulge_params[:, 2]

In [None]:
print(smh.shape, eff_bulge.shape, sfh_bulge.shape, smh_bulge.shape, bth.shape)

In [None]:
# Compute sSSFR and disk quantities
diffstar_cens['sSFR'] = np.divide(diffstar_cens['sfh'], diffstar_cens['smh'])
diffstar_cens['sSFR_bulge'] = np.divide(sfh_bulge, smh_bulge)
diffstar_cens['smh_disk'] = diffstar_cens['smh'] - smh_bulge
diffstar_cens['sfh_disk'] = diffstar_cens['sfh'] - sfh_bulge
diffstar_cens['sSFR_disk'] = np.divide(diffstar_cens['sfh_disk'], diffstar_cens['smh_disk'])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.cm as cm
from itertools import zip_longest
from diffaux.validation.plot_utilities import get_nrow_ncol
plotdir = '/Users/kovacs/cosmology/DiskBulgePlots'

qs = [eff_bulge, smh_bulge, diffstar_cens['smh'], bth,
      #sfh_bulge, diffstar_cens['sfh'],
      diffstar_cens['sSFR_bulge'], diffstar_cens['sSFR']]
ylabels = ['Bulge_efficiency', 'Bulge_SMH ($M_\\odot$)', 'SMH ($M_\\odot$)', 'B/T',
           #'Bulge_SFH ($M_\\odot \\mathrm{yr}^{-1}$)', 'SFH ($M_\\odot \\mathrm{yr}^{-1}$)',
           'Bulge_sSFH ($\\mathrm{yr}^{-1}$)', 'sSFH ($\\mathrm{yr}^{-1}$)']
labels = ['effB', 'SMHB', 'SMH', 'BT',
          #'SFHB', 'SFH',
          'sSFRB', 'sSFR']

def plot_histories(qs, t_table, labels, ylabels, plot_label=None,
                   color_array=None, row_mask=None, lgnd_label='#{}', 
                   pltname='History_{}_step_{}.png', yscale='', reverse=False,
                   check_step=5000, xlimlo=0.5, xlimhi=14,
                   step=300, plotdir = os.path.join(plotdir, 'DiskBulge_Histories')):
    
    nrow, ncol = get_nrow_ncol(len(qs))
    fig, ax_all = plt.subplots(nrow, ncol, figsize=(5*ncol, 4*nrow))
    colors=cm.coolwarm(np.linspace(0, 1, len(qs[0]))) #expect all q to be same length
    #print(len(colors))
    indices = np.linspace(0, len(qs[0])-1, len(qs[0]), dtype=int)
    if not plot_label: #assume plot_label, color_array, indx_array supplied together
        color_array = indices
    if row_mask is None:
        row_mask = np.ones(len(qs[0]), dtype=bool)
    sort_array = np.argsort(color_array[row_mask])
    if reverse:
        sort_array =  sort_array[::-1]
        print(len(color_array[row_mask]), color_array[row_mask][sort_array][::check_step])

    for ax, q, ylabel in zip_longest(ax_all.flat, qs, ylabels):
        if ylabel is None:
            ax.set_visible(False)
            continue
        if len(color_array[row_mask]) != len(q[row_mask]):
            print('oops: array mismatch')
        for n, (h, c, i) in enumerate(zip(q[row_mask][sort_array][::step],
                                          color_array[row_mask][sort_array][::step],
                                          indices[row_mask][::step])):
            if n==0:
                __=ax.plot(t_table, h, color=colors[i], label=lgnd_label.format(c))
            elif n==int(len(q)/step):
                __=ax.plot(t_table, h, color=colors[i], label=lgnd_label.format(c))
            else:
                __=ax.plot(t_table, h, color=colors[i])

        ax.set_xlim(xlimlo, xlimhi)
        ax.set_ylabel(ylabel)
        if yscale=='log' and ('SMH' in ylabel or 'SFH' in ylabel or 'SFR' in ylabel):
            ax.set_yscale('log')
        ax.set_xlabel('$t$ (Gyr)')
        ax.legend(loc='best')
    fig.suptitle('Sample Histories')
    
    xname = '_'.join(labels)
    if plot_label:
        xname = '_'.join([xname, plot_label])
    if yscale=='log':
        xname = '_'.join([xname, yscale])
    fn = os.path.join(plotdir, pltname.format(xname, step))
    plt.tight_layout()
    plt.savefig(fn)
    print('Saving {}'.format(fn))

In [None]:
plot_histories(qs, diffstar_cens['t_table'], labels, ylabels, step=5000)

In [None]:
logM_min = 8.0
row_mask = np.log10(diffstar_cens['smh'][:, -1]) > logM_min
print(np.count_nonzero(mask))
print(np.log10(diffstar_cens['smh'][:, -1][row_mask][0]), np.log10(diffstar_cens['smh'][:, -1][row_mask][30000]))
plot_histories(qs, diffstar_cens['t_table'], labels, ylabels, 
               color_array=np.log10(diffstar_cens['smh'][:, -1]), row_mask=row_mask, yscale='log',
               plot_label='Mstar_z0', step=500, lgnd_label='$\\log_{{10}}(M^*/M_\\odot) = {:.1f}$')

In [None]:
# sfh_bulge = eff_bulge*diffstar_cens['sfh']
plot_histories(qs, diffstar_cens['t_table'], labels, ylabels, 
               color_array=np.log10(diffstar_cens['sSFR'][:, -1]), row_mask=row_mask, yscale='log',
               reverse=True,
               plot_label='sSFR_z0', step=500, lgnd_label='$\\log_{{10}}(sSFR/yr^{{-1}})) = {:.1f}$')


In [None]:
from dsps.cosmology.defaults import DEFAULT_COSMOLOGY
from dsps.cosmology.flat_wcdm import lookback_to_z, age_at_z, age_at_z0
tl = lookback_to_z(1.0, *DEFAULT_COSMOLOGY)
a = age_at_z(1.0, *DEFAULT_COSMOLOGY)
print(DEFAULT_COSMOLOGY, tl, a, tl+a, age_at_z0(*DEFAULT_COSMOLOGY))
#print(diffstar_cens['t_table'])
# interpolate to invert t_table
def get_redshifts_from_times(t_table, cosmo_params, zmin=.001, zmax=50, Ngrid=200, zcheck=3):
    zgrid = np.logspace(np.log10(zmax), np.log10(zmin), Ngrid)
    age_grid = age_at_z(zgrid, *cosmo_params)
    #print(age_grid)
    redshifts = np.interp(t_table, age_grid, zgrid)
    mask = (redshifts <= zcheck)
    t_interp = age_at_z(redshifts, *cosmo_params)
    check = np.isclose(t_interp, diffstar_cens['t_table'], atol=1e-3, rtol=1e-3)
    print("Check within 1e-3 for z< {}: {}".format(zcheck, np.all(check[mask])))
    #print(check)
    return redshifts

redshifts = get_redshifts_from_times(diffstar_cens['t_table'], DEFAULT_COSMOLOGY)
#print(redshifts)
zvalues = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5]
dz = 0.1


In [None]:
def plot_q(q, zvalues, redshifts, cut_array, cuts, dz=0.1,
           cut_labels=['{{}} $\\leq$ {:.0f}', '{{}} $\\geq$ {:.0f}'],
           colors = ['r', 'blue'], xlabel='B/T', cut_name='$\\log_{10}(sSFR/yr)$',
           pltname='BT_cut_on_{}.png', yscale='log', xscale='', bins=50, xname='log_sSFR',
           plotdir = os.path.join(plotdir, '/DiskBulge_Histograms'), lgnd_title=''):
    
    nrow, ncol = get_nrow_ncol(len(zvalues))
    fig, ax_all = plt.subplots(nrow, ncol, figsize=(5*ncol, 4*nrow))

    for ax, z in zip_longest(ax_all.flat, zvalues):
        zmask = (z-dz <= redshifts) & (z+dz >= redshifts)
        zlabel = '${:.1f} \\leq z \\leq {:.1f}$'.format(max(0., z-dz), z+dz)
        # apply row mask to arrays
        q_z = q[:, zmask]
        cut_array_z = cut_array[:, zmask]
        #print(cut_array_z.shape, q_z.shape)
        for n, (cut, cut_label, color) in enumerate(zip(cuts, cut_labels, colors)):
            cut_mask = (cut_array_z <= cut) if n==0 else (cut_array_z >= cut)
            label = cut_label.format(cut).format(cut_name)
            #print(np.count_nonzero(cut_mask), q_z[cut_mask].shape)
            ax.hist(q_z[cut_mask], bins=bins, color=color, label=label, alpha=0.4)
        if yscale=='log':
            ax.set_yscale('log')
        if xscale=='log':
            ax.set_xscale('log')
        ax.set_xlabel(xlabel)
        ax.set_ylabel('$N$')
        ax.legend(loc='best', title=zlabel+lgnd_title)
    
    fn = os.path.join(plotdir, pltname.format(xname))
    plt.tight_layout()
    plt.savefig(fn)
    print('Saving {}'.format(fn))
    return

In [None]:
# setup mass mask
logMz0_min = 8.0
mass_mask = np.log10(diffstar_cens['smh'][:, -1]) > logMz0_min
lgnd_title=', $\\log_{{10}}(M^*_{{z=0}}/M_\\odot) > {:.1f}$'.format(logMz0_min)

In [None]:
# test what masks do
zmask = redshifts < 0.2
q_z = bth[mass_mask][:, zmask]
cut_array_z = np.log10(diffstar_cens['sSFR'])[mass_mask][:, zmask]
cut_mask = (cut_array_z <= -11)
print(cut_mask.shape, q_z.shape)
#print(cut_mask[0], q_z[0], cut_mask[3], q_z[3])
print(np.count_nonzero(cut_mask), q_z[cut_mask].shape)
#print(q_z[cut_mask][0:22])

In [None]:
plot_q(bth[mass_mask], zvalues, redshifts, np.log10(diffstar_cens['sSFR'])[mass_mask], [-11, -10], dz=0.2,
       lgnd_title=lgnd_title,
      )
       


In [None]:
plot_q(eff_bulge[mass_mask], zvalues, redshifts, np.log10(diffstar_cens['sSFR'])[mass_mask], [-11, -10], dz=0.2,
       pltname='effB_cut_on_{}.png', xlabel='Bulge Efficiency',
       lgnd_title=lgnd_title,
      )

In [None]:
logMz0_min = 8.5
mass_mask = np.log10(diffstar_cens['smh'][:, -1]) > logMz0_min
print(mass_mask.shape)
print(smh_bulge[mass_mask].shape)
bins = np.logspace(7, 12, 51)
lgnd_title=', $\\log_{{10}}(M^*_{{z=0}}/M_\\odot) > {:.1f}$'.format(logMz0_min)
#print(lgnd_title)

plot_q(smh_bulge[mass_mask], zvalues, redshifts, np.log10(diffstar_cens['sSFR'])[mass_mask], [-11, -10], dz=0.2,
       pltname='bulge_mass_cut_on_{}.png', xlabel='Bulge Mass ($M_\\odot$)', xscale='log', bins=bins,
       lgnd_title=lgnd_title,
      )

In [None]:
def plot_q1_q2(q1, q2, zvalues, redshifts, cut_array, cut_lo, cut_hi, dz=0.1,
           cut_label='{:.1f} $\\leq$ {{}} $\\leq$ {:.1f}', qlabels=['Bulge', 'Disk'],
           colors = ['r', 'blue'], xlabel='sSFR $(yr^{-1})$', cut_name='$\\log_{10}(M^*_{z=0}/M_\\odot)$',
           pltname='log_sSFR_cut_on_{}.png', yscale='log', xscale='log', bins=50, xname='log_M0_{:.1f}_{:.1f}',
           plotdir = os.path.join(plotdir, 'DiskBulge_Histograms'), lgnd_title=''):
    
    nrow, ncol = get_nrow_ncol(len(zvalues))
    fig, ax_all = plt.subplots(nrow, ncol, figsize=(5*ncol, 4*nrow))

    clabel = cut_label.format(cut_lo, cut_hi).format(cut_name)
    for ax, z in zip_longest(ax_all.flat, zvalues):
        zmask = (z-dz <= redshifts) & (z+dz >= redshifts)
        zlabel = '${:.1f} \\leq z \\leq {:.1f}$'.format(max(0., z-dz), z+dz)
        # apply row mask to arrays
        q1_z = q1[:, zmask]
        q2_z = q2[:, zmask]
        cut_array_z = cut_array[:, zmask]
        #print(q1_z.shape, cut_array_z.shape)
        #cut on value at z=0
        cut_mask = (cut_array_z[:, -1] >= cut_lo) & (cut_array_z[:, -1] < cut_hi)
        #now broadcast back to 2-d mask
        cut_mask = np.broadcast_to(cut_mask, (np.count_nonzero(zmask), len(cut_mask))).T
        for q_z, qlabel, color in zip([q1_z[cut_mask], q2_z[cut_mask]], qlabels, colors):
            ax.hist(q_z, bins=bins, color=color, label=qlabel, alpha=0.4)
            
        if yscale=='log':
            ax.set_yscale('log')
        if xscale=='log':
            ax.set_xscale('log')
        ax.set_xlim(max(np.min(bins), min(np.min(q1_z[cut_mask]), np.min(q1_z[cut_mask]))*0.5),
                    min(np.max(bins), max(np.max(q1_z[cut_mask]), np.max(q1_z[cut_mask])))*2.)
        ax.set_xlabel(xlabel)
        ax.set_ylabel('$N$')
        ax.legend(loc='best', title='\n'.join([zlabel, clabel]))

    fn = os.path.join(plotdir, pltname.format(xname.format(cut_lo, cut_hi)))
    plt.tight_layout()
    plt.savefig(fn)
    print('Saving {}'.format(fn))
    return
    

In [None]:
#logMz0_min = 8.5
#mass_mask = np.log10(diffstar_cens['smh'][:, -1]) > logMz0_min
bins = np.logspace(-14, -7, 71)
mass_bins = np.linspace(8.5, 11.5, 4)
print(mass_bins, np.min(bins), np.max(bins))
# test masking
zmask = redshifts < 0.2
q1_z = diffstar_cens['sSFR_bulge'][:, zmask]
print(q1_z.shape)
cut_array_z= np.log10(diffstar_cens['smh'])[:, zmask]
cut_mask = (cut_array_z[:, -1]>=8.5) & (cut_array_z[:, -1] < 11.5)
print(cut_mask.shape, np.where(cut_mask==False)[0][0:2])
cut_mask = np.broadcast_to(cut_mask, (np.count_nonzero(zmask), len(cut_mask))).T
print(cut_mask.shape)
q1 = q1_z[cut_mask]
#print(q1_z[36:39])
#print(q1[18*36: 18*37])
print(q1.shape)

In [None]:
for m_lo, m_hi in zip(mass_bins[0:-1], mass_bins[1:]):
    plot_q1_q2(diffstar_cens['sSFR_bulge'], diffstar_cens['sSFR_disk'], zvalues, redshifts, 
           np.log10(diffstar_cens['smh']), m_lo, m_hi, dz=0.2,
           bins=bins,
          )