# Notebook to run photo-z estimation with `process_fors2.photoZ`
Created by J. Chevalier on October 28, 2024

## Requirements
This notebook requires the `process_fors2` and `dsps` packages. Installation instructions can be found following [this link](https://github.com/JospehCeh/process_fors2.git).

Then, a `jupyter` kernel must be created and associated to your installation environment and used to run this notebook.

## Imports

In [None]:
import numpy as np  # simply more convenient for a couple of things
from jax import numpy as jnp

from process_fors2.fetchData import json_to_inputs, readPhotoZHDF5
from process_fors2.photoZ import run_from_inputs

## Default run

In [None]:
# Defaults settings
conf_file = "../../src/data/defaults.json"

input_settings = json_to_inputs(conf_file)
print(input_settings)

In [None]:
inputs_pz = input_settings["photoZ"]
inputs_pz

Here we have loaded the defaults settings that come within the package for a minimal (and not optimized at all) run.
It is structured as a dictionary so it is easy to change one of the settings before the run. For example :

In [None]:
input_settings["photoZ"]["Estimator"] = "delight"
inputs_pz["Estimator"]

We may also activate the prior. Note that in this cas it may also be wise to reflect this in the `input_settings["photoZ"]["run name"]` value, although we will not do it here in the sake of laziness (and to avoid generating too many files).

In [None]:
input_settings["photoZ"]["prior"] = True
inputs_pz["prior"]

Note that as of now, the setting we have modified is actually not used in the code, so it won't affect our run and only serves as an example.

## Photometric redshifts
Now we will run the code from the notebook. It may take some time, up to a couple of hours on larger datasets, and jax implementation does not make it easy to incorporate progress bars so please just be patient...

In [None]:
pz_res_tree = run_from_inputs(input_settings)  # It is necessary here to use the overall dictionary and not the 'photoZ' subset

Now, we may want to save the results to the disk, so let us write them into an `HDF5` file with the included utilitary `photoZtoHDF5`:

In [None]:
if input_settings["photoZ"]["save results"]:
    from process_fors2.fetchData import photoZtoHDF5

    # df_gal.to_pickle(f"{inputs['run name']}_results_summary.pkl")
    # with open(f"{inputs['photoZ']['run name']}_posteriors_dict.pkl", "wb") as handle:
    #    pickle.dump(tree_of_results_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    resfile = photoZtoHDF5(f"{input_settings['photoZ']['run name']}_posteriors_dict.h5", pz_res_tree)
else:
    resfile = "Run terminated correctly but results were not saved, please check your input configuration."
print(resfile)

Alternatively, the steps above can be performed from a terminal by running the command :
`python -m process_fors2.photoZ $HOME/process_fors2/src/data/defaults.json` (or using any other appropriate `JSON` configuration file).

## Let's look at the results
Here we can either read the file we have created using the provided function `readPhotoZHDF5`, or directly use our `pz_res_tree` object.

**Note :**
_If the results were saved in the step above, it is highly recommended to reload them, otherwise the PDFs might be missing from the `pz_res_tree` object due to the presence of `dict.pop()` methods in `photoZtoHDF5`._


In [None]:
load_from_file = input_settings["photoZ"]["save results"]
if load_from_file:
    pz_res_tree = readPhotoZHDF5(resfile)

In [None]:
z_grid = jnp.arange(inputs_pz["Z_GRID"]["z_min"], inputs_pz["Z_GRID"]["z_max"] + inputs_pz["Z_GRID"]["z_step"], inputs_pz["Z_GRID"]["z_step"])
# or equivalently
z_grid = pz_res_tree["z_grid"]

In [None]:
import matplotlib
import matplotlib.pyplot as plt

%matplotlib inline


def _plot_pdz(pz_res_dict, obsid):
    z = pz_res_dict["z_grid"]
    pdz = pz_res_dict["PDZ"][:, obsid]
    mean = pz_res_dict["z_mean"][obsid]
    zs = pz_res_dict["z_spec"][obsid]
    medz = pz_res_dict["z_med"][obsid]

    plt.semilogy(z, pdz)
    if jnp.isfinite(zs):
        plt.axvline(zs, c="k", label="z_spec")
    plt.axvline(mean, c="r", label="Mean")
    plt.axvline(medz, c="g", label="Median")
    plt.legend()
    plt.show()

In [None]:
randomid = np.random.choice(pz_res_tree["PDZ"].shape[1])

In [None]:
_plot_pdz(pz_res_tree, randomid)

In [None]:
f, a = plt.subplots(1, 1, figsize=(6, 6))
zp = pz_res_tree["z_ML"]
zs = pz_res_tree["z_spec"]

zp = np.array(zp)
zs = np.array(zs)

bias = np.abs(zp - zs)
outliers = np.nonzero(bias > 0.15 * (1 + zs))
outl_rate = 100.0 * len(zs[outliers]) / len(zs)

a.scatter(zs, zp, s=4, alpha=0.2, label=f"SPS: {outl_rate:.3f}% outliers", color="green")
a.plot(z_grid, z_grid, c="k", ls=":", lw=1)
a.plot(z_grid, z_grid + 0.15 * (1 + z_grid), c="k", lw=2)
a.plot(z_grid, z_grid - 0.15 * (1 + z_grid), c="k", lw=2)
a.set_xlabel("z_spec")
a.set_ylabel("z_phot")
a.legend()
a.grid()
# a.set_xlim(0., 3.1)
# a.set_ylim(0., 3.1)

In [None]:
print(f"SPS templates : {outl_rate:.3f}% outliers out of {len(zp)} successful fits.")

## Some checks
Let's have a look at files that were created along the way.

In [None]:
import pandas as pd

from process_fors2.fetchData import readTemplatesHDF5

In [None]:
!ls

In [None]:
h5cat = "COSMOS2020_emu_hscOnly_CC_allzinf3.h5"
h5inp = "pz_inputs_COSMOS2020_emu_hscOnly_CC_allzinf3.h5"
h5templ = "SEDtempl_SPS_mags+rews_1_to_10.h5"

In [None]:
df_cat = pd.read_hdf(h5cat, key="catalog")
df_cat

In [None]:
df_inp = pd.read_hdf(h5inp, key="pz_inputs")
df_inp

In [None]:
dict_templ = readTemplatesHDF5(h5templ)
dict_templ