In [None]:
import os
import qp
import jax
import matplotlib.pyplot as plt
import seaborn as sns
from jax import numpy as jnp

from rail.core.data import TableHandle
from rail.core.stage import RailStage
from rail.utils.path_utils import RAILDIR
from rail.core.common_params import SHARED_PARAMS

from rail.shire import ShireInformer, ShireEstimator, hist_outliers, plot_zp_zs_ensemble

jax.config.update("jax_enable_x64", True)

## Select and load data into the datastore

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
trainFile = os.path.abspath(os.path.join('.', 'data', 'train_magszgalaxies_lsstroman_gold_hp10552_10k.h5')) #os.path.abspath(os.path.join('./data', 'test_magszgalaxies_lsstroman_gold_hp10552_50k.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_training_9816.hdf5')
testFile = os.path.abspath(os.path.join('.', 'data', 'test_magszgalaxies_lsstroman_gold_hp10552_50k.h5')) #os.path.abspath(os.path.join('../..', 'magszgalaxies_lsstroman_gold_hp10552.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_validation_9816.hdf5')

training_data = DS.read_file("training_data", TableHandle, trainFile)
test_data = DS.read_file("test_data", TableHandle, testFile)

In [None]:
lsst_filts_dict = {f"{_n}_lsst": "filt_lsst" for _n in "ugrizy"}
roman_filts_dict = {
    "roman_wfi_f106": "",
    "roman_wfi_f129": "",
    "roman_wfi_f158": "",
    "roman_wfi_f184": ""
}

_bands = [ f"mag_{_k}" for _k in {**lsst_filts_dict, **roman_filts_dict} ]
_errbands = [ f"mag_err_{_k}" for _k in {**lsst_filts_dict, **roman_filts_dict} ]
_maglims = {**SHARED_PARAMS['mag_limits'], **{_band: 28 for _band in _bands if 'roman' in _band}}
print(_maglims)

## Inform the estimator, i.e. select a subset of galaxies as templates

In [None]:
default_dict_inform = dict(
    hdf5_groupname="photometry",
    data_path="./data",
    bands=_bands,
    err_bands=_errbands,
    mag_limits=_maglims,
    spectra_file="dsps_valid_fits_F2SM3_GG_DESI.h5",
    ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
    filter_dict={**lsst_filts_dict, **roman_filts_dict},
    wlmin=1000.,
    wlmax=25000.,
    dwl=10.,
    colrsbins=40
)

### Prepare two informers: one for each method 'SPS' or 'Legacy'
- 'SPS' recomputes an SED at every redshift based on the star-formation history of the template galaxy before synthesizing the colours for every value of $z$ along the grid
- 'Legacy' computes the SED once at the template galaxy's redshift and uses it to synthesize colours at all $z$ values with the usual transformation $\lambda_\mathrm{obs} = (1+z)\lambda_\mathrm{em}$

In [None]:
run_shire_inform_sps = ShireInformer.make_stage(
    name="shireSPS_inform_lsstSimhp10552",
    output="shireSPS_templates_lsstSimhp10552.hf5",
    **default_dict_inform,
    templ_type="SPS"
)

run_shire_inform_legacy = ShireInformer.make_stage(
    name="shireLEG_inform_lsstSimhp10552",
    output="shireLEG_templates_lsstSimhp10552.hf5",
    **default_dict_inform,
    templ_type="Legacy"
)

### Inform the 'SPS' templates

In [None]:
%%time
run_shire_inform_sps.inform(training_data)

In [None]:
templ = run_shire_inform_sps.get_handle("templates")
templ.read()

In [None]:
run_shire_inform_sps.plot_colrs_templates()

In [None]:
run_shire_inform_sps.plot_sfh_templates()

In [None]:
run_shire_inform_sps.hist_colrs_templates()

In [None]:
run_shire_inform_sps.plot_bpt_templates()

In [None]:
all_templs_df_sps = run_shire_inform_sps._nuvk_classif()
sns.scatterplot(
    data=all_templs_df_sps, x="g_lsst-r_lsst", y="r_lsst-i_lsst",
    hue="CAT_NUVK", size='z_p', sizes=(10, 100), alpha=0.5
)

In [None]:
run_shire_inform_sps.plot_templ_seds()

In [None]:
run_shire_inform_sps.plot_line_sed(7, redshift=0.46)

### Inform the 'Legacy' templates

In [None]:
%%time
run_shire_inform_legacy.inform(training_data)

In [None]:
run_shire_inform_legacy.plot_templ_seds()

