In [None]:
import os
import qp
import jax
import json
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from jax import numpy as jnp
from tqdm import tqdm

from rail.core.data import TableHandle
from rail.core.stage import RailStage
from rail.utils.path_utils import RAILDIR
from rail.core.common_params import SHARED_PARAMS

from rail.shire import ShireInformer, ShireEstimator, hist_outliers, plot_zp_zs_ensemble

jax.config.update("jax_enable_x64", True)

## Select and load data into the datastore

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
trainFile = os.path.abspath(os.path.join('.', 'data', 'train_magszgalaxies_lsstroman_gold_hp10552_10k.h5')) # os.path.abspath(os.path.join('./data', 'test_magszgalaxies_lsstroman_gold_hp10552_50k.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_training_9816.hdf5') # 
testFile = os.path.abspath(os.path.join('.', 'data', 'test_magszgalaxies_lsstroman_gold_hp10552_50k.h5')) # os.path.abspath(os.path.join('../..', 'magszgalaxies_lsstroman_gold_hp10552.h5')) #os.path.join(RAILDIR, 'rail/examples_data/testdata/test_dc2_validation_9816.hdf5') # 

training_data = DS.read_file("training_data", TableHandle, trainFile)
test_data = DS.read_file("test_data", TableHandle, testFile)

In [None]:
lsst_filts_dict = {f"{_n}_lsst": "filt_lsst" for _n in "ugrizy"}

_bands = [ f"mag_{_k}" for _k in lsst_filts_dict ]
_errbands = [ f"mag_err_{_k}" for _k in lsst_filts_dict ]

## Run successive estimations with random templates
For each estimation, we store the outliers rate $\eta$ and the normalised MAD $\sigma_\mathrm{MAD}$ for both SPS and Legacy estimations.

In [None]:
gen = np.random.default_rng(seed=46)

nestim = 100
ntempls = gen.integers(5, 60, endpoint=True, size=nestim)

etas_sps, etas_legacy, nmad_sps, nmad_legacy, outl_histos_sps, outl_histos_legacy = [], [], [], [], [], []
sz = jnp.array(test_data()['photometry']['redshift'])
zs_counts, z_bins = np.histogram(sz, bins=50, density=False) #np.histogram(sz, bins='auto', density=False)

def calc_eta_nmad_outl(ens_PDFs, z_true, z_grid=None, key_estim='zmode', bins=None):
    z_grid = jnp.squeeze(jnp.array(ens_PDFs[0].dist.xvals, dtype=jnp.float64)) if z_grid is None else z_grid
    zp = jnp.squeeze(ens_PDFs.mode(z_grid)) if key_estim is None else ens_PDFs.ancil[key_estim]
    zs = jnp.squeeze(z_true) #ens_PDFs.ancil[key_truth]
    bias = zp - zs
    errz = bias/(1+zs)
    mad = jnp.median(jnp.abs(errz)) # - medscat))
    sig_mad = 1.4826 * mad
    outliers = jnp.nonzero(jnp.abs(errz)*100.0 > 15) #3*sigscat) #
    outl_rate = zs[outliers].shape[0] / zs.shape[0]
    zbins = np.histogram_bin_edges(zs, bins='auto') if bins is None else bins
    outl_counts, outl_binedges = np.histogram(zs[outliers], bins=zbins, density=False)
    return outl_rate, sig_mad, outl_counts, outl_binedges

In [None]:
%%time

import io
from contextlib import redirect_stdout, redirect_stderr
trap = io.StringIO()

