In [None]:
%matplotlib inline

import os
import pandas as pd
import numpy as np
from jax import numpy as jnp
from jax import vmap, jit
from jax.tree_util import tree_map
from jax import random as jrn
import seaborn as sns
import matplotlib.pyplot as plt
from process_fors2.photoZ import read_h5_table, read_params, DATALOC, sedpyFilter, load_filt, get_2lists
from process_fors2.fetchData import json_to_inputs
from process_fors2.stellarPopSynthesis import SSPParametersFit, load_ssp
from interpax import interp1d
from tqdm import tqdm

In [None]:
_DUMMY_PARS = SSPParametersFit()
dumpars_df = pd.DataFrame(
    index=_DUMMY_PARS.PARAM_NAMES_FLAT,
    columns=["INIT", "MIN", "MAX"],
    data=jnp.column_stack(
        (
            _DUMMY_PARS.INIT_PARAMS,
            _DUMMY_PARS.PARAMS_MIN,
            _DUMMY_PARS.PARAMS_MAX
        )
    )
)
fixed_pars_names = _DUMMY_PARS.PARAM_NAMES_FLAT[:13]+_DUMMY_PARS.PARAM_NAMES_FLAT[14:]

In [None]:
dsps_out_h5 = 'dsps_valid_fits_F2_GG_DESI_SM3.h5' #'dsps_40best_fits_F2_GG_DESI_SM3.h5'
obs_inp_h5 = 'COSMOS2020_emu_hscOnly_CC_zinf3_noNaN.h5' # 'COSMOS2020_emu_CC.h5'
inputs_json = 'conf_IDRIS_PZ_TemplSel.json' # 'conf_FORS2_SM3.json' #'conf_IDRIS_cosmos2020_allFilts_noPrior.json' #

In [None]:
inputs_glob = json_to_inputs(inputs_json)
inputs_pz = inputs_glob['photoZ']

In [None]:
clrh5file = f"pz_inputs_iclrs_{os.path.basename(obs_inp_h5)}" if inputs_pz["i_colors"] else f"pz_inputs_{os.path.basename(obs_inp_h5)}"

In [None]:
inputs_pz['Dataset']

In [None]:
inputs_pz['Filters']

In [None]:
filters_dict = inputs_pz["Filters"]
filters_names = [_f["name"] for _, _f in filters_dict.items()]

In [None]:
from sedpy import observate
observate.list_available_filters()

In [None]:
from process_fors2.photoZ import load_data_for_run

z_grid, wl_grid, transm_arr, pars_arr, zref_arr, templ_classif, i_mag_ab, ab_colors, ab_cols_errs, z_specs, ssp_data = load_data_for_run(inputs_glob)

In [None]:
print(ab_colors.shape, ab_cols_errs.shape, len(filters_names))

In [None]:
color_names = [f"{n1}-{n2}" for n1,n2 in zip(filters_names[:-1], filters_names[1:])]
color_err_names = [f"{n1}-{n2}_err" for n1,n2 in zip(filters_names[:-1], filters_names[1:])]
obs_df = pd.read_hdf(clrh5file, key='pz_inputs')

## Select a training set and a test set

In [None]:
key = jrn.key(717)
key, subkey = jrn.split(key)
train_sel = jrn.choice(subkey, obs_df.shape[0], shape=(min(20*obs_df.shape[0]//100, 20000),), replace=False) # 20% of data is selected
del subkey

In [None]:
train_sel = jnp.sort(train_sel, axis=0)

In [None]:
train_df = obs_df.iloc[train_sel]
train_df.shape

In [None]:
obs_df.shape[0]

In [None]:
ind_array = np.arange(obs_df.shape[0])
test_sel = np.isin(
    ind_array,
    train_sel,
    invert=True,
)
test_df = obs_df.iloc[test_sel]
test_df.shape

In [None]:
train_sel

In [None]:
test_sel

In [None]:
obs_df.shape[0]-(test_df.shape[0]+train_df.shape[0])

In [None]:
f,a = plt.subplots(1,1)
sns.histplot(data=obs_df, x='redshift', label='Full sample', stat='density')
sns.histplot(data=train_df, x='redshift', label='Training sample', stat='density')
#sns.histplot(data=test_df, x='z_spec', label='Test sample')
a.legend(loc='best')
#plt.show()

## Binning observations in redshift

In [None]:
z_grid = np.histogram_bin_edges(train_df['redshift'], bins='auto')
z_mids = 0.5*jnp.array(z_grid[:-1]+z_grid[1:])

## Loading templates

In [None]:
from process_fors2.photoZ import make_sps_templates, make_legacy_templates, read_h5_table
from process_fors2.stellarPopSynthesis import istuple, ssp_spectrum_fromparam, vmap_calc_obs_mag

#z_grid = jnp.linspace(z_specs.min(), z_specs.max(), num=10, endpoint=True)
av_arr = jnp.linspace(dumpars_df.loc['AV', 'MIN'], dumpars_df.loc['AV', 'MAX'], num=4, endpoint=True)
pars_arr, zref_arr, templ_classif = read_h5_table(dsps_out_h5)

In [None]:
templs_ref_df = pd.read_hdf(dsps_out_h5, key='fit_dsps')
templs_ref_df

In [None]:
def get_colors_templates_av(params, av, wls, z_obs, transm_arr, ssp_data):
    _pars = params.at[13].set(av)
    ssp_wave, _, sed_attenuated = ssp_spectrum_fromparam(_pars, z_obs, ssp_data)
    _mags = vmap_calc_obs_mag(ssp_wave, sed_attenuated, wls, transm_arr, z_obs)
    return _mags[:-1]-_mags[1:]

vmap_cols_av = vmap(get_colors_templates_av, in_axes=(None, 0, None, None, None, None))
vmap_cols_av_zo = vmap(vmap_cols_av, in_axes=(None, None, None, 0, None, None))
vmap_cols_av_templ = vmap(vmap_cols_av_zo, in_axes=(0, None, None, None, None, None))

In [None]:
def get_colors_templates(params, wls, z_obs, transm_arr, ssp_data):
    #_pars = params.at[13].set(0.)
    ssp_wave, _, sed_attenuated = ssp_spectrum_fromparam(params, z_obs, ssp_data)
    _mags = vmap_calc_obs_mag(ssp_wave, sed_attenuated, wls, transm_arr, z_obs)
    return _mags[:-1]-_mags[1:]

vmap_cols_zo = vmap(get_colors_templates, in_axes=(None, None, 0, None, None))
vmap_cols_templ = vmap(vmap_cols_zo, in_axes=(0, None, None, None, None))

In [None]:
%%time
templ_tupl = [tuple(_pars) for _pars in pars_arr]
templ_tupl_sps = tree_map(lambda partup: vmap_cols_zo(jnp.array(partup), wl_grid, z_grid, transm_arr[:-2], ssp_data), templ_tupl, is_leaf=istuple)

In [None]:
templ_tupl_sps[0].shape

In [None]:
templs_as_dict = {}
for it, (tname, row) in enumerate(templs_ref_df.iterrows()):
    _colrs = templ_tupl_sps[it]
    _df = pd.DataFrame(columns=color_names, data=_colrs)
    _df['z_p'] = z_grid
    _df['Dataset'] = np.full(z_grid.shape, row['Dataset'])
    _df['name'] = np.full(z_grid.shape, tname)
    templs_as_dict.update({f"{tname}": _df})

In [None]:
all_templs_df = pd.concat([_df for _, _df in templs_as_dict.items()], ignore_index=True)
all_templs_df.shape

In [None]:
import plotly.express as px

for ix, (c1, c2) in enumerate(zip(color_names[:-1], color_names[1:])):
    fig = px.scatter(
        train_df,
        x=c1,
        y=c2,
        color='redshift',
    )
    fig.show()

In [None]:
fig = px.scatter(
    train_df,
    x='i_mag',
    y='redshift',
)
fig.show()

In [None]:
#sel = np.isfinite(obs_df[color_names[0]]) * np.isfinite(obs_df[color_names[1]])* np.isfinite(obs_df['z_spec'])

for ix, (c1, c2) in enumerate(zip(color_names[:-1], color_names[1:])):
    f, a = plt.subplots(1, np.unique(all_templs_df['Dataset']).shape[0], sharey=True, constrained_layout=True)
    for _iax, (_ds, _m) in enumerate(zip(np.unique(all_templs_df['Dataset']), ['+', 'x', '*'])):
        sns.scatterplot(
            data=train_df,
            x=c1,
            y=c2,
            edgecolor='k',
            facecolor='none',
            size='redshift',
            marker='.',
            sizes=(10, 100),
            ax=a[_iax],
            legend=False,
            alpha=0.2
        )
        sns.scatterplot(
            data=all_templs_df[all_templs_df['Dataset']==_ds],
            x=c1,
            y=c2,
            ax=a[_iax],
            size='z_p',
            sizes=(10, 100),
            legend='brief' if _iax==0 else False,
            marker=_m
        )
        #a.legend(loc='lower right', bbox_to_anchor=(1.0, 0.0))
        a[_iax].legend(loc='best')
        a[_iax].set_title(_ds)
    #plt.show()

## Bin training data and templates data in same color bins

In [None]:
import matplotlib.patches as mpatches

train_patch = mpatches.Patch(edgecolor='k', facecolor='grey', label='COSMOS2020', alpha=0.7)

list_edges = []
for idc, c in enumerate(color_names):
    _arr = np.array(train_df[c])
    H_data_1D, _edges1d = np.histogram(_arr[np.isfinite(_arr)], bins=60) #, bins='auto') #
    H_templ_1d, _edges1d = np.histogram(np.array(all_templs_df[c]), bins=_edges1d) 
    #H_data_1D, _edges1d = np.histogram(_arr[np.isfinite(_arr)], bins='auto')
    #H_templ_1d, _edges1d = np.histogram(np.array(all_templs_df[c]), bins=_edges1d)
    list_edges.append(_edges1d)
    
    f,a = plt.subplots(1,1)

    sns.histplot(
        data=train_df,
        x=c,
        bins=_edges1d,
        stat='density',
        label='COSMOS2020',
        color='grey',
        ax=a,
        legend=False
    )

    sns.histplot(
        data=all_templs_df,
        x=c,
        bins=_edges1d,
        stat='density',
        multiple='stack',
        hue='Dataset',
        alpha=0.7,
        ax=a,
        legend=True
    )

    old_legend = a.get_legend()
    handles = old_legend.legend_handles
    labels = [t.get_text() for t in old_legend.get_texts()]
    title = old_legend.get_title().get_text()
    
    a.legend(handles=[train_patch]+handles, labels=['COSMOS2020 (Training)']+labels, title=title, loc='best')
    
    plt.show()

In [None]:
coords = []
for c, b in zip(color_names, list_edges):
    c_idxs = np.digitize(train_df[c], b)
    coords.append(c_idxs)
coords = np.column_stack(coords)
coords.shape

In [None]:
templ_coords = []
for c, b in zip(color_names, list_edges):
    c_idxs = np.digitize(all_templs_df[c], b)
    templ_coords.append(c_idxs)
templ_coords = np.column_stack(templ_coords)
templ_coords.shape

In [None]:
train_df[[f'{c}_bin' for c in color_names]] = coords

In [None]:
all_templs_df[[f'{c}_bin' for c in color_names]] = templ_coords

In [None]:
all_templs_df[[f'{c}_bin' for c in color_names]]

### Compute a score for each template in each bin
$\displaystyle \frac{1}{N_{obs}} \sum_\text{obs. i} \left( \frac{z_p-z_s^i}{1+z_s^i} \right)^2$

Then sum over bins for each template... The smallest score wins.

In [None]:
if False:
    for itempl, row in tqdm(all_templs_df.iterrows(), total=all_templs_df.shape[0]):
        scores = []
        for c in color_names:
            cbin = row[f'{c}_bin']
            sel = train_df[f'{c}_bin']==cbin
            _sel_df = train_df[sel]
            zs = jnp.array(_sel_df['redshift'].values)
            scores.append(
                jnp.sum(jnp.power((zs-row['z_p'])/(1+zs), 2)) / zs.shape[0] if zs.shape[0]>0 else jnp.nan
            )
        score = jnp.nansum(jnp.array(scores))
        all_templs_df.loc[itempl, 'score'] = score
        if not jnp.isfinite(score):
            all_templs_df.loc[itempl, 'name'] = None

    score_df = all_templs_df[['name', 'score']]
    grp_mean_score = score_df.groupby(by='name', dropna=True, axis=0).mean()
    
    templs_score_df = templs_ref_df.join(grp_mean_score, how='inner')
    templs_score_df.sort_values('score', ascending=True, inplace=True)
    templs_score_df

$ \displaystyle \frac{1}{N_{obs}} \frac{\left|z_p-z_s^i \right|}{1+z_s^i} \lt 0.15 $ on each color bin, select the template with the smallest score and keep the unique list at the end.

In [None]:
%%time

best_templs_names = []
allbestscores = []
for c in color_names:
    for cbin in tqdm(jnp.unique(train_df[f'{c}_bin'].values)):
    #cbin = row[f'{c}_bin']
        sel = train_df[f'{c}_bin']==cbin
        _sel_df = train_df[sel]
        zs = jnp.array(_sel_df['redshift'].values)
        sel_templ = all_templs_df[f'{c}_bin']==cbin
        _templ_df = all_templs_df[sel_templ]
        scores = jnp.array(
            [
                jnp.sum(jnp.abs(zs-zp)/(1+zs)) / zs.shape[0] if zs.shape[0]>0 else jnp.nan for zp in _templ_df['z_p']
            ]
        )
        if scores.shape[0]>0 and not jnp.all(jnp.isnan(scores)):
            ix_best = int(jnp.nanargmin(scores))
            bestscore = scores[ix_best]
            if bestscore < 0.15:
                best_templs_names.append(_templ_df['name'].iloc[ix_best])
                allbestscores.append(scores[ix_best])
            
best_templ_sels = np.unique(best_templs_names)
allbestscores = jnp.array(allbestscores)

print(len(best_templ_sels), allbestscores.shape)

In [None]:
meanscores = []
for it, nt in tqdm(enumerate(best_templ_sels), total=len(best_templ_sels)):
    _sel = jnp.array([_t==nt for _t in best_templs_names])
    _sc = allbestscores[_sel]
    meanscores.append(jnp.nanmean(_sc))
meanscores = jnp.array(meanscores)

In [None]:
#_templsel = [_n in best_templ_sels for _n in templs_ref_df['name']]
templs_score_df = templs_ref_df.loc[best_templ_sels] #templs_ref_df[_templsel]
for msc, tn in zip(meanscores, best_templ_sels):
    templs_score_df.loc[tn, 'score'] = msc
    templs_score_df.loc[tn, 'name'] = tn
templs_score_df.sort_values('score', ascending=True, inplace=True)
templs_score_df

In [None]:
if False: templs_score_df.to_hdf('templ_NEWscoredOnTraining_SPS.h5', key='fit_dsps')

In [None]:
templ_select_df = templs_score_df #.iloc[:40]

In [None]:
#if True: templ_select_df.to_hdf('templSPS_best40scored_F2_GG_DESI_SM3.h5', key='fit_dsps')

In [None]:
%%time
infile = 'templ_NEWscoredOnTraining_SPS.h5'

pars_arr, zref_arr, templ_classif = read_h5_table(infile)

templ_select_df = pd.read_hdf(infile, key='fit_dsps')

templ_tupl = [tuple(_pars) for _pars in pars_arr]
templ_tupl_sps = tree_map(lambda partup: vmap_cols_zo(jnp.array(partup), wl_grid, z_grid, transm_arr[:-2], ssp_data), templ_tupl, is_leaf=istuple)

templs_as_dict = {}
for it, (tname, row) in enumerate(templ_select_df.iterrows()):
    _colrs = templ_tupl_sps[it]
    _df = pd.DataFrame(columns=color_names, data=_colrs)
    _df['z_p'] = z_grid
    _df['Dataset'] = np.full(z_grid.shape, row['Dataset'])
    _df['name'] = np.full(z_grid.shape, tname)
    templs_as_dict.update({f"{tname}": _df})

all_tsels_df = pd.concat([_df for _, _df in templs_as_dict.items()], ignore_index=True)
all_tsels_df.shape

In [None]:
f,a = plt.subplots(1,3, sharey=True)
sns.histplot(templ_select_df, x='fun_val', hue='Dataset', multiple='stack', ax=a[0])
sns.histplot(templ_select_df, x='redshift', hue='Dataset', multiple='stack', ax=a[1])
sns.histplot(templ_select_df, x='score', hue='Dataset', multiple='stack', ax=a[2])
plt.show()

In [None]:
import matplotlib.lines as mlines

leg1 = mlines.Line2D([], [], color='gray', label='COSMOS2020', marker='o', markersize=6, alpha=0.7, ls='')

#sel = np.isfinite(obs_df[color_names[0]]) * np.isfinite(obs_df[color_names[1]])* np.isfinite(obs_df['z_spec'])
for ix, (c1, c2) in enumerate(zip(color_names[:-1], color_names[1:])):
    f,a = plt.subplots(1,1, constrained_layout=True)
    # Create a legend for the first line.
    #first_legend = a.legend(handles=[leg1], loc='upper left')
    
    sns.scatterplot(
        data=train_df,
        x=c1,
        y=c2,
        c='gray',
        size='redshift',
        sizes=(10, 100),
        ax=a,
        legend=False,
        alpha=0.2
    )
    
    sns.scatterplot(
        data=all_tsels_df,
        x=c1,
        y=c2,
        ax=a,
        size='z_p',
        sizes=(10, 100),
        alpha=0.5,
        hue='Dataset',
        style='Dataset',
        legend='brief'
    )

    handles, labels = a.get_legend_handles_labels()
    a.legend(handles=[handles[0]]+[leg1]+handles, labels=['Training set']+['COSMOS2020']+labels)
    
    # Add the legend manually to the Axes.
    #a.add_artist(first_legend)
    
    #a.legend(loc='lower right', bbox_to_anchor=(1.0, 0.0))
    #a.legend(loc='lower right')
    plt.show()

In [None]:
from process_fors2.stellarPopSynthesis import mean_sfr, vmap_mean_sfr
from dsps.cosmology import DEFAULT_COSMOLOGY, age_at_z0
TODAY_GYR = age_at_z0(*DEFAULT_COSMOLOGY)
T_ARR = jnp.linspace(0.1, TODAY_GYR, 100)

In [None]:
all_sfh = vmap_mean_sfr(pars_arr)
all_sfh.shape

In [None]:
srcs = np.unique(templ_select_df['Dataset'].values)
srcs

In [None]:
cdict = {srcs[0]: 'tab:blue', srcs[1]: 'tab:orange', srcs[2]: 'tab:green'}

In [None]:
import matplotlib.lines as mlines

f, a = plt.subplots(1,1)
for sfh, src in zip(all_sfh, templ_select_df['Dataset'], strict=True):
    a.plot(T_ARR, sfh, lw=1, ls='-', c=cdict[src])
    a.set_xlabel('Age of the Universe [Gyr]')
    a.set_ylabel('SFR '+r"$\mathrm{M_\odot.yr}^{-1}$")
    a.set_title('SFH of photo-z templates')

legs = []
for src, colr in cdict.items():
    _line = mlines.Line2D([], [], color=colr, label=src, lw=1)
    legs.append(_line)
a.legend(handles=legs)
plt.show()

In [None]:
f, a = plt.subplots(1,1)
for sfh, src in zip(all_sfh, templ_select_df['Dataset'], strict=True):
    a.semilogy(T_ARR, sfh, lw=1, ls='-', c=cdict[src])
    a.set_xlabel('Age of the Universe [Gyr]')
    a.set_ylabel('Specific SFR [-]')
    a.set_title('SFH of photo-z templates')
a.set_ylim(1e-3, 5e2)

legs = []
for src, colr in cdict.items():
    _line = mlines.Line2D([], [], color=colr, label=src, lw=1)
    legs.append(_line)
a.legend(handles=legs)
plt.show()

In [None]:
import matplotlib.patches as mpatches

train_patch = mpatches.Patch(edgecolor='k', facecolor='grey', label='COSMOS2020', alpha=0.7)

list_edges = []
for idc, c in enumerate(color_names):
    _arr = np.array(train_df[c])
    H_data_1D, _edges1d = np.histogram(_arr[np.isfinite(_arr)], bins=60) #, bins='auto') #
    H_templ_1d, _edges1d = np.histogram(np.array(all_tsels_df[c]), bins=_edges1d) 
    #H_data_1D, _edges1d = np.histogram(_arr[np.isfinite(_arr)], bins='auto')
    #H_templ_1d, _edges1d = np.histogram(np.array(all_templs_df[c]), bins=_edges1d)
    list_edges.append(_edges1d)
    
    f,a = plt.subplots(1,1)

    sns.histplot(
        data=train_df,
        x=c,
        bins=_edges1d,
        stat='density',
        label='COSMOS2020',
        color='grey',
        ax=a,
        legend=False
    )

    sns.histplot(
        data=all_tsels_df,
        x=c,
        bins=_edges1d,
        stat='density',
        multiple='stack',
        hue='Dataset',
        alpha=0.7,
        ax=a,
        legend=True
    )

    old_legend = a.get_legend()
    handles = old_legend.legend_handles
    labels = [t.get_text() for t in old_legend.get_texts()]
    title = old_legend.get_title().get_text()
    
    a.legend(handles=[train_patch]+handles, labels=['COSMOS2020 (Training)']+labels, title=title, loc='best')
    
    plt.show()