In [None]:
import os
import qp
import jax
import matplotlib.pyplot as plt
import seaborn as sns
from jax import numpy as jnp

import pandas as pd
import tables_io
import numpy as np

from rail.core.data import TableHandle
from rail.core.stage import RailStage
from rail.utils.path_utils import RAILDIR
from rail.core.common_params import SHARED_PARAMS

from rail.shire import ShireInformer, ShireEstimator, hist_outliers, plot_zp_zs_ensemble

jax.config.update("jax_enable_x64", True)

## Select and load data into the datastore

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
trainFile = os.path.abspath(os.path.join('.', 'data', 'train_magszgalaxies_lsstroman_gold_hp10552_10k.h5')) #os.path.abspath(os.path.join('./data', 'test_magszgalaxies_lsstroman_gold_hp10552_50k.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_training_9816.hdf5')
testFile = os.path.abspath(os.path.join('.', 'data', 'test_magszgalaxies_lsstroman_gold_hp10552_50k.h5')) #os.path.abspath(os.path.join('../..', 'magszgalaxies_lsstroman_gold_hp10552.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_validation_9816.hdf5')

In [None]:
mag_lim = 24
mag_col = "mag_i_lsst"

traindf = pd.read_hdf(trainFile)
traindf.sort_values(mag_col, axis=0, inplace=True)
train_bright = traindf[traindf[mag_col]<=mag_lim]
train_faint = traindf[traindf[mag_col]>mag_lim]

testdf = pd.read_hdf(testFile)
testdf.sort_values(mag_col, axis=0, inplace=True)
test_bright = testdf[testdf[mag_col]<=mag_lim]
test_faint = testdf[testdf[mag_col]>mag_lim]

In [None]:
trainFile_bright = os.path.abspath(os.path.join('.', 'data', 'train_lsstroman_gold_hp10552_bright.h5'))
testFile_bright = os.path.abspath(os.path.join('.', 'data', 'test_lsstroman_gold_hp10552_bright.h5'))
train_bright.to_hdf(trainFile_bright, key="photometry", mode="w")
test_bright.to_hdf(testFile_bright, key="photometry", mode="w")

trainFile_faint = os.path.abspath(os.path.join('.', 'data', 'train_lsstroman_gold_hp10552_faint.h5'))
testFile_faint = os.path.abspath(os.path.join('.', 'data', 'test_lsstroman_gold_hp10552_faint.h5'))
train_faint.to_hdf(trainFile_faint, key="photometry", mode="w")
test_faint.to_hdf(testFile_faint, key="photometry", mode="w")

In [None]:
train_lowz = DS.read_file("training_data_bright", TableHandle, trainFile_bright)
test_lowz = DS.read_file("test_data_bright", TableHandle, testFile_bright)

train_hiz = DS.read_file("training_data_faint", TableHandle, trainFile_faint)
test_hiz = DS.read_file("test_data_faint", TableHandle, testFile_faint)

In [None]:
from rail.creation.degraders.quantityCut import QuantityCut

if False:
    train_cut = QuantityCut.make_stage(name="train_cuts", cuts={mag_col: mag_lim})
    test_cut = QuantityCut.make_stage(name="test_cuts", cuts={mag_col: mag_lim})
    train_lowz = train_cut(train_data)
    test_lowz = test_cut(test_data)

In [None]:
lsst_filts_dict = {f"{_n}_lsst": "filt_lsst" for _n in "ugrizy"}

_bands = [ f"mag_{_k}" for _k in lsst_filts_dict ]
_errbands = [ f"mag_err_{_k}" for _k in lsst_filts_dict ]

## Inform the estimator, i.e. select a subset of galaxies as templates

In [None]:
default_dict_inform = dict(
    hdf5_groupname="photometry", #"", #
    data_path="./data",
    bands=_bands,
    err_bands=_errbands,
    spectra_file="dsps_valid_fits_F2SM3_GG_DESI.h5",
    ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
    filter_dict=lsst_filts_dict,
    wlmin=500.,
    wlmax=12000.,
    dwl=5.,
    colrsbins=60,
    init_m0=mag_lim
)

