In [None]:
import os
import qp
import jax
import matplotlib.pyplot as plt
import seaborn as sns
from jax import numpy as jnp

import pandas as pd
import tables_io
import numpy as np

from jax import vmap
from jax.scipy.special import gamma as jgamma
import pickle

In [None]:
path_prior_faint = os.path.abspath('model_shireSPS_inform_lsstSimhp10552_demo_splithi.pkl')
path_prior_bright = os.path.abspath('model_shireSPS_inform_lsstSimhp10552_demo_splitlow.pkl')
path_prior_ref = os.path.abspath('model_shireSPS_inform_lsstSimhp10552_demo.pkl')

In [None]:
with open(path_prior_faint, 'rb') as _f1:
    model_prior_faint = pickle.load(_f1)

with open(path_prior_bright, 'rb') as _f2:
    model_prior_bright = pickle.load(_f2)

with open(path_prior_ref, 'rb') as _f0:
    model_prior_ref = pickle.load(_f0)

In [None]:
bpz_model = {
    'fo_arr': jnp.array([0.35, 0.5, 0.15]),
    'kt_arr': jnp.array([0.147, 0.450]),
    'zo_arr': jnp.array([0.431, 0.39, 0.063]),
    'a_arr': jnp.array([2.46, 1.81, 0.91]),
    'km_arr': jnp.array([0.091, 0.0636, 0.123]),
    'mo': 20.0,
    'nt_array': jnp.array([1, 2, 3])
}

In [None]:
DATDIR = "/global/u2/j/jcheval/rail_base/src"
cosmospriorfile = os.path.join(DATDIR, "rail", "examples_data", "estimation_data", "data", "COSMOS31_HDFN_prior.pkl")
with open(cosmospriorfile, 'rb') as _f:
    cosmos_prior_dict = pickle.load(_f)
cosmos_prior_dict['nt_array'] = jnp.array([10, 5, 16])
cosmos_prior_dict['mo'] = 20.0

In [None]:
f,a = plt.subplots(1,1)

def frac_func(X, m0, m):
    _m = jnp.where(m<m0, m0, m)
    fo, kt = X
    return fo * jnp.exp(-kt * (_m - m0))

def kt3(fo_arr, kt_arr, m0, m):
    _m = jnp.where(m<m0, m0, m)
    kt_val = -jnp.log((1 - fo_arr[0] * jnp.exp(-kt_arr[0] * (_m - m0)) - fo_arr[1] * jnp.exp(-kt_arr[1] * (_m - m0))) / fo_arr[-1]) / (_m - m0)
    return kt_val

refmags = jnp.linspace(16, 26, 1000)

_sum_to_one = np.ones_like(refmags)
for ityp, (typ, _c) in enumerate(zip(["E_S0", "Sbc/Scd", "Irr"], ['tab:blue', 'tab:orange', 'tab:green'])):
    nt, m0 = model_prior_faint['nt_array'][ityp], model_prior_faint['mo'][ityp]
    fo = model_prior_faint['fo_arr'][ityp] if ityp<2 \
        else 1-np.sum(model_prior_faint['fo_arr'])
    kt = model_prior_faint['kt_arr'][ityp] if ityp<2 \
        else kt3(model_prior_faint['fo_arr'], model_prior_faint['kt_arr'], model_prior_faint['mo'][ityp], refmags)
    _frac = frac_func((fo, kt), m0, refmags)
    if ityp<2 : _sum_to_one -= _frac
    frac = _frac if ityp<2 else _sum_to_one
    a.plot(refmags, frac, label=typ+" FAINT -- LSST sim", c=_c, ls=':')

_sum_to_one = np.ones_like(refmags)
for ityp, (typ, _c) in enumerate(zip(["E_S0", "Sbc/Scd", "Irr"], ['tab:blue', 'tab:orange', 'tab:green'])):
    nt, m0 = model_prior_bright['nt_array'][ityp], model_prior_bright['mo'][ityp]
    fo = model_prior_bright['fo_arr'][ityp] if ityp<2 \
        else 1-np.sum(model_prior_bright['fo_arr'])
    kt = model_prior_bright['kt_arr'][ityp] if ityp<2 \
        else kt3(model_prior_bright['fo_arr'], model_prior_bright['kt_arr'], model_prior_bright['mo'][ityp], refmags)
    _frac = frac_func((fo, kt), m0, refmags)
    if ityp<2 : _sum_to_one -= _frac
    frac = _frac if ityp<2 else _sum_to_one
    a.plot(refmags, frac, label=typ+" BRIGHT -- LSST sim", c=_c, ls='--')

_sum_to_one = np.ones_like(refmags)
for ityp, (typ, _c) in enumerate(zip(["E_S0", "Sbc/Scd", "Irr"], ['tab:blue', 'tab:orange', 'tab:green'])):
    nt, m0 = model_prior_ref['nt_array'][ityp], model_prior_ref['mo'][ityp]
    fo = model_prior_ref['fo_arr'][ityp] if ityp<2 \
        else 1-np.sum(model_prior_ref['fo_arr'])
    kt = model_prior_ref['kt_arr'][ityp] if ityp<2 \
        else kt3(model_prior_ref['fo_arr'], model_prior_ref['kt_arr'], model_prior_ref['mo'][ityp], refmags)
    _frac = frac_func((fo, kt), m0, refmags)
    if ityp<2 : _sum_to_one -= _frac
    frac = _frac if ityp<2 else _sum_to_one
    a.plot(refmags, frac, label=typ+" all LSST sim", c=_c, lw=3)

default = frac_func((1/3, 0.3), 20., refmags)
a.plot(refmags, default, label='default', c='k')
a.set_xlabel('AB magnitude in i-band')
a.set_ylabel('Fraction of galaxies of a given type')
a.legend()

In [None]:
f,a = plt.subplots(1,1)

def frac_func(X, m0, m):
    _m = jnp.where(m<m0, m0, m)
    fo, kt = X
    return fo * jnp.exp(-kt * (_m - m0))

def kt3(fo_arr, kt_arr, m0, m):
    _m = jnp.where(m<m0, m0, m)
    kt_val = -jnp.log((1 - fo_arr[0] * jnp.exp(-kt_arr[0] * (_m - m0)) - fo_arr[1] * jnp.exp(-kt_arr[1] * (_m - m0))) / fo_arr[-1]) / (_m - m0)
    return kt_val

refmags = jnp.linspace(16, 26, 1000)

_sumcos=np.zeros_like(refmags)
for ityp, (typ, _c) in enumerate(zip(["E_S0", "Sbc/Scd"], ['tab:blue', 'tab:orange'])):
    nt, m0 = cosmos_prior_dict['nt_array'][ityp],\
        cosmos_prior_dict['mo']
    fo = cosmos_prior_dict['fo_arr'][ityp]
    kt = cosmos_prior_dict['kt_arr'][ityp]
    frac = frac_func((fo, kt), m0, refmags)
    _sumcos+=frac
    a.plot(refmags, frac, label=typ+" COSMOS (RAIL)", c=_c, ls='--')
a.plot(refmags, 1.0-_sumcos, label="Irr COSMOS (RAIL)", c='tab:green', ls='--')

_sumbpz=np.zeros_like(refmags)
for ityp, (typ, _c) in enumerate(zip(["E_S0", "Sbc/Scd"], ['tab:blue', 'tab:orange'])):
    nt, m0 = bpz_model['nt_array'][ityp],\
        bpz_model['mo']
    fo = bpz_model['fo_arr'][ityp]
    kt = bpz_model['kt_arr'][ityp]
    frac = frac_func((fo, kt), m0, refmags)
    _sumbpz+=frac
    a.plot(refmags, frac, label=typ+" BPZ-2000", c=_c, ls=':')
a.plot(refmags, 1.0-_sumbpz, label="Irr BPZ-2000", c='tab:green', ls=':')

_sum_to_one = np.ones_like(refmags)
for ityp, (typ, _c) in enumerate(zip(["E_S0", "Sbc/Scd", "Irr"], ['tab:blue', 'tab:orange', 'tab:green'])):
    nt, m0 = model_prior_ref['nt_array'][ityp], model_prior_ref['mo'][ityp]
    fo = model_prior_ref['fo_arr'][ityp] if ityp<2 \
        else 1-np.sum(model_prior_ref['fo_arr'])
    kt = model_prior_ref['kt_arr'][ityp] if ityp<2 \
        else kt3(model_prior_ref['fo_arr'], model_prior_ref['kt_arr'], model_prior_ref['mo'][ityp], refmags)
    _frac = frac_func((fo, kt), m0, refmags)
    if ityp<2 : _sum_to_one -= _frac
    frac = _frac if ityp<2 else _sum_to_one
    a.plot(refmags, frac, label=typ+" SHIRE (all LSST sim)", c=_c, lw=3)

default = frac_func((1/3, 0.3), 20., refmags)
a.plot(refmags, default, label='default', c='k')
a.set_xlabel('AB magnitude in i-band')
a.set_ylabel('Fraction of galaxies of a given type')
a.legend(loc='lower left', bbox_to_anchor=(1., 0.))

In [None]:
def nz_func(m, z, z0, alpha, km, m0):  # pragma: no cover
    _m = jnp.where(m<m0, m0, m)
    zm = z0 + (km * (_m - m0))
    vals = jnp.power(z, alpha) * jnp.exp(- jnp.power((z / zm), alpha))
    Inorm = jnp.power(zm, alpha+1) * jgamma(1 + 1 / alpha) / alpha
    return vals / Inorm

vmap_dndz_z = vmap(
    nz_func,
    in_axes=(None, 0, None, None, None, None)
)

pzs = jnp.linspace(0.01, 3.1, 301)

for m in np.arange(19, 27, 1):
    _sum_to_one = 1.0
    _sum_to_1cos = 1.0
    _sum_to_1bpz = 1.0
    sumprior = np.zeros_like(pzs)
    sumbpz = np.zeros_like(pzs)
    sumcos = np.zeros_like(pzs)
    f,a = plt.subplots(1,2, figsize=(12, 5))
    for ityp, (typ, _c) in enumerate(zip(["E_S0", "Sbc/Scd", "Irr"], ['tab:blue', 'tab:orange', 'tab:green'])):
        z0, alpha, km, nt, m0 = model_prior_ref['zo_arr'][ityp],\
            model_prior_ref['a_arr'][ityp],\
            model_prior_ref['km_arr'][ityp],\
            model_prior_ref['nt_array'][ityp]/pzs.shape[0],\
            model_prior_ref['mo'][ityp]
        fo = model_prior_ref['fo_arr'][ityp] if ityp<2 \
            else 1-np.sum(model_prior_ref['fo_arr'])
        kt = model_prior_ref['kt_arr'][ityp] if ityp<2 \
            else kt3(model_prior_ref['fo_arr'], model_prior_ref['kt_arr'], model_prior_ref['mo'][ityp], m)

        _frac = frac_func((fo, kt), m0, m)
        if ityp<2 : _sum_to_one -= _frac
        frac = _frac if ityp<2 else _sum_to_one

        vals = vmap_dndz_z(m, pzs, z0, alpha, km, m0) # * frac
        norm = jnp.trapezoid(vals, x=pzs)
        #print(norm)
        a[0].plot(pzs, vals, label=typ+" SHIRE (all LSST sim)", c=_c, lw=3)
        sumprior+=jnp.where(jnp.isfinite(vals*frac), vals*frac, 0.)

        z0bpz, albpz, kmbpz, m0bpz, ntbpz = bpz_model['zo_arr'][ityp],\
            bpz_model['a_arr'][ityp],\
            bpz_model['km_arr'][ityp],\
            bpz_model['mo'],\
            bpz_model['nt_array'][ityp]
        fobpz = bpz_model['fo_arr'][ityp] if ityp<2 \
            else 1-np.sum(bpz_model['fo_arr'])
        ktbpz = bpz_model['kt_arr'][ityp] if ityp<2 \
            else kt3(bpz_model['fo_arr'], bpz_model['kt_arr'], m0bpz, m)

        _fracbpz = frac_func((fobpz, ktbpz), m0bpz, m)
        if ityp<2 : _sum_to_1bpz -= _fracbpz
        fracbpz = _fracbpz if ityp<2 else _sum_to_1bpz

        valsbpz = vmap_dndz_z(m, pzs, z0bpz, albpz, kmbpz, m0bpz) # * frac
        a[0].plot(pzs, valsbpz, label=typ+" BPZ-2000", c=_c, ls=':')
        sumbpz+=jnp.where(jnp.isfinite(valsbpz*fracbpz), valsbpz*fracbpz, 0.)
        
        z0cos, alcos, kmcos, m0cos, ntcos = cosmos_prior_dict['zo_arr'][ityp],\
            cosmos_prior_dict['a_arr'][ityp],\
            cosmos_prior_dict['km_arr'][ityp],\
            cosmos_prior_dict['mo'],\
            cosmos_prior_dict['nt_array'][ityp]
        focos = cosmos_prior_dict['fo_arr'][ityp] if ityp<2 \
            else 1-np.sum(cosmos_prior_dict['fo_arr'])
        ktcos = cosmos_prior_dict['kt_arr'][ityp] if ityp<2 \
            else kt3(cosmos_prior_dict['fo_arr'], cosmos_prior_dict['kt_arr'], m0cos, m)

        _fraccos = frac_func((focos, ktcos), m0cos, m)
        if ityp<2 : _sum_to_1cos -= _fraccos
        fraccos = _fraccos if ityp<2 else _sum_to_1cos

        valscos = vmap_dndz_z(m, pzs, z0cos, alcos, kmcos, m0cos) # * frac
        a[0].plot(pzs, valscos, label=typ+" COSMOS (RAIL)", c=_c, ls='--')
        sumcos+=jnp.where(jnp.isfinite(valscos*fraccos), valscos*fraccos, 0.)
        
    valdefault = vmap_dndz_z(
        m, pzs,
        0.4,
        1.8,
        0.1,
        20.0
    )
    fracdef = frac_func((1/3, 0.3), 20.0, m) #/1*3
    a[0].plot(pzs, valdefault, c='k', label='Default')
    a[0].legend()

    normprior = jnp.trapezoid(sumprior, x=pzs)
    normbpz = jnp.trapezoid(sumbpz, x=pzs)
    normcos = jnp.trapezoid(sumcos, x=pzs)
    normdef = jnp.trapezoid(valdefault*fracdef, x=pzs)
    a[1].plot(pzs, valdefault*fracdef/normdef, c='k', label='Default')
    a[1].plot(pzs, sumprior/normprior, label="SHIRE (all LSST sim)", lw=3, c='blue')
    a[1].plot(pzs, sumbpz/normbpz, label="BPZ-2000", ls=':', lw=2, c='red')
    a[1].plot(pzs, sumcos/normcos, label="COSMOS (RAIL)", ls='--', lw=2, c='green')
    a[1].legend()
    a[0].set_title('Priors for 3 categories of galaxies')
    a[1].set_title('Marginalised prior distributions (sum on galaxy types)')
    a[0].set_xlabel('Redshift z')
    a[1].set_xlabel('Redshift z')
    a[0].set_ylabel('PDF')
    a[1].set_ylabel('PDF')
    f.suptitle(f'Comparison of prior distributions at m={m:.2f}')

In [None]:
for m in np.arange(19, 27, 1):
    _sum_to_one = 1.0
    _sum_to_1cos = 1.0
    _sum_to_1bpz = 1.0
    sumprior = np.zeros_like(pzs)
    sumbpz = np.zeros_like(pzs)
    sumcos = np.zeros_like(pzs)
    f,a = plt.subplots(1,2, figsize=(12, 5))
    for ityp, (typ, _c) in enumerate(zip(["E_S0", "Sbc/Scd", "Irr"], ['tab:blue', 'tab:orange', 'tab:green'])):
        z0, alpha, km, nt, m0 = model_prior_ref['zo_arr'][ityp],\
            model_prior_ref['a_arr'][ityp],\
            model_prior_ref['km_arr'][ityp],\
            model_prior_ref['nt_array'][ityp]/pzs.shape[0],\
            model_prior_ref['mo'][ityp]
        fo = model_prior_ref['fo_arr'][ityp] if ityp<2 \
            else 1-np.sum(model_prior_ref['fo_arr'])
        kt = model_prior_ref['kt_arr'][ityp] if ityp<2 \
            else kt3(model_prior_ref['fo_arr'], model_prior_ref['kt_arr'], model_prior_ref['mo'][ityp], m)

        _frac = frac_func((fo, kt), m0, m)
        if ityp<2 : _sum_to_one -= _frac
        frac = _frac if ityp<2 else _sum_to_one

        vals = vmap_dndz_z(m, pzs, z0, alpha, km, m0) # * frac
        norm = jnp.trapezoid(vals, x=pzs)
        #print(norm)
        a[0].plot(pzs, vals, label=typ+" LSST sim (all)", c=_c, lw=3)
        sumprior+=jnp.where(jnp.isfinite(vals*frac), vals*frac, 0.)

        z0bpz, albpz, kmbpz, m0bpz, ntbpz = model_prior_faint['zo_arr'][ityp],\
            model_prior_faint['a_arr'][ityp],\
            model_prior_faint['km_arr'][ityp],\
            model_prior_faint['mo'][ityp],\
            model_prior_faint['nt_array'][ityp]
        fobpz = model_prior_faint['fo_arr'][ityp] if ityp<2 \
            else 1-np.sum(model_prior_faint['fo_arr'])
        ktbpz = model_prior_faint['kt_arr'][ityp] if ityp<2 \
            else kt3(model_prior_faint['fo_arr'], model_prior_faint['kt_arr'], m0bpz, m)

        _fracbpz = frac_func((fobpz, ktbpz), m0bpz, m)
        if ityp<2 : _sum_to_1bpz -= _fracbpz
        fracbpz = _fracbpz if ityp<2 else _sum_to_1bpz

        valsbpz = vmap_dndz_z(m, pzs, z0bpz, albpz, kmbpz, m0bpz) # * frac
        a[0].plot(pzs, valsbpz, label=typ+" LSST sim FAINT", c=_c, ls=':')
        sumbpz+=jnp.where(jnp.isfinite(valsbpz*fracbpz), valsbpz*fracbpz, 0.)
        
        z0cos, alcos, kmcos, m0cos, ntcos = model_prior_bright['zo_arr'][ityp],\
            model_prior_bright['a_arr'][ityp],\
            model_prior_bright['km_arr'][ityp],\
            model_prior_bright['mo'][ityp],\
            model_prior_bright['nt_array'][ityp]
        focos = model_prior_bright['fo_arr'][ityp] if ityp<2 \
            else 1-np.sum(model_prior_bright['fo_arr'])
        ktcos = model_prior_bright['kt_arr'][ityp] if ityp<2 \
            else kt3(model_prior_bright['fo_arr'], model_prior_bright['kt_arr'], m0cos, m)

        _fraccos = frac_func((focos, ktcos), m0cos, m)
        if ityp<2 : _sum_to_1cos -= _fraccos
        fraccos = _fraccos if ityp<2 else _sum_to_1cos

        valscos = vmap_dndz_z(m, pzs, z0cos, alcos, kmcos, m0cos) # * frac
        a[0].plot(pzs, valscos, label=typ+" LSST sim BRIGHT", c=_c, ls='--')
        sumcos+=jnp.where(jnp.isfinite(valscos*fraccos), valscos*fraccos, 0.)
        
    valdefault = vmap_dndz_z(
        m, pzs,
        0.4,
        1.8,
        0.1,
        20.0
    )
    fracdef = frac_func((1/3, 0.3), 20.0, m) #/1*3
    a[0].plot(pzs, valdefault, c='k', label='Default')
    a[0].legend()

    normprior = jnp.trapezoid(sumprior, x=pzs)
    normbpz = jnp.trapezoid(sumbpz, x=pzs)
    normcos = jnp.trapezoid(sumcos, x=pzs)
    normdef = jnp.trapezoid(valdefault*fracdef, x=pzs)
    a[1].plot(pzs, valdefault*fracdef/normdef, c='k', label='Default')
    a[1].plot(pzs, sumprior/normprior, label="all LSST sim", lw=3, c='k')
    a[1].plot(pzs, sumbpz/normbpz, label="FAINT -- LSST sim", ls=':', lw=2, c='k')
    a[1].plot(pzs, sumcos/normcos, label="BRIGHT -- LSST sim", ls='--', lw=2, c='k')
    a[1].legend()
    a[0].set_title('Priors for 3 categories of galaxies')
    a[1].set_title('Marginalised prior distributions (sum on galaxy types)')
    a[0].set_xlabel('Redshift z')
    a[1].set_xlabel('Redshift z')
    a[0].set_ylabel('PDF')
    a[1].set_ylabel('PDF')
    f.suptitle(f'Comparison of prior distributions at m={m:.2f}')