# Notebook to run photo-z estimation with `rail.dsps_fors2_pz`
Created by J. Chevalier on October 25, 2024

## Requirements
This notebook requires the `rail`, `rail_dsps` packages. Installation instructions can be found following [this link](https://rail-hub.readthedocs.io/en/latest/source/installation.html).
It also uses the [rail_dspsXfors_pz](https://github.com/JospehCeh/rail_dspsXfors2_pz) package and associated dependencies. As of now, this should be installed as follows in the same `conda` environment `[rail_env_name]` as the one used for `RAIL` installation :

```bash
conda activate [rail_env_name]
git clone https://github.com/JospehCeh/rail_dspsXfors2_pz.git
cd rail_dspsXfors2_pz
pip install --no-cache-dir .
```

Then, a `jupyter` kernel must be created and associated to your environment `[rail_env_name]` and used to run this notebook.

## Imports

In [None]:
import numpy as np # simply more convenient for a couple of things
from jax import numpy as jnp
from rail.dsps_fors2_pz import readPhotoZHDF5, json_to_inputs, run_from_inputs

## Default run

In [None]:
# Defaults settings
conf_file = '/home/chevalier/rail_dspsXfors2_pz/src/rail/dsps_fors2_pz/data/defaults.json'

input_settings = json_to_inputs(conf_file)
print(input_settings)

In [None]:
inputs_pz = input_settings['photoZ']
inputs_pz

Here we have loaded the defaults settings that come within the package for a minimal (and not optimized at all) run.
It is structured as a dictionary so it is easy to change one of the settings before the run. For example :

In [None]:
input_settings['photoZ']['Estimator']='delight'
inputs_pz

Note that as of now, the setting we have modified is actually not used in the code, so it won't affect our run and only serves as an example.

## Photometric redshifts
Now we will run the code from the notebook. It may take some time, up to a couple of hours on larger datasets, and jax implementation does not make it easy to incorporate progress bars so please just be patient...

In [None]:
pz_res_tree = run_from_inputs(input_settings) # It is necessary here to use the overall dictionary and not the 'photoZ' subset

Now, we may want to save the results to the disk, so let us write them into an `HDF5` file with the included utilitary `photoZtoHDF5`:

In [None]:
if input_settings["photoZ"]["save results"]:
        from rail.dsps_fors2_pz import photoZtoHDF5
        # df_gal.to_pickle(f"{inputs['run name']}_results_summary.pkl")
        # with open(f"{inputs['photoZ']['run name']}_posteriors_dict.pkl", "wb") as handle:
        #    pickle.dump(tree_of_results_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        resfile = photoZtoHDF5(f"{input_settings['photoZ']['run name']}_posteriors_dict.h5", pz_res_tree)
else:
        resfile = "Run terminated correctly but results were not saved, please check your input configuration."
print(resfile)

Alternatively, the steps above can be performed from a terminal by running the command :
`python -m rail.dsps_fors2_pz $HOME/rail_dspsXfors2_pz/src/rail/dsps_fors2_pz/data/defaults.json` (or using any other appropriate `JSON` configuration file).

## Let's look at the results
Here we can either read the file we have created using the provided function `readPhotoZHDF5`, or directly use our `pz_res_tree` object.

**Note :**
_If the results were saved in the step above, it is highly recommended to reload them, otherwise the PDFs might be missing from the `pz_res_tree` object due to the presence of `dict.pop()` methods in `photoZtoHDF5`._


In [None]:
load_from_file = input_settings["photoZ"]["save results"]
if load_from_file:
    pz_res_tree = readPhotoZHDF5(resfile)

In [None]:
z_grid = jnp.arange(
    inputs_pz["Z_GRID"]["z_min"],
    inputs_pz["Z_GRID"]["z_max"] + inputs_pz["Z_GRID"]["z_step"],
    inputs_pz["Z_GRID"]["z_step"]
)

In [None]:
import matplotlib.pyplot as plt
def stats(z, pdz, zs=None, plot=False):
    cdf = np.array(
        [
            np.trapz(pdz[:i], x=z[:i]) for i in range(len(z))
        ]
    )
    medz = z[np.where(cdf>=0.5)][0]
    mean = np.trapz(z*pdz, x=z)
    if plot:
        plt.semilogy(z, pdz)
        plt.semilogy(z, cdf)
        if zs is not None:
            plt.axvline(zs, c='k', label='z_spec')
        plt.axvline(mean, c='r', label='Mean')
        plt.axvline(medz, c='g', label='Median')
        plt.axhline(0.5, c='grey', ls=':')
        plt.legend()
    return mean, medz

In [None]:
randomid = np.random.choice(len(pz_res_tree))

In [None]:
obs0 = pz_res_tree[randomid]
mean, med = stats(obs0['PDZ'][:, 0], obs0['PDZ'][:, 1], obs0['z_spec'], plot=True)

In [None]:
from tqdm import tqdm

f, a = plt.subplots(1, 1, figsize=(6,6))
zp = []
zs = []
znan = []
for obs in tqdm(pz_res_tree):
    #a[0].plot(z_grid, obs['PDZ'])
    try:
        zp.append(obs['z_ML'])
        zs.append(obs['z_spec'])
    except ValueError:
        znan.append(obs['z_spec'])
    #try:
    #    meanz, medz = stats(z_grid, obs['PDZ'])
    #    zp.append(medz)
    #    zs.append(obs['z_spec'])
    #except IndexError:
    #    pass

zp = np.array(zp)
zs = np.array(zs)

bias = np.abs(zp-zs)
outliers = np.nonzero(bias > 0.15*(1+zs))
outl_rate = 100. * len(zs[outliers]) / len(zs)

a.scatter(zs, zp, s=4, alpha=0.2, label=f"SPS: {outl_rate:.3f}% outliers", color='green')
a.plot(z_grid, z_grid, c='k', ls=':', lw=1)
a.plot(z_grid, z_grid+0.15*(1+z_grid), c='k', lw=2)
a.plot(z_grid, z_grid-0.15*(1+z_grid), c='k', lw=2)
a.set_xlabel('z_spec')
a.set_ylabel('z_phot')
a.legend()
a.grid()
#a.set_xlim(0., 3.1)
#a.set_ylim(0., 3.1)

In [None]:
print(f"SPS templates : {outl_rate:.3f}% outliers out of {len(zp)} successful fits.")