# Notebook to run photo-z estimation with `process_fors2.photoZ`
Created by J. Chevalier on October 28, 2024

## Requirements
This notebook requires the `process_fors2` and `dsps` packages. Installation instructions can be found following [this link](https://github.com/JospehCeh/process_fors2.git).

Then, a `jupyter` kernel must be created and associated to your installation environment and used to run this notebook.

## Imports

In [None]:
import numpy as np  # simply more convenient for a couple of things
from jax import numpy as jnp
import h5py

from process_fors2.fetchData import json_to_inputs, readPhotoZHDF5
from process_fors2.photoZ import run_from_inputs

## Default run

In [None]:
# Defaults settings
conf_file_sps = "conf_IDRIS_cosmos2020_allFilts_noPrior.json" # "conf_IDRIS_PZ_TemplSel.json" # 
conf_file_legacy = "conf_IDRIS_LEGACY-cosmos2020_allFilts_noPrior.json" # "conf_IDRIS_LEGACY-PZ_TemplSel.json" # 

input_settings_sps = json_to_inputs(conf_file_sps)
inputs_pz_sps = input_settings_sps["photoZ"]

input_settings_legacy = json_to_inputs(conf_file_legacy)
inputs_pz_legacy = input_settings_legacy["photoZ"]

inputs_pz_sps

Here we have loaded the defaults settings that come within the package for a minimal (and not optimized at all) run.
It is structured as a dictionary so it is easy to change one of the settings before the run. For example :

In [None]:
if False:
    inputs_pz_legacy["Templates"] = {
        'input': "templ_NEWscoredOnTraining_SPS.h5",
        'output': "SEDtempl_NEWscoredOnTraining_SPS.h5",
        'overwrite': True
    }
    inputs_pz_legacy["run name"] = "PZ_SPSbutLEGACY_COSMOS2020vis_NEWscored_noprior"
    #inputs_pz_legacy['save results'] = True
    #inputs_pz_legacy['Mode'] = 'Legacy'

Note that as of now, the setting we have modified is actually not used in the code, so it won't affect our run and only serves as an example.

## Photometric redshifts
Now we will run the code from the notebook. It may take some time, up to a couple of hours on larger datasets, and jax implementation does not make it easy to incorporate progress bars so please just be patient...

In [None]:
%%time
if False:
    pz_res_tree = run_from_inputs(input_settings_sps)  # It is necessary here to use the overall dictionary and not the 'photoZ' subset

Now, we may want to save the results to the disk, so let us write them into an `HDF5` file with the included utilitary `photoZtoHDF5`:

In [None]:
if False:
    if input_settings_sps["photoZ"]["save results"]:
        from process_fors2.fetchData import photoZtoHDF5
        resfile = photoZtoHDF5(f"{input_settings_sps['photoZ']['run name']}_posteriors_dict.h5", pz_res_tree)
    else:
        resfile = "Run terminated correctly but results were not saved, please check your input configuration."
    print(resfile)

Alternatively, the steps above can be performed from a terminal by running the command :
`python -m process_fors2.photoZ $HOME/process_fors2/src/data/defaults.json` (or using any other appropriate `JSON` configuration file).

## Let's look at the results
Here we can either read the file we have created using the provided function `readPhotoZHDF5`, or directly use our `pz_res_tree` object.

**Note :**
_If the results were saved in the step above, it is highly recommended to reload them, otherwise the PDFs might be missing from the `pz_res_tree` object due to the presence of `dict.pop()` methods in `photoZtoHDF5`._


In [None]:
load_from_file = input_settings_sps["photoZ"]["save results"]
resfile_sps = f"{input_settings_sps['photoZ']['run name']}_posteriors_dict.h5"
resfile_legacy = f"{input_settings_legacy['photoZ']['run name']}_posteriors_dict.h5"
if load_from_file:
    pz_res_tree_sps = readPhotoZHDF5(resfile_sps)

In [None]:
z_grid = pz_res_tree_sps['z_grid']
#z_grid = jnp.arange(inputs_pz["Z_GRID"]["z_min"], inputs_pz["Z_GRID"]["z_max"] + inputs_pz["Z_GRID"]["z_step"], inputs_pz["Z_GRID"]["z_step"])

In [None]:
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import colors
import matplotlib.cm as cmx

%matplotlib widget

from matplotlib.colors import LinearSegmentedColormap

# "Viridis-like" colormap with white background
white_viridis = LinearSegmentedColormap.from_list('white_viridis', [
    (0, '#ffffff'),
    (1e-20, '#440053'),
    (0.2, '#404388'),
    (0.4, '#2a788e'),
    (0.6, '#21a784'),
    (0.8, '#78d151'),
    (1, '#fde624'),
], N=256)

def _plot_pdz(res_dict, numobj):
    f, a = plt.subplots(1, 1, figsize=(6, 6))
    zs = res_dict["redshift"][numobj]
    mean = res_dict["z_mean"][numobj]
    medz = res_dict["z_med"][numobj]
    mode = res_dict["z_ML"][numobj]
    a.semilogy(res_dict["z_grid"], res_dict["PDZ"][:, numobj])
    if zs is not None:
        a.axvline(zs, c="k", label="z_spec")
    a.axvline(mean, c="r", label="Mean")
    a.axvline(medz, c="orange", label="Median")
    a.axvline(mode, c="g", label="Mode")
    a.legend()
    a.set_xlabel(r"$z_{phot}$")

def plot_zp_zs(res_dict, z_bounds=None, ax=None, label=""):
    if ax is None:
        f, ax = plt.subplots(1, 1, figsize=(7, 6))
    if z_bounds is None:
        z_grid = res_dict["z_grid"]
        zp = res_dict['z_ML']
        zs = res_dict['redshift']
    else:
        zmin, zmax = z_bounds
        sel_grid = jnp.logical_and(res_dict["z_grid"]>=zmin-0.1, res_dict["z_grid"]<=zmax+0.1)
        sel_zs = jnp.logical_and(res_dict["redshift"]>=zmin, res_dict["redshift"]<=zmax)
        z_grid = res_dict["z_grid"][sel_grid]
        zp = res_dict['z_ML'][sel_zs]
        zs = res_dict['redshift'][sel_zs]
    
    bias = zp - zs
    errz = bias/(1+zs)
    meanscat, sigscat, medscat = jnp.mean(errz), jnp.std(errz), jnp.median(errz)
    mad = jnp.median(jnp.abs(errz)) #- medscat))
    sig_mad = 1.4826 * mad
    outliers = jnp.nonzero(jnp.abs(errz)*100.0 > 15) #3*sigscat) #
    outl_rate = len(zs[outliers]) / len(zs)

    cmap = plt.get_cmap('viridis_r')
    cNorm = colors.Normalize(vmin=100*jnp.abs(errz).min(), vmax=20)
    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cmap)
    all_colors_refs = scalarMap.to_rgba(100*jnp.abs(errz), alpha=1)
    
    density = ax.scatter(zs, zp, s=4, alpha=0.2, c=all_colors_refs)
    ax.plot(z_grid, z_grid, c="k", ls=":", lw=1)
    outl, = ax.plot(z_grid, z_grid + 0.15 * (1 + z_grid), c="k", lw=2)
    ax.plot(z_grid, z_grid - 0.15 * (1 + z_grid), c="k", lw=2)

    med, = ax.plot(z_grid, z_grid + medscat*(1 + z_grid), c="orange", lw=2, ls='-.') #, label=r"$\mathrm{median}\left(\zeta_z \right)$")
    scat = ax.fill_between(z_grid, z_grid + (medscat+sigscat)*(1 + z_grid), z_grid + (medscat-sigscat)*(1 + z_grid), color="pink", alpha=0.4)
    
    ax.set_xlabel(r"$z_{spec}$")
    ax.set_ylabel(r"$z_{phot}$")
    ax.set_xlim(z_grid.min()-0.05, z_grid.max()+0.05)
    ax.set_ylim(z_grid.min()-0.05, z_grid.max()+0.05)
    f.legend(
        [density, outl, (med, scat)],
        [
            label,
            "Outliers:\n"+r"$\left| \frac{z_p-z_s}{1+z_s} \right| > 0.15$",
            r"$\mathrm{median} \left( \zeta_z \right) \pm \sigma_{\zeta_z}=$"+f"\n\t{medscat:.3f}"+r"$\pm$"+f"{sigscat:.3f}"
        ],
        loc='lower right',
        bbox_to_anchor=(1., 0.)
    )
    ax.grid()
    ax.set_title(f"{100.0*outl_rate:.3f}% outliers ;\n"+r"$\sigma_{MAD}=$"+f"{sig_mad:.3f}")
    plt.colorbar(scalarMap, ax=ax, location='right', label=r"$\frac{\left| \Delta z \right|}{1+z}$ [%]")
    print(f"{label}: {100*outl_rate:.3f}% outliers out of {len(zp)} successful fits.\nsigma_mad: {sig_mad:.3f}.")
    
    return ax