for nestim, ntempl in tqdm(enumerate(ntempls), total=ntempls.shape[0]):
    #print(f"Estimation {nestim}/{nestim.shape[0]}: {ntempl} random templates.")
    with redirect_stdout(trap):
        with redirect_stderr(trap):
            default_dict_inform = dict(
                hdf5_groupname="photometry",
                data_path="./data",
                bands=_bands,
                err_bands=_errbands,
                spectra_file="dsps_valid_fits_F2SM3_GG_DESI.h5",# "dsps_valid_fits_GG_DESI.h5", #
                ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
                filter_dict=lsst_filts_dict,
                wlmin=500.,
                wlmax=12000.,
                dwl=5.,
                randomsel=True, # Here we specify that we don't want to run proper training
                ntemplates=ntempl # In that case this flag is used to select the number of templates randomly drawn from `spectra_file`
            )
        
            run_shire_inform_rand = ShireInformer.make_stage(
                name="shireRandSucc_inform_lsstSimhp10552",
                output="shireRandSucc_templates_lsstSimhp10552.hf5",
                **default_dict_inform,
                templ_type="SPS",
            )
            run_shire_inform_rand.inform(training_data)
        
            default_dict_estimate = dict(
                hdf5_groupname="photometry",
                data_path="./data",
                bands=_bands,
                err_bands=_errbands,
                zmin=0.01,
                zmax=3.1,
                nzbins=310,
                ssp_file="ssp_data_fsps_v3.2_lgmet_age.h5",
                filter_dict=lsst_filts_dict,
                wlmin=500.,
                wlmax=12000.,
                dwl=5.,
                no_prior=True,
                chunk_size=10000
            )
            
            run_shire_estimate_sps = ShireEstimator.make_stage(
                name="shireRandSuccSPS_estimate_lsstSimhp10552",
                output="shireRandSuccSPS_results_lsstSimhp10552_noprior.hdf5",
                **default_dict_estimate,
                templ_type="SPS",
                templates=run_shire_inform_rand.get_handle("templates"),
                model=run_shire_inform_rand.get_handle("model")
            )
            run_shire_estimate_sps.estimate(test_data)
            
            run_shire_estimate_legacy = ShireEstimator.make_stage(
                name="shireRandSuccLEG_estimate_lsstSimhp10552",
                output="shireRandSuccLEG_results_lsstSimhp10552_noprior.hdf5",
                **default_dict_estimate,
                templ_type="Legacy",
                templates=run_shire_inform_rand.get_handle("templates"),
                model=run_shire_inform_rand.get_handle("model")
            )
            run_shire_estimate_legacy.estimate(test_data)
        
            pdfs_file_sps = "shireRandSuccSPS_results_lsstSimhp10552_noprior.hdf5"
            custom_res_sps = qp.read(pdfs_file_sps)
            _etasps, _nmadsps, _outl_cnt_sps, _ = calc_eta_nmad_outl(custom_res_sps, sz, z_grid=None, key_estim='zmode', bins=z_bins)
            etas_sps.append(_etasps)
            nmad_sps.append(_nmadsps)
            outl_histos_sps.append(_outl_cnt_sps)
            
            pdfs_file_legacy = "shireRandSuccLEG_results_lsstSimhp10552_noprior.hdf5"
            custom_res_legacy = qp.read(pdfs_file_legacy)
            _etaleg, _nmadleg, _outl_cnt_leg, _ = calc_eta_nmad_outl(custom_res_legacy, sz, z_grid=None, key_estim='zmode', bins=z_bins)
            etas_legacy.append(_etaleg)
            nmad_legacy.append(_nmadleg)
            outl_histos_legacy.append(_outl_cnt_leg)
etas_sps = 100.0*np.array(etas_sps)
etas_legacy = 100.0*np.array(etas_legacy)
nmad_sps = np.array(nmad_sps)
nmad_legacy = np.array(nmad_legacy)
outl_histos_sps = np.column_stack(outl_histos_sps)
outl_histos_legacy = np.column_stack(outl_histos_legacy)

In [None]:
eta_mean_sps, eta_std_sps, nmad_mean_sps, nmad_std_sps = np.nanmean(etas_sps), np.nanstd(etas_sps), np.nanmean(nmad_sps), np.nanstd(nmad_sps)
eta_mean_legacy, eta_std_legacy, nmad_mean_legacy, nmad_std_legacy = np.nanmean(etas_legacy), np.nanstd(etas_legacy), np.nanmean(nmad_legacy), np.nanstd(nmad_legacy)


_dumx = np.linspace(ntempls.min(), ntempls.max(), 5, endpoint=True)

