In [None]:
import os
import qp
import jax
import matplotlib.pyplot as plt
from jax import numpy as jnp

from rail.core.data import TableHandle
from rail.core.stage import RailStage
from rail.utils.path_utils import RAILDIR
from rail.core.common_params import SHARED_PARAMS

from rail.shire import ShireInformer, ShireEstimator, hist_outliers, plot_zp_zs_ensemble

jax.config.update("jax_enable_x64", True)

## Select and load data into the datastore

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
trainFile = os.path.abspath(os.path.join('.', 'data', 'train_magszgalaxies_lsstroman_gold_hp10552_10k.h5')) #os.path.abspath(os.path.join('./data', 'test_magszgalaxies_lsstroman_gold_hp10552_50k.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_training_9816.hdf5')
testFile = os.path.abspath(os.path.join('.', 'data', 'test_magszgalaxies_lsstroman_gold_hp10552_50k.h5')) #os.path.abspath(os.path.join('../..', 'magszgalaxies_lsstroman_gold_hp10552.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_validation_9816.hdf5')

training_data = DS.read_file("training_data", TableHandle, trainFile)
test_data = DS.read_file("test_data", TableHandle, testFile)

In [None]:
lsst_filts_dict = {f"{_n}_lsst": "filt_lsst" for _n in "ugrizy"}
roman_filts_dict = {
    "roman_wfi_f106": "",
    "roman_wfi_f129": "",
    "roman_wfi_f158": "",
    "roman_wfi_f184": ""
}

_bands = [ f"mag_{_k}" for _k in {**lsst_filts_dict, **roman_filts_dict} ]
_errbands = [ f"mag_err_{_k}" for _k in {**lsst_filts_dict, **roman_filts_dict} ]
_maglims = {**SHARED_PARAMS['mag_limits'], **{_band: 28 for _band in _bands if 'roman' in _band}}
print(_maglims)

## Inform the estimator, i.e. select a subset of galaxies as templates

In [None]:
default_dict_inform = dict(
    hdf5_groupname="photometry",
    output="shire_templates_lsstSimhp10552.hf5",
    data_path="./data",
    bands=_bands,
    err_bands=_errbands,
    mag_limits=_maglims,
    spectra_file="dsps_valid_fits_F2_GG_DESI_SM3.h5",
    ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
    filter_dict={**lsst_filts_dict, **roman_filts_dict},
    wlmin=1000.,
    wlmax=25000.,
    dwl=10.,
    colrsbins=40
)

In [None]:
run_shire_inform = ShireInformer.make_stage(name="shire_inform_lsstSimhp10552", **default_dict_inform)

In [None]:
%%time
run_shire_inform.inform(training_data)

In [None]:
templ = run_shire_inform.get_handle("templates")
templ.read()

In [None]:
run_shire_inform.plot_colrs_templates()

In [None]:
run_shire_inform.plot_sfh_templates()

In [None]:
run_shire_inform.hist_colrs_templates()

In [None]:
run_shire_inform.plot_bpt_templates()

## Run the photometric redshifts estimation

In [None]:
default_dict_estimate = dict(
    hdf5_groupname="photometry",
    output="shire_results_lsstSimhp10552_noprior.hdf5",
    data_path="./data",
    templ_type="SPS",
    bands=_bands,
    err_bands=_errbands,
    zmin=0.01,
    zmax=3.1,
    nzbins=310,
    mag_limits=_maglims,
    ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
    filter_dict={**lsst_filts_dict, **roman_filts_dict},
    wlmin=1000.,
    wlmax=25000.,
    dwl=10.,
    no_prior=True,
    chunk_size=5000
)

run_shire_estimate = ShireEstimator.make_stage(
    name="shire_estimate_lsstSimhp10552",
    **default_dict_estimate,
    templates=run_shire_inform.get_handle("templates"),
    model=run_shire_inform.get_handle("model")
)

In [None]:
%%time
run_shire_estimate.estimate(test_data)

In [None]:
pdfs_file = "shire_results_lsstSimhp10552_noprior.hdf5"
custom_res = qp.read(pdfs_file)
sz = jnp.array(test_data()['photometry']['redshift'])

In [None]:
a = plot_zp_zs_ensemble(custom_res, sz, z_grid=None, key_estim="zmode", label='_'.join(['SHIRE']+(os.path.splitext(pdfs_file)[0]).split('_')[2:]))
plt.show()

In [None]:
hist_outliers(
    custom_res, sz, label1='_'.join(['SHIRE']+(os.path.splitext(pdfs_file)[0]).split('_')[2:])
)
plt.show()

## Build a pipeline with `ceci`

In [None]:
import ceci
pipe = ceci.Pipeline.interactive()
stages = [run_shire_inform, run_shire_estimate]
for stage in stages:
    pipe.add_stage(stage)
pipe.stage_execution_config['shire_estimate_lsstSimhp10552'].nprocess=2

In [None]:
pipe.initialize(
    dict(
        training_data=trainFile,
        test_data=testFile
    ),
    dict(
        output_dir='.',
        log_dir='.',
        resume=False
    ),
    None
)

In [None]:
pipe.save('rail_shire_pz.yml')

In [None]:
pr = ceci.Pipeline.read('rail_shire_pz.yml')

In [None]:
pr.run()