def plot_zp_zs_photzHDF5(pz_hdf5, res_key='pz_outputs', zgrid_key='z_grid', zp_key='z_ML', zs_key='redshift', label='', bins=100):
    f, ax = plt.subplots(1, 1, figsize=(7, 6))

    with h5py.File(pz_hdf5, 'r') as h5res:
        pzouts = h5res.get(res_key)
        zp = jnp.array(pzouts.get(zp_key), dtype=jnp.float64)
        zs = jnp.array(pzouts.get(zs_key), dtype=jnp.float64)
        z_grid = jnp.array(pzouts.get(zgrid_key), dtype=jnp.float64)

    bias = zp - zs
    errz = bias/(1+zs)
    meanscat, sigscat, medscat = jnp.mean(errz), jnp.std(errz), jnp.median(errz)
    mad = jnp.median(jnp.abs(errz)) # - medscat))
    sig_mad = 1.4826 * mad
    outliers = jnp.nonzero(jnp.abs(errz)*100.0 > 15) #3*sigscat) #
    outl_rate = len(zs[outliers]) / len(zs)
    
    #ax.scatter(zs, zp, s=4, alpha=0.2, label=label, c=all_colors_refs)
    density = ax.hexbin(zs, zp, bins='log', gridsize=bins)
    ax.plot(z_grid, z_grid, c="k", ls=":", lw=1)
    outl, = ax.plot(z_grid, z_grid + 0.15 * (1 + z_grid), c="k", lw=2)
    ax.plot(z_grid, z_grid - 0.15 * (1 + z_grid), c="k", lw=2)

    med, = ax.plot(z_grid, z_grid + medscat*(1 + z_grid), c="orange", lw=2, ls='-.') #, label=r"$\mathrm{median}\left(\zeta_z \right)$")
    scat = ax.fill_between(z_grid, z_grid + (medscat+sigscat)*(1 + z_grid), z_grid + (medscat-sigscat)*(1 + z_grid), color="pink", alpha=0.4)
    
    ax.set_xlabel(r"$z_{spec}$")
    ax.set_ylabel(r"$z_{phot}$")
    ax.set_xlim(z_grid.min()-0.05, z_grid.max()+0.05)
    ax.set_ylim(z_grid.min()-0.05, z_grid.max()+0.05)
    f.legend(
        [density, outl, (med, scat)],
        [
            label,
            "Outliers:\n"+r"$\left| \frac{z_p-z_s}{1+z_s} \right| > 0.15$",
            r"$\mathrm{median} \left( \zeta_z \right) \pm \sigma_{\zeta_z}=$"+f"\n\t{medscat:.3f}"+r"$\pm$"+f"{sigscat:.3f}"
        ],
        loc='lower right',
        bbox_to_anchor=(1., 0.)
    )
    ax.grid()
    ax.set_title(f"{100.0*outl_rate:.3f}% outliers ;\n"+r"$\sigma_{MAD}=$"+f"{sig_mad:.3f}")
    #plt.colorbar(scalarMap, ax=ax, location='right', label="Delta z / (1+z) [%]")
    cbar = f.colorbar(density, label='Density', location='right')
    ax.set_aspect("equal", "box")
    print(f"{label}: {100*outl_rate:.3f}% outliers out of {len(zp)} successful fits.\nsigma_mad: {sig_mad:.3f}.")
    
    return ax