In [None]:
run_shire_inform_legacy.plot_colrs_templates()

In [None]:
all_templs_df_leg = run_shire_inform_legacy._nuvk_classif()
sns.scatterplot(
    data=all_templs_df_leg, x="g_lsst-r_lsst", y="r_lsst-i_lsst",
    hue="CAT_NUVK", size='z_p', sizes=(10, 100), alpha=0.5
)

## Run the photometric redshifts estimation

### Build two estimators
Again, one is for the "SPS" method and the other is for the "Legacy" method. Though both `estimate` stages could work with templates from either `inform` stage, it makes more sense to keep things consistent and load the appropriate `handles` from the corresponding `inform` stage.

In [None]:
default_dict_estimate = dict(
    hdf5_groupname="photometry",
    data_path="./data",
    bands=_bands,
    err_bands=_errbands,
    zmin=0.01,
    zmax=3.1,
    nzbins=310,
    mag_limits=_maglims,
    ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
    filter_dict={**lsst_filts_dict, **roman_filts_dict},
    wlmin=1000.,
    wlmax=25000.,
    dwl=10.,
    no_prior=True,
    chunk_size=5000
)

run_shire_estimate_sps = ShireEstimator.make_stage(
    name="shireSPS_estimate_lsstSimhp10552",
    output="shireSPS_results_lsstSimhp10552_noprior.hdf5",
    **default_dict_estimate,
    templ_type="SPS",
    templates=run_shire_inform_sps.get_handle("templates"),
    model=run_shire_inform_sps.get_handle("model")
)

run_shire_estimate_legacy = ShireEstimator.make_stage(
    name="shireLEG_estimateSPS_lsstSimhp10552",
    output="shireLEG_resultsSPS_lsstSimhp10552_noprior.hdf5",
    **default_dict_estimate,
    templ_type="Legacy",
    templates=run_shire_inform_legacy.get_handle("templates"),
    model=run_shire_inform_legacy.get_handle("model")
)

### Run the "SPS" estimation

In [None]:
%%time
run_shire_estimate_sps.estimate(test_data)

In [None]:
pdfs_file_sps = "shireSPS_results_lsstSimhp10552_noprior.hdf5"
custom_res_sps = qp.read(pdfs_file_sps)
sz = jnp.array(test_data()['photometry']['redshift'])

In [None]:
a = plot_zp_zs_ensemble(
    custom_res_sps, sz,
    z_grid=None, key_estim="zmode",
    label='_'.join(['SHIRE_SPS']+(os.path.splitext(pdfs_file_sps)[0]).split('_')[2:]))
plt.show()

### Run the "Legacy" estimation

In [None]:
%%time
run_shire_estimate_legacy.estimate(test_data)

In [None]:
pdfs_file_legacy = "shireLEG_results_lsstSimhp10552_noprior.hdf5"
custom_res_legacy = qp.read(pdfs_file_legacy)
# sz = jnp.array(test_data()['photometry']['redshift']) -- Unnecessary

In [None]:
a = plot_zp_zs_ensemble(
    custom_res_legacy, sz,
    z_grid=None, key_estim="zmode",
    label='_'.join(['SHIRE_Legacy']+(os.path.splitext(pdfs_file_legacy)[0]).split('_')[2:]))
plt.show()

### Compare outliers distribution between both methods

In [None]:
hist_outliers(
    custom_res_sps, sz, label1='_'.join(['SHIRE_SPS']+(os.path.splitext(pdfs_file_sps)[0]).split('_')[2:]),
    qp_ens_2=custom_res_legacy, label2='_'.join(['SHIRE_Legacy']+(os.path.splitext(pdfs_file_legacy)[0]).split('_')[2:])
)
plt.show()

## Evaluate posteriors using `RAIL` 

## Build a pipeline with `ceci`

In [None]:
import ceci
pipe = ceci.Pipeline.interactive()
stages = [run_shire_inform_sps, run_shire_estimate_sps]
for stage in stages:
    pipe.add_stage(stage)
pipe.stage_execution_config['shireSPS_estimate_lsstSimhp10552'].nprocess=1

In [None]:
pipe.initialize(
    dict(
        training_data=trainFile,
        test_data=testFile
    ),
    dict(
        output_dir='.',
        log_dir='.',
        resume=False
    ),
    None
)

In [None]:
pipe.save('rail_shireSPS_pz.yml')

In [None]:
pr = ceci.Pipeline.read('rail_shireSPS_pz.yml')

In [None]:
pr.run()