In [None]:
import os
import qp
import jax
import matplotlib.pyplot as plt
import numpy as np
from jax import numpy as jnp
import pandas as pd
import tables_io

from rail.core.data import TableHandle
from rail.core.stage import RailStage
from rail.utils.path_utils import RAILDIR

from rail.shire import ShireInformer, ShireEstimator
from rail.core import SHARED_PARAMS

jax.config.update("jax_enable_x64", True)

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
trainFile = os.path.abspath(os.path.join('.', 'data', 'train_magszgalaxies_lsstroman_gold_hp10552_10k.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_training_9816.hdf5')
testFile = os.path.abspath(os.path.join('.', 'data', 'test_magszgalaxies_lsstroman_gold_hp10552_50k.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_validation_9816.hdf5')
training_data = DS.read_file("training_data", TableHandle, trainFile)
test_data = DS.read_file("test_data", TableHandle, testFile)

In [None]:
default_dict_inform = dict(
    hdf5_groupname="photometry",
    output="shire_templates_lsstSimhp10552.hf5",
    data_path="./data",
    spectra_file="dsps_valid_fits_F2_GG_DESI_SM3.h5",
    ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
    filter_dict={f"{_n}_lsst": "filt_lsst" for _n in "ugrizy"},
    wlmin=100.,
    wlmax=15000.,
    dwl=100.,
    colrsbins=60
)

In [None]:
run_shire_inform = ShireInformer.make_stage(name="shire_inform_lsstSimhp10552", **default_dict_inform)

In [None]:
%%time
run_shire_inform.inform(training_data)

In [None]:
templ = run_shire_inform.get_handle("templates")
templ.read()

In [None]:
default_dict_estimate = dict(
    hdf5_groupname="photometry",
    output="shire_results_lsstSimhp10552_noprior.hdf5",
    data_path="./data",
    templ_type="SPS",
    ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
    filter_dict={f"{_n}_lsst": "filt_lsst" for _n in "ugrizy"},
    wlmin=100.,
    wlmax=15000.,
    dwl=100.,
    no_prior=True,
    chunk_size=50
)

run_shire_estimate = ShireEstimator.make_stage(
    name="shire_estimate_lsstSimhp10552",
    **default_dict_estimate,
    templates=run_shire_inform.get_handle("templates")
)

In [None]:
%%time
run_shire_estimate.estimate(test_data)

In [None]:
custom_res = qp.read("shire_results_lsstSimhp10552_noprior.hdf5")

In [None]:
sz = test_data()['photometry']['redshift']

In [None]:
plt.figure(figsize=(8,8))
plt.scatter(sz, custom_res.ancil['zmode'].flatten(), s=2, c='k', label='zmode')
plt.plot([0,3], [0,3], 'r--')
plt.xlabel("redshift")
plt.ylabel("photo-z mode")
plt.legend(loc='upper center', fontsize=10)

In [None]:
custom_res.ancil['zmode'].shape

In [None]:
fig, axs = plt.subplots(1,1, figsize=(10,6))
custom_res.plot_native(key=110, axes=axs, label='example')
axs.set_xlabel("redshift")
axs.set_ylabel("PDF")
axs.set_yscale('log')
axs.legend(loc="upper center", fontsize=10)