### Prepare two informers: one for each method 'SPS' or 'Legacy'
- 'SPS' recomputes an SED at every redshift based on the star-formation history of the template galaxy before synthesizing the colours for every value of $z$ along the grid
- 'Legacy' computes the SED once at the template galaxy's redshift and uses it to synthesize colours at all $z$ values with the usual transformation $\lambda_\mathrm{obs} = (1+z)\lambda_\mathrm{em}$

In [None]:
run_shire_inform_sps = ShireInformer.make_stage(
    name="shireSPS_inform_lsstSimhp10552_demo_splithi",
    output="shireSPS_templates_lsstSimhp10552_demo_splithi.hf5",
    **default_dict_inform,
    templ_type="SPS"
)

run_shire_inform_legacy = ShireInformer.make_stage(
    name="shireLEG_inform_lsstSimhp10552_demo_splithi",
    output="shireLEG_templates_lsstSimhp10552_demo_splithi.hf5",
    **default_dict_inform,
    templ_type="Legacy"
)

### Inform the 'SPS' templates

In [None]:
%%time
run_shire_inform_sps.inform(train_hiz)

In [None]:
all_templs_df_sps = run_shire_inform_sps._nuvk_classif()
f, a = plt.subplots(1,1)
sns.scatterplot(
    data=all_templs_df_sps, x="g_lsst-r_lsst", y="r_lsst-i_lsst",
    hue="CAT_NUVK", hue_order=['E_S0', 'Sbc/Scd', 'Irr'], palette='tab10', alpha=0.5,
    size='z_p', sizes=(10, 100),
    ax=a
)
a.grid()

In [None]:
run_shire_inform_sps.plot_templ_seds()

### Inform the 'Legacy' templates

In [None]:
%%time
run_shire_inform_legacy.inform(train_hiz)

In [None]:
run_shire_inform_legacy.plot_templ_seds()

In [None]:
all_templs_df_leg = run_shire_inform_legacy._nuvk_classif()
f, a = plt.subplots(1,1)
sns.scatterplot(
    data=all_templs_df_leg, x="g_lsst-r_lsst", y="r_lsst-i_lsst",
    hue="CAT_NUVK", hue_order=['E_S0', 'Sbc/Scd', 'Irr'], palette='tab10', alpha=0.5,
    size='z_p', sizes=(10, 100),
    ax=a
)
a.grid()

### Plot the prior

In [None]:
trainref_df = pd.DataFrame()
trainref_df['ref_mag'] = np.array(run_shire_inform_legacy.refmags)
trainref_df['redshift'] = np.array(run_shire_inform_legacy.szs)
trainref_df['type_num'] = np.array(run_shire_inform_legacy.besttypes)
trainref_df['type'] = np.array([run_shire_inform_legacy.refcategs[_n] for _n in run_shire_inform_legacy.besttypes])

In [None]:
trainref_df[['u-g', 'g-r', 'r-i', 'i-z', 'z-y']] = np.array(run_shire_inform_legacy.mags[:, :-1] - run_shire_inform_legacy.mags[:, 1:])

In [None]:
f,a = plt.subplots(1,1)
sns.histplot(data=trainref_df, x='ref_mag', hue='type', hue_order=['E_S0', 'Sbc/Scd', 'Irr'], palette='tab10', stat='probability', multiple='layer', ax=a)

In [None]:
sns.scatterplot(trainref_df, x='u-g', y='i-z', hue='type', hue_order=['E_S0', 'Sbc/Scd', 'Irr'], palette='tab10', alpha=0.5)

In [None]:
f,a = plt.subplots(1,1)

def frac_func(X, m0, m):
    fo, kt = X
    return fo * jnp.exp(-kt * (m - m0))

def kt3(fo_arr, kt_arr, m0, m):
    kt_val = -jnp.log((1 - fo_arr[0] * jnp.exp(-kt_arr[0] * (m - m0)) - fo_arr[1] * jnp.exp(-kt_arr[1] * (m - m0))) / fo_arr[-1]) / (m - m0)
    return kt_val

_order = np.argsort(run_shire_inform_legacy.refmags)
refmags = run_shire_inform_legacy.refmags[_order]