f, a = plt.subplots(2, 1, sharex=True)
a[0].scatter(ntempls, etas_sps, label='SPS', c='green')
a[0].scatter(ntempls, etas_legacy, label='Legacy', c='orange')
a[0].axhline(eta_mean_sps, color='green')
a[0].fill_between(
    _dumx,
    np.full_like(_dumx, eta_mean_sps+eta_std_sps),
    np.full_like(_dumx, eta_mean_sps-eta_std_sps),
    label=r'$\overline{\eta}_\mathrm{SPS}=$'+f'{eta_mean_sps:.1f}'+r'$\pm$'+f'{eta_std_sps:.2f} [%]',
    color='green',
    alpha=0.5
)
a[0].axhline(eta_mean_legacy, color='orange')
a[0].fill_between(
    _dumx,
    np.full_like(_dumx, eta_mean_legacy+eta_std_legacy),
    np.full_like(_dumx, eta_mean_legacy-eta_std_legacy),
    label=r'$\overline{\eta}_\mathrm{Legacy}=$'+f'{eta_mean_legacy:.1f}'+r'$\pm$'+f'{eta_std_legacy:.2f} [%]',
    color='orange',
    alpha=0.5
)


a[1].scatter(ntempls, nmad_sps, c='green')
a[1].scatter(ntempls, nmad_legacy, c='orange')
a[1].axhline(nmad_mean_sps, color='green')
a[1].fill_between(
    _dumx,
    np.full_like(_dumx, nmad_mean_sps+nmad_std_sps),
    np.full_like(_dumx, nmad_mean_sps-nmad_std_sps),
    label=r'$\overline{\sigma_\mathrm{MAD}}_{\mathrm{, SPS}}=$'+f'{nmad_mean_sps:.3f}'+r'$\pm$'+f'{nmad_std_sps:.4f}',
    color='green',
    alpha=0.5
)
a[1].axhline(nmad_mean_legacy, color='orange')
a[1].fill_between(
    _dumx,
    np.full_like(_dumx, nmad_mean_legacy+nmad_std_legacy),
    np.full_like(_dumx, nmad_mean_legacy-nmad_std_legacy),
    label=r'$\overline{\sigma_\mathrm{MAD}}_{\mathrm{, Legacy}}=$'+f'{nmad_mean_legacy:.3f}'+r'$\pm$'+f'{nmad_std_legacy:.4f}',
    color='orange',
    alpha=0.5
)


a[1].set_xlabel('N. of templates')
a[0].set_ylabel(r'$\eta\ \mathrm{[\%]}$')
a[1].set_ylabel(r'$\sigma_\mathrm{MAD}\ \mathrm{[-]}$')
f.legend()
plt.show()

In [None]:
mean_hist_sps, std_hist_sps = np.nanmean(outl_histos_sps, axis=1), np.nanstd(outl_histos_sps, axis=1)
mean_hist_legacy, std_hist_legacy = np.nanmean(outl_histos_legacy, axis=1), np.nanstd(outl_histos_legacy, axis=1)

In [None]:
f, a = plt.subplots(1, 1)

a.stairs(mean_hist_sps, edges=z_bins, color='green')
a.stairs(mean_hist_sps+std_hist_sps, baseline=mean_hist_sps-std_hist_sps, fill=True, edges=z_bins, alpha=0.5, color='green', label='SPS')

a.stairs(mean_hist_legacy, edges=z_bins, color='orange')
a.stairs(mean_hist_legacy+std_hist_legacy, baseline=mean_hist_legacy-std_hist_legacy, fill=True, edges=z_bins, alpha=0.5, color='orange', label='Legacy')

a.set_xlabel("Redshift")
a.set_ylabel("N. of outliers in bins [-]")
a.legend()

In [None]:
f, a = plt.subplots(1, 1)

a.stairs(100.0*mean_hist_sps/zs_counts, edges=z_bins, color='green')
a.stairs(100.0*(mean_hist_sps+std_hist_sps)/zs_counts, baseline=100.0*(mean_hist_sps-std_hist_sps)/zs_counts, fill=True, edges=z_bins, alpha=0.5, color='green', label='SPS')

a.stairs(100.0*mean_hist_legacy/zs_counts, edges=z_bins, color='orange')
a.stairs(100.0*(mean_hist_legacy+std_hist_legacy)/zs_counts, baseline=100.0*(mean_hist_legacy-std_hist_legacy)/zs_counts, fill=True, edges=z_bins, alpha=0.5, color='orange', label='Legacy')

a.set_xlabel("Redshift")
a.set_ylabel("Outliers in bins [%]")
a.grid()
a.legend()