In [None]:
import os
import qp
import jax

import matplotlib.pyplot as plt
import seaborn as sns
from jax import numpy as jnp

from rail.core.data import TableHandle
from rail.core.stage import RailStage

from rail.shire import ShireInformer, ShireEstimator, plot_zp_zs_ensemble

#jax.config.update("jax_enable_x64", True)

## Select and load data into the datastore

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
trainFile = os.path.abspath(os.path.join('.', 'data', 'minitrain_magszgalaxies_lsstroman_gold_hp10552_1k.h5')) #os.path.abspath(os.path.join('./data', 'test_magszgalaxies_lsstroman_gold_hp10552_50k.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_training_9816.hdf5')
testFile = os.path.abspath(os.path.join('.', 'data', 'minitest_magszgalaxies_lsstroman_gold_hp10552_3k.h5')) #os.path.abspath(os.path.join('../..', 'magszgalaxies_lsstroman_gold_hp10552.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_validation_9816.hdf5')

training_data = DS.read_file("training_data", TableHandle, trainFile)
test_data = DS.read_file("test_data", TableHandle, testFile)

In [None]:
lsst_filts_dict = {f"{_n}_lsst": "filt_lsst" for _n in "ugrizy"}

_bands = [ f"mag_{_k}" for _k in lsst_filts_dict ]
_errbands = [ f"mag_err_{_k}" for _k in lsst_filts_dict ]

## Inform the estimator, i.e. select a subset of galaxies as templates

In [None]:
default_dict_inform = dict(
    hdf5_groupname="photometry",
    data_path="./data",
    bands=_bands,
    err_bands=_errbands,
    spectra_file="dsps_valid_fits_F2SM3_GG_DESI.h5",
    ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
    filter_dict=lsst_filts_dict,
    wlmin=900.,
    wlmax=12000.,
    dwl=20.,
    zmin=0.01,
    zmax=3.0,
    nzbins=50,
    ntemplates=20
)

### Prepare the informer: either for method 'SPS' or 'Legacy'
- 'SPS' recomputes an SED at every redshift based on the star-formation history of the template galaxy before synthesizing the colours for every value of $z$ along the grid
- 'Legacy' computes the SED once at the template galaxy's redshift and uses it to synthesize colours at all $z$ values with the usual transformation $\lambda_\mathrm{obs} = (1+z)\lambda_\mathrm{em}$

In [None]:
run_shire_inform = ShireInformer.make_stage(
    name="shireSPS_inform_lsstSimhp10552_demo",
    output="shireSPS_templates_lsstSimhp10552_demo.hf5",
    **default_dict_inform,
    templ_type="SPS"
)

# run_shire_inform = ShireInformer.make_stage(
#     name="shireLEG_inform_lsstSimhp10552_demo",
#     output="shireLEG_templates_lsstSimhp10552_demo.hf5",
#     **default_dict_inform,
#     templ_type="Legacy"
# )

### Inform the templates

In [None]:
%%time
run_shire_inform.inform(training_data)

In [None]:
templ = run_shire_inform.get_handle("templates")
templ.read()

In [None]:
all_templs_df_sps = run_shire_inform._nuvk_classif()
f, a = plt.subplots(1,1)
sns.scatterplot(
    data=all_templs_df_sps, x="g_lsst-r_lsst", y="r_lsst-i_lsst",
    hue="CAT_NUVK", hue_order=['E_S0', 'Sbc/Scd', 'Irr'], palette='tab10', alpha=0.5,
    size='z_p', sizes=(10, 100),
    ax=a
)
a.grid()

In [None]:
run_shire_inform.plot_templ_seds()

In [None]:
run_shire_inform.model

## Run the photometric redshifts estimation

### Build the estimators
Again, one is for the "SPS" method and the other is for the "Legacy" method. Though both `estimate` stages could work with templates from either `inform` stage, it makes more sense to keep things consistent and load the appropriate `handles` from the corresponding `inform` stage.

In [None]:
use_prior = True
_suffix = "" if use_prior else "_noprior"

In [None]:
default_dict_estimate = dict(
    hdf5_groupname="photometry",
    data_path="./data",
    bands=_bands,
    err_bands=_errbands,
    zmin=0.01,
    zmax=3.1,
    nzbins=150,
    ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
    filter_dict=lsst_filts_dict,
    wlmin=900.,
    wlmax=12000.,
    dwl=20.,
    no_prior=not(use_prior),
    chunk_size=250
)

run_shire_estimate = ShireEstimator.make_stage(
    name="shireSPS_estimate_lsstSimhp10552_demo"+_suffix,
    output=f"shireSPS_results_lsstSimhp10552_demo{_suffix}.hdf5",
    **default_dict_estimate,
    templ_type="SPS",
    templates=run_shire_inform.get_handle("templates"),
    model=run_shire_inform.get_handle("model")
)

# run_shire_estimate = ShireEstimator.make_stage(
#     name="shireLEG_estimate_lsstSimhp10552_demo"+_suffix,
#     output=f"shireLEG_results_lsstSimhp10552_demo{_suffix}.hdf5",
#     **default_dict_estimate,
#     templ_type="Legacy",
#     templates=run_shire_inform_legacy.get_handle("templates"),
#     model=run_shire_inform_legacy.get_handle("model")
# )

### Run the estimation

In [None]:
%%time
run_shire_estimate.estimate(test_data)

In [None]:
pdfs_file = f"shireSPS_results_lsstSimhp10552_demo{_suffix}.hdf5"
#pdfs_file = f"shireLEG_results_lsstSimhp10552_demo{_suffix}.hdf5"
custom_res = qp.read(pdfs_file)
sz = jnp.array(test_data()['photometry']['redshift'])

In [None]:
a = plot_zp_zs_ensemble(
    custom_res, sz,
    z_grid=None, key_estim="zmode",
    label='_'.join(['SHIRE_SPS']+(os.path.splitext(pdfs_file)[0]).split('_')[2:]))
plt.show()

## Evaluate the posterior distribution using `RAIL` 

Check out [Evaluation_demo_LSSTsim_v2.ipynb](Evaluation_demo_LSSTsim_v2.ipynb) !

## Build a pipeline with `ceci`

In [None]:
import ceci
pipe = ceci.Pipeline.interactive()
stages = [run_shire_inform, run_shire_estimate]
for stage in stages:
    pipe.add_stage(stage)
pipe.stage_execution_config[f'shireSPS_estimate_lsstSimhp10552_demo{_suffix}'].nprocess=1

In [None]:
pipe.initialize(
    dict(
        training_data=trainFile,
        test_data=testFile
    ),
    dict(
        output_dir='.',
        log_dir='.',
        resume=False
    ),
    None
)

In [None]:
pipe.save(f'rail_shire_lsstsim_pz_demo{_suffix}.yml')

In [None]:
pr = ceci.Pipeline.read(f'rail_shire_lsstsim_pz_demo{_suffix}.yml')

In [None]:
#pr.run()