In [None]:
import os
import qp
import jax
import matplotlib.pyplot as plt
from jax import numpy as jnp

from rail.core.data import TableHandle
from rail.core.stage import RailStage
from rail.utils.path_utils import RAILDIR

from rail.shire import ShireInformer, ShireEstimator, hist_outliers, plot_zp_zs_ensemble

jax.config.update("jax_enable_x64", True)

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
trainFile = os.path.abspath(os.path.join('.', 'data', 'train_magszgalaxies_lsstroman_gold_hp10552_10k.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_training_9816.hdf5')
testFile = os.path.abspath(os.path.join('.', 'data', 'test_magszgalaxies_lsstroman_gold_hp10552_50k.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_validation_9816.hdf5')
training_data = DS.read_file("training_data", TableHandle, trainFile)
test_data = DS.read_file("test_data", TableHandle, testFile)

In [None]:
default_dict_inform = dict(
    hdf5_groupname="photometry",
    output="shire_templates_lsstSimhp10552.hf5",
    data_path="./data",
    spectra_file="dsps_valid_fits_F2_GG_DESI_SM3.h5",
    ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
    filter_dict={f"{_n}_lsst": "filt_lsst" for _n in "ugrizy"},
    wlmin=100.,
    wlmax=15000.,
    dwl=100.,
    colrsbins=46
)

In [None]:
run_shire_inform = ShireInformer.make_stage(name="shire_inform_lsstSimhp10552", **default_dict_inform)

In [None]:
%%time
run_shire_inform.inform(training_data)

In [None]:
templ = run_shire_inform.get_handle("templates")
templ.read()

In [None]:
default_dict_estimate = dict(
    hdf5_groupname="photometry",
    output="shire_results_lsstSimhp10552_noprior.hdf5",
    data_path="./data",
    templ_type="SPS",
    ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
    filter_dict={f"{_n}_lsst": "filt_lsst" for _n in "ugrizy"},
    wlmin=100.,
    wlmax=15000.,
    dwl=100.,
    no_prior=True,
    chunk_size=5000
)

run_shire_estimate = ShireEstimator.make_stage(
    name="shire_estimate_lsstSimhp10552",
    **default_dict_estimate,
    templates=run_shire_inform.get_handle("templates")
)

In [None]:
%%time
run_shire_estimate.estimate(test_data)

In [None]:
pdfs_file = "shire_results_lsstSimhp10552_noprior.hdf5"
custom_res = qp.read(pdfs_file)
sz = jnp.array(test_data()['photometry']['redshift'])

In [None]:
a = plot_zp_zs_ensemble(custom_res, sz, z_grid=None, key_estim="zmode", label='_'.join(['SHIRE']+(os.path.splitext(pdfs_file)[0]).split('_')[2:]))
plt.show()

In [None]:
hist_outliers(
    custom_res, sz, label1='_'.join(['SHIRE']+(os.path.splitext(pdfs_file)[0]).split('_')[2:])
)
plt.show()