def hist_outliers(pz1_hdf5, pz2_hdf5=None, label1='', label2='', res_key='pz_outputs', zp_key='z_ML', zs_key='redshift'):
    f, ax = plt.subplots(1, 1, figsize=(7, 6))

    with h5py.File(pz1_hdf5, 'r') as h5res:
        pzouts = h5res.get(res_key)
        zp1 = jnp.array(pzouts.get(zp_key), dtype=jnp.float64)
        zs1 = jnp.array(pzouts.get(zs_key), dtype=jnp.float64)

    bias1 = zp1 - zs1
    errz1 = bias1/(1+zs1)
    meanscat1, sigscat1, medscat1 = jnp.mean(errz1), jnp.std(errz1), jnp.median(errz1)
    mad1 = jnp.median(jnp.abs(errz1)) #- medscat1))
    sig_mad1 = 1.4826 * mad1
    outliers1 = jnp.nonzero(jnp.abs(errz1)*100.0 > 15) #3*sigscat) #
    _n, _bins, _ = ax.hist(zs1[outliers1], bins='auto', density=False, label=label1, alpha=0.7)

    if pz2_hdf5 is not None and pz2_hdf5!="":
        with h5py.File(pz2_hdf5, 'r') as h5res:
            pzouts = h5res.get(res_key)
            zp2 = jnp.array(pzouts.get(zp_key), dtype=jnp.float64)
            zs2 = jnp.array(pzouts.get(zs_key), dtype=jnp.float64)
    
        bias2 = zp2 - zs2
        errz2 = bias2/(1+zs2)
        meanscat2, sigscat2, medscat2 = jnp.mean(errz2), jnp.std(errz2), jnp.median(errz2)
        mad2 = jnp.median(jnp.abs(errz2)) #- medscat1))
        sig_mad2 = 1.4826 * mad2
        outliers2 = jnp.nonzero(jnp.abs(errz2)*100.0 > 15) #3*sigscat) #
        ax.hist(zs2[outliers2], bins=_bins, density=False, label=label2, alpha=0.7)
    #ax.set_aspect("equal", "box")
    ax.set_xlabel(r"$z_{spec}$")
    ax.set_ylabel("Outliers count")
    ax.legend()
    return ax

In [None]:
randomid = np.random.choice(pz_res_tree_sps['PDZ'].shape[1])

In [None]:
_plot_pdz(pz_res_tree_sps, randomid)
plt.show()

In [None]:
ax_zpzs = plot_zp_zs(pz_res_tree_sps, z_bounds=None, label='_'.join(inputs_pz_sps["run name"].split('_')[1:3]))
plt.show()

In [None]:
plot_zp_zs_photzHDF5(resfile_sps, label='_'.join(inputs_pz_sps["run name"].split('_')[1:3]), bins=100)
plt.show()

In [None]:
plot_zp_zs_photzHDF5(resfile_legacy, label='_'.join(inputs_pz_legacy["run name"].split('_')[1:3]), bins=100)
plt.show()

In [None]:
hist_outliers(resfile_sps, resfile_legacy, label1='_'.join(inputs_pz_sps["run name"].split('_')[1:3]), label2='_'.join(inputs_pz_legacy["run name"].split('_')[1:3]))
plt.show()