for ityp, (typ, _c) in enumerate(zip(run_shire_inform_legacy.refcategs, ['tab:blue', 'tab:orange', 'tab:green'])):
    fo, kt, nt, m0 = run_shire_inform_legacy.model['fo_arr'][ityp],\
        run_shire_inform_legacy.model['kt_arr'][ityp],\
        run_shire_inform_legacy.model['nt_array'][ityp],\
        run_shire_inform_legacy.model['mo'][ityp]
    frac = frac_func((fo, kt), m0, refmags)
    a.plot(refmags, frac, label=typ, c=_c)

default = frac_func((0.3, run_shire_inform_legacy.config.init_kt), 20., refmags)
a.plot(refmags, default, label='default', c='k')
a.legend()

In [None]:
f,a = plt.subplots(1,1)
sns.histplot(data=trainref_df, x='redshift', hue='type', hue_order=['E_S0', 'Sbc/Scd', 'Irr'], palette='tab10', stat='density', multiple='layer', ax=a) 

In [None]:
f,a = plt.subplots(1,1)
sns.kdeplot(data=trainref_df, x='ref_mag', y='redshift', hue='type', hue_order=['E_S0', 'Sbc/Scd', 'Irr'], palette='tab10', ax=a)

In [None]:
from jax import vmap
from jax.scipy.special import gamma as jgamma
import pickle

def nz_func(m, z, z0, alpha, km, m0):  # pragma: no cover
    zm = z0 + (km * (m - m0))
    vals = jnp.power(z, alpha) * jnp.exp(- jnp.power((z / zm), alpha))
    Inorm = jnp.power(zm, alpha+1) * jgamma(1 + 1 / alpha) / alpha
    return vals / Inorm

vmap_dndz_z = vmap(
    nz_func,
    in_axes=(None, 0, None, None, None, None)
)

bpz_model = {
    'fo_arr': jnp.array([0.35, 0.5, 0.15]),
    'kt_arr': jnp.array([0.147, 0.450]),
    'zo_arr': jnp.array([0.431, 0.39, 0.063]),
    'a_arr': jnp.array([2.46, 1.81, 0.91]),
    'km_arr': jnp.array([0.091, 0.0636, 0.123]),
    'mo': 20.0,
    'nt_array': jnp.array([1, 2, 3])
}

DATDIR = "/global/u2/j/jcheval/rail_base/src"
cosmospriorfile = os.path.join(DATDIR, "rail/examples_data/estimation_data/data/COSMOS31_HDFN_prior.pkl")
with open(cosmospriorfile, 'rb') as _f:
    cosmos_prior_dict = pickle.load(_f)
cosmos_prior_dict['nt_array'] = jnp.array([10, 5, 16])
cosmos_prior_dict['mo'] = 20.0

for m in np.linspace(mag_lim, 26, 4):
    sumprior = np.zeros_like(run_shire_inform_legacy.pzs)
    sumbpz = np.zeros_like(run_shire_inform_legacy.pzs)
    sumcos = np.zeros_like(run_shire_inform_legacy.pzs)
    f,a = plt.subplots(1,2, figsize=(12, 5))
    pzs = jnp.array(run_shire_inform_legacy.pzs)
    for ityp, (typ, _c) in enumerate(zip(run_shire_inform_legacy.refcategs, ['tab:blue', 'tab:orange', 'tab:green'])):
        fo, kt, z0, alpha, km, nt, m0 = run_shire_inform_legacy.model['fo_arr'][ityp],\
            run_shire_inform_legacy.model['kt_arr'][ityp],\
            run_shire_inform_legacy.model['zo_arr'][ityp],\
            run_shire_inform_legacy.model['a_arr'][ityp],\
            run_shire_inform_legacy.model['km_arr'][ityp],\
            run_shire_inform_legacy.model['nt_array'][ityp]/pzs.shape[0],\
            run_shire_inform_legacy.model['mo'][ityp]
        frac = frac_func((fo, kt), m0, m)/nt #*np.sum(run_shire_inform_legacy.model['nt_array'])
        vals = vmap_dndz_z(m, pzs, z0, alpha, km, m0) # * frac
        norm = jnp.trapezoid(vals, x=pzs)
        #print(norm)
        a[0].plot(pzs, vals*frac, label="SHIRE-"+typ, c=_c, lw=2)
        sumprior+=vals*frac

        z0bpz, albpz, kmbpz, m0bpz, ntbpz = bpz_model['zo_arr'][ityp],\
            bpz_model['a_arr'][ityp],\
            bpz_model['km_arr'][ityp],\
            bpz_model['mo'],\
            bpz_model['nt_array'][ityp]
        fobpz = bpz_model['fo_arr'][ityp] if ityp<2 \
            else 1-np.sum(bpz_model['fo_arr'])
        ktbpz = bpz_model['kt_arr'][ityp] if ityp<2 \
            else kt3(bpz_model['fo_arr'], bpz_model['kt_arr'], m0bpz, m)
        fracbpz = frac_func((fobpz, ktbpz), m0bpz, m)/ntbpz #*np.sum(bpz_model['nt_array'])
        valsbpz = vmap_dndz_z(m, pzs, z0bpz, albpz, kmbpz, m0bpz) # * frac
        a[0].plot(pzs, valsbpz*fracbpz, label="BPZ-"+typ, c=_c, ls=':')
        sumbpz+=valsbpz*fracbpz
        
        z0cos, alcos, kmcos, m0cos, ntcos = cosmos_prior_dict['zo_arr'][ityp],\
            cosmos_prior_dict['a_arr'][ityp],\
            cosmos_prior_dict['km_arr'][ityp],\
            cosmos_prior_dict['mo'],\
            cosmos_prior_dict['nt_array'][ityp]
        focos = cosmos_prior_dict['fo_arr'][ityp] if ityp<2 \
            else 1-np.sum(cosmos_prior_dict['fo_arr'])
        ktcos = cosmos_prior_dict['kt_arr'][ityp] if ityp<2 \
            else kt3(cosmos_prior_dict['fo_arr'], cosmos_prior_dict['kt_arr'], m0bpz, m)
        fraccos = frac_func((focos, ktcos), m0cos, m)/ntcos #*np.sum(cosmos_prior_dict['nt_array'])
        valscos = vmap_dndz_z(m, pzs, z0cos, alcos, kmcos, m0cos) # * frac
        a[0].plot(pzs, valscos*fraccos, label="COSMOS-"+typ, c=_c, ls=(0, (3, 5, 1, 5)))
        sumcos+=valscos*fraccos
        
    valdefault = vmap_dndz_z(
        m, pzs,
        run_shire_inform_legacy.config.init_z0,
        run_shire_inform_legacy.config.init_alpha,
        run_shire_inform_legacy.config.init_km,
        20.0
    )
    fracdef = frac_func((1/3, run_shire_inform_legacy.config.init_kt), 20.0, m) #/1*3
    a[0].plot(pzs, valdefault*fracdef, c='k', label='Default')
    a[0].legend()

    normprior = jnp.trapezoid(sumprior, x=pzs)
    normbpz = jnp.trapezoid(sumbpz, x=pzs)
    normcos = jnp.trapezoid(sumcos, x=pzs)
    normdef = jnp.trapezoid(valdefault*fracdef, x=pzs)
    a[1].plot(pzs, valdefault*fracdef/normdef, c='k', label='Default')
    a[1].plot(pzs, sumprior/normprior, label="SHIRE")
    a[1].plot(pzs, sumbpz/normbpz, label="BPZ", ls=':')
    a[1].plot(pzs, sumcos/normcos, label="COSMOS", ls=(0, (3, 5, 1, 5)))
    a[1].legend()
    a[0].set_title('Priors for 3 categories of galaxies')
    a[1].set_title('Marginalised prior distributions (sum on galaxy types)')
    a[0].set_xlabel('Redshift z')
    a[1].set_xlabel('Redshift z')
    a[0].set_ylabel('PDF')
    a[1].set_ylabel('PDF')
    f.suptitle(f'Comparison of prior distributions at m={m:.2f}')

In [None]:
run_shire_inform_sps.model

In [None]:
run_shire_inform_legacy.model

## Run the photometric redshifts estimation

### Build two estimators
Again, one is for the "SPS" method and the other is for the "Legacy" method. Though both `estimate` stages could work with templates from either `inform` stage, it makes more sense to keep things consistent and load the appropriate `handles` from the corresponding `inform` stage.

In [None]:
use_prior = False
_suffix = "" if use_prior else "_noprior"

In [None]:
default_dict_estimate = dict(
    hdf5_groupname="photometry", #"", #
    data_path="./data",
    bands=_bands,
    err_bands=_errbands,
    zmin=0.01,
    zmax=3.1,
    nzbins=310,
    ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
    filter_dict=lsst_filts_dict,
    wlmin=500.,
    wlmax=12000.,
    dwl=5.,
    no_prior=not(use_prior),
    chunk_size=5000
)

run_shire_estimate_sps = ShireEstimator.make_stage(
    name="shireSPS_estimate_lsstSimhp10552_demo"+_suffix+"_splithi",
    output=f"shireSPS_results_lsstSimhp10552_demo{_suffix}_splithi.hdf5",
    **default_dict_estimate,
    templ_type="SPS",
    templates=run_shire_inform_sps.get_handle("templates"),
    model=run_shire_inform_sps.get_handle("model")
)

run_shire_estimate_legacy = ShireEstimator.make_stage(
    name="shireLEG_estimate_lsstSimhp10552_demo"+_suffix+"_splithi",
    output=f"shireLEG_results_lsstSimhp10552_demo{_suffix}_splithi.hdf5",
    **default_dict_estimate,
    templ_type="Legacy",
    templates=run_shire_inform_legacy.get_handle("templates"),
    model=run_shire_inform_legacy.get_handle("model")
)

### Run a reference Estimator (BPZ)

In [None]:
typefile = 'training_types_legacy.hdf5'
typ_df = pd.DataFrame()
traintypes = np.array(run_shire_inform_legacy.besttypes)
typ_df['types'] = traintypes #np.where(traintypes > 1, traintypes-1, traintypes)
tables_io.write(typ_df, typefile)
typ_df['types']

In [None]:
from rail.estimation.algos.bpz_lite import BPZliteInformer, BPZliteEstimator

from rail.core.data import ModelHandle

RAILDIR = "/global/u2/j/jcheval/rail_base/src"

cosmospriorfile = os.path.join(RAILDIR, "rail/examples_data/estimation_data/data/COSMOS31_HDFN_prior.pkl")
cosmosprior = DS.read_file("cosmos_prior", ModelHandle, cosmospriorfile)
sedfile = "COSMOS_seds.list" #os.path.join(RAILDIR, "rail/examples_data/estimation_data/data/SED/COSMOS_seds.list")

filter_list = [f"DC2LSST_{b.split('_')[0]}" for b in lsst_filts_dict.keys()]

cosmos_dict = dict(
    hdf5_groupname="photometry", #"", #
    output=f"BPZ_results_lsstSimhp10552_demo{_suffix}_splithi.hdf5",
    spectra_file=sedfile,
    bands=_bands,
    err_bands=_errbands,
    filter_list=filter_list,
    wlmin=500.,
    wlmax=12000.,
    dwl=5.,
    zmin=0.01,
    zmax=3.1,
    nzbins=310,
    prior_band="mag_i_lsst",
    data_path="/global/u2/j/jcheval/rail_base/src/rail/examples_data/estimation_data/data",
    no_prior=not(use_prior),
    chunk_size=5000
)

inform_bpz = BPZliteInformer.make_stage(
    name="BPZ_inform_lsstSimhp10552_demo_splithi",
    hdf5_groupname="photometry", #"", #
    nondetect_val=jnp.nan,
    bands=_bands,
    err_bands=_errbands,
    filter_list=filter_list,
    prior_band="mag_i_lsst",
    wlmin=500.,
    wlmax=12000.,
    dwl=5.,
    data_path="/global/u2/j/jcheval/rail_base/src/rail/examples_data/estimation_data/data",
    nt_array=[10, 5, 16],
    type_file=typefile
)

In [None]:
%%time
inform_bpz.inform(train_hiz)

In [None]:
estimate_bpz = BPZliteEstimator.make_stage(
    name="BPZ_estimate_lsstSimhp10552_demo"+_suffix+"_splithi",
    model= cosmosprior, # inform_bpz.get_handle("model"), #
    **cosmos_dict
)

In [None]:
%%time
estimate_bpz.estimate(test_hiz)

In [None]:
estimate_bpz.modeldict

In [None]:
pdfs_file_bpz = f"BPZ_results_lsstSimhp10552_demo{_suffix}_splithi.hdf5"
custom_res_bpz = qp.read(pdfs_file_bpz)
sz = jnp.array(test_hiz()['photometry']['redshift'])

In [None]:
a = plot_zp_zs_ensemble(
    custom_res_bpz, sz,
    z_grid=None, key_estim="zmode",
    label='_'.join(['BPZ']+(os.path.splitext(pdfs_file_bpz)[0]).split('_')[2:]))
plt.show()

### Run the "SPS" estimation

In [None]:
%%time
run_shire_estimate_sps.estimate(test_hiz)

In [None]:
pdfs_file_sps = f"shireSPS_results_lsstSimhp10552_demo{_suffix}_splithi.hdf5"
custom_res_sps = qp.read(pdfs_file_sps)
sz = jnp.array(test_hiz()['photometry']['redshift'])

In [None]:
a = plot_zp_zs_ensemble(
    custom_res_sps, sz,
    z_grid=None, key_estim="zmode",
    label='_'.join(['SHIRE_SPS']+(os.path.splitext(pdfs_file_sps)[0]).split('_')[2:]))
plt.show()

### Run the "Legacy" estimation

In [None]:
run_shire_estimate_legacy.modeldict

In [None]:
%%time
run_shire_estimate_legacy.estimate(test_hiz)

In [None]:
pdfs_file_legacy = f"shireLEG_results_lsstSimhp10552_demo{_suffix}_splithi.hdf5"
custom_res_legacy = qp.read(pdfs_file_legacy)
# sz = jnp.array(test_data()['photometry']['redshift']) -- Unnecessary

In [None]:
a = plot_zp_zs_ensemble(
    custom_res_legacy, sz,
    z_grid=None, key_estim="zmode",
    label='_'.join(['SHIRE_Legacy']+(os.path.splitext(pdfs_file_legacy)[0]).split('_')[2:]))
plt.show()

### Compare outliers distribution between both methods

In [None]:
hist_outliers(
    custom_res_sps, sz, label1='_'.join(['SHIRE_SPS']+(os.path.splitext(pdfs_file_sps)[0]).split('_')[2:]),
    qp_ens_2=custom_res_legacy, label2='_'.join(['SHIRE_Legacy']+(os.path.splitext(pdfs_file_legacy)[0]).split('_')[2:]),
    qp_ens_3=custom_res_bpz, label3='_'.join(['BPZ']+(os.path.splitext(pdfs_file_bpz)[0]).split('_')[2:])
)
plt.show()

## Evaluate posteriors using `RAIL` 

Check out [Evaluation_demo_LSSTsim.ipynb](Evaluation_demo_LSSTsim.ipynb) !

## Build a pipeline with `ceci`

In [None]:
import ceci
pipe = ceci.Pipeline.interactive()
stages = [run_shire_inform_sps, run_shire_estimate_sps]
for stage in stages:
    pipe.add_stage(stage)
pipe.stage_execution_config[f'shireSPS_estimate_lsstSimhp10552_demo{_suffix}_splithi'].nprocess=1

In [None]:
pipe.initialize(
    dict(
        training_data_faint=trainFile_faint,
        test_data_faint=testFile_faint
    ),
    dict(
        output_dir='.',
        log_dir='.',
        resume=False
    ),
    None
)

In [None]:
pipe.save(f'rail_shireSPS_lsstsim_pz_demo{_suffix}_splithi.yml')

In [None]:
pr = ceci.Pipeline.read(f'rail_shireSPS_lsstsim_pz_demo{_suffix}_splithi.yml')

In [None]:
#pr.run()