# Análise dos resultados - Métricas

_Autores: Andreia Dourado, Bruno Moraes_

_Adaptado dos notebooks(https://github.com/LSSTDESC/rail_tpz) and Demo: RAIL Evaluation notebook (https://rail-hub.readthedocs.io/projects/rail-notebooks/en/latest/rendered/evaluation_examples/Evaluation_Demo.html)_

__Descrição: Análise das métricas para os resultados gerados na etapa Estimate para o TPZ.__

### 1. Importando as bibliotecas:

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import rail
import qp
from rail.core.data import TableHandle, PqHandle, ModelHandle, QPHandle, DataHandle, Hdf5Handle
from rail.core.stage import RailStage

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
from qp import Ensemble
from matplotlib import gridspec
from qp import interp
from qp.metrics.pit import PIT
from rail.evaluation.metrics.cdeloss import *
from rail.evaluation.evaluator import OldEvaluator
from rail.evaluation.point_to_point_evaluator import PointToPointEvaluator
from rail.estimation.algos.point_est_hist import PointEstHistSummarizer
from rail.evaluation.metrics.cdeloss import *
from utils import plot_pit_qq, ks_plot
import os
from rail.estimation.algos.naive_stack import NaiveStackSummarizer
from scipy.interpolate import UnivariateSpline


%matplotlib inline
%reload_ext autoreload
%autoreload 

### 2. Leitura dos arquivos

##### A célula abaixo é apenas para facilitar a transicão entre os meus diretórios

In [None]:
path_true = '../dados_tcc/run_files/'

In [None]:
path_mode = '../dados_tcc/output/'

#### 2.1 Arquivo de teste utilizado no estimate:

In [None]:
ztrue_file= f'{path_true}test_file.hdf5'
ztrue_data = DS.read_file('ztrue_data', TableHandle, ztrue_file)

In [None]:
pdfs_file=f'../dados_tcc/output/output_test.hdf5'
tpzdata = DS.read_file('pdfs_data', QPHandle, pdfs_file)

#### 2.2 Multiplos arquivos:

In [None]:
pdfs_file1=f'../dados_tcc/output/output_test_mags+colors_minleaf30.hdf5'
tpzdata1 = DS.read_file('pdfs_data', QPHandle, pdfs_file1)

zgrid1 = tpzdata1.data[0].gen_obj.xvals
photoz_mode1 = tpzdata1().mode(grid=zgrid1)
z_mode1= np.squeeze(photoz_mode1)

In [None]:
pdfs_file2=f'../dados_tcc/output/output_test_mags_minleaf30.hdf5'
tpzdata2 = DS.read_file('pdfs_data', QPHandle, pdfs_file2)

zgrid2 = tpzdata2.data[0].gen_obj.xvals
photoz_mode2 = tpzdata2().mode(grid=zgrid2)
z_mode2= np.squeeze(photoz_mode2)

In [None]:
pdfs_file3=f'../dados_tcc/output/output_test_colors_minleaf30.hdf5'
tpzdata3 = DS.read_file('pdfs_data', QPHandle, pdfs_file3)

zgrid3 = tpzdata3.data[0].gen_obj.xvals
photoz_mode3 = tpzdata3().mode(grid=zgrid3)
z_mode3= np.squeeze(photoz_mode3)

In [None]:
ztrue = ztrue_data()['redshift']

In [None]:
zgrid = np.linspace(0, 3., 301)

In [None]:
len(ztrue), len(z_mode1), len(z_mode2), len(z_mode3)

In [None]:
#tpzdata1 = DS.read_file('pdfs_data', QPHandle, pdfs_file)

#### 2.3 Lendo os valores de redshift true e gerados no estimate:

In [None]:
ztrue = ztrue_data()['redshift']
zgrid = np.linspace(0,3,301)
photoz_mode = tpzdata().mode(grid=zgrid)
z_mode= np.squeeze(photoz_mode)

In [None]:
truth = DS.add_data('truth', ztrue_data(), TableHandle)
ensemble = DS.add_data('ensemble', tpzdata(), QPHandle)

In [None]:
len(z_mode), len(ztrue)

### 3. Métricas

__Caminho para salvar as imagens:__

In [None]:
path = '../dados_tcc/metrics/validation/'

#### 3.1 Zphot x Ztrue

In [None]:
def plot_scatter(zphot,
                 ztrue,
                 zmin=0,
                 zmax=3,
                 bins=150,
                 cmap='viridis',
                 line_color='red',
                 line_width=0.2,
                 title='$z_{true}$ vs $z_{phot}$',
                 xlabel='z$_{true}$',
                 ylabel='z$_{phot}$', 
                 fontsize_title=18,
                 fontsize_labels=15,
                 path_to_save=''):

    h = sns.histplot(x=ztrue, y=zphot, bins=bins, cmap=cmap)
    plt.plot([0,3], [0,3], color=line_color, linewidth=line_width)
    plt.xlim(zmin, zmax)
    plt.ylim(zmin, zmax)
    plt.xlabel(xlabel, fontsize=fontsize_labels)
    plt.ylabel(ylabel, fontsize=fontsize_labels)
    plt.title(title, fontsize=fontsize_title)
    plt.colorbar(h.collections[0], label='Número de objetos')


    
    plt.savefig(f'{path}scatter_mags.png')
    
    plt.show()

In [None]:
plot_scatter(z_mode,ztrue)

In [None]:
plt.figure(figsize=(8,8))
plt.scatter(ztrue,z_mode4, s=1,c='k')
plt.plot([0,3],[0,3],'b--')
plt.axvline(x=0.45, color='red', linestyle='--', alpha=0.7, label='g to r: 0.45')
plt.axvline(x=0.8, color='red', linestyle='--', alpha=0.7, label='r to i: 0.8')
plt.axvline(x=1.2, color='red', linestyle='--', alpha=0.7, label='i to z: 1.2')
plt.axvline(x=1.45, color='red', linestyle='--', alpha=0.7, label='z to y: 1.45')
plt.xlabel("redshift", fontsize=15)
plt.ylabel("TPZ mode", fontsize=15)
plt.legend(loc='upper right')
plt.savefig(f'{path}scatter_bands.png')

In [None]:
cutcriterion_all = np.maximum(0.06, 3*ptp_stage_single.get_handle('summary')()['point_stats_iqr'][0])
mask = (np.fabs(ez) > np.fabs(cutcriterion_all))
points=np.linspace(0,3.3,1000)

plt.scatter(ztrue[mask],zmode[mask],s=0.1,color='purple')
plt.scatter([],[],color='purple',s=13,label='outliers')
plt.scatter(ztrue[~mask],zmode[~mask],s=0.1,color='black')
plt.scatter(points,points+3*ptp_stage_single.get_handle('summary')()['point_stats_iqr'][0]*(1+points),color='blue',s=0.1)
plt.scatter(points,points-3*ptp_stage_single.get_handle('summary')()['point_stats_iqr'][0]*(1+points),color='blue',s=0.1)
plt.xlim(0,3.1)
plt.ylim(0,3.1)
plt.legend(fontsize=16,loc=2)
plt.xlabel('ztrue',fontsize=16)
plt.ylabel('zphot',fontsize=16)
plt.savefig(f'{path}scatter_outliers.png')

#### 3.2. PDF individual

In [None]:
import random
random.seed(60)
numeros = random.sample(range(29912), 12)
print(numeros)
j=1
fig, axs = plt.subplots(4, 3, figsize=(15, 10))
axs = axs.flatten() 
for j, i in enumerate(numeros):
    which= i
    ax = axs[j]
    tpzdata().plot_native(key=which,axes=ax, label=f"PDF for galaxy {which}")
    ax.axvline(ztrue[which],c='r',ls='--', label="spec-z")
    ax.axvline(z_mode[which],c='black',ls='--', label="photo-z mode")
    ax.legend(loc='upper right', fontsize=10)
    ax.set_xlabel("redshift")
    ax.set_title(f"Galaxy {which}")
    j+=1
plt.tight_layout()
plt.savefig(f'{path}example_pdfs_mags.png')
plt.show()

#### 3.3 Métricas básicas

##### Functions

In [None]:
def compute_photoz_metrics(zspec, zphot):
    
    delta_z = (zphot - zspec) / (1 + zspec)
    

    rms = np.sqrt(np.mean(delta_z**2))
    
    bias = np.mean(delta_z)

    sigma = np.std(delta_z)
    #sigma = np.sqrt(np.mean((delta_z - bias)**2))

    lower = np.percentile(delta_z, 15.87)
    upper = np.percentile(delta_z, 84.13)
    sigma_68 = 0.5 * (upper - lower)

    out_2sigma = np.sum(np.abs(delta_z) > 2 * sigma) / len(delta_z)

    out_3sigma = np.sum(np.abs(delta_z) > 3 * sigma) / len(delta_z)


    return {
        'bias': bias,
        'sigma_68': sigma_68,
        'sigma': sigma,
        'out_2sigma': out_2sigma,
        'out_3sigma': out_3sigma,
        'RMS': rms
    }

In [None]:
def plot_metrics_bias_sep(zspec, zphot, maximum, path_to_save='', title=None, initial=0):
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns

    sns.set_context("paper", font_scale=1.3)
    sns.set_style("whitegrid")
    plt.rcParams.update({
        "font.family": "serif",
        "axes.edgecolor": "black",
        "axes.linewidth": 1.2,
        "xtick.direction": "in",
        "ytick.direction": "in",
        "xtick.major.size": 5,
        "ytick.major.size": 5
    })

    bins = np.arange(initial, maximum, 0.1)
    points = bins + 0.05

    bias_list = []
    sigma_list = []
    sigma68_list = []
    out2_list = []
    out3_list = []

    for i in range(len(bins) - 1):
        zmin, zmax = bins[i], bins[i + 1]
        mask = (zphot >= zmin) & (zphot < zmax)
        zp, zs = zphot[mask], zspec[mask]

        if len(zp) == 0:
            bias_list.append(np.nan)
            sigma_list.append(np.nan)
            sigma68_list.append(np.nan)
            out2_list.append(np.nan)
            out3_list.append(np.nan)
            continue

        dz = zp - zs
        bias = np.mean(dz / (1 + zs))
        sigma = np.std(dz / (1 + zs))
        sorted_dz = np.sort(np.abs(dz / (1 + zs)))
        sigma68 = sorted_dz[int(len(sorted_dz) * 0.68)]
        out2 = np.mean(np.abs(dz - bias) > 2 * sigma)
        out3 = np.mean(np.abs(dz - bias) > 3 * sigma)

        bias_list.append(bias)
        sigma_list.append(sigma)
        sigma68_list.append(sigma68)
        out2_list.append(out2)
        out3_list.append(out3)


    fig, (ax_top, ax_bot) = plt.subplots(2, 1, figsize=(10, 8), sharex=True, gridspec_kw={'height_ratios': [2, 1]})
    plt.subplots_adjust(hspace=0.05)


    ax_top.plot(points[:-1], sigma68_list, 'o-', label=r'$\sigma_{68}$', color='forestgreen')
    ax_top.plot(points[:-1], out2_list, 'o-', label=r'Outliers 2$\sigma$', color='darkorange')
    ax_top.plot(points[:-1], out3_list, 'o-', label=r'Outliers 3$\sigma$', color='crimson')
    #ax_top.axhline(0.12, linestyle='--', color='gray', lw=1, label='limite σ₆₈')
    #ax_top.axhline(0.1, linestyle=':', color='gray', lw=1, label='limite outliers')
    ax_top.set_ylabel("Métricas", fontsize=13)
    ax_top.set_xlim(initial, maximum)
    ax_top.legend()
    ax_top.grid(True)


    ax_bot.plot(points[:-1], bias_list, 'o-', color='royalblue', label='Bias (Δz)')
    ax_bot.fill_between(points[:-1],
                        np.array(bias_list) - np.array(sigma_list),
                        np.array(bias_list) + np.array(sigma_list),
                        color='royalblue', alpha=0.2, label='±1σ')

    ax_bot.axhline(0, linestyle='--', color='gray', lw=1)
    ax_bot.set_xlabel(r'$z_{\mathrm{spec}}$', fontsize=13)
    ax_bot.set_ylabel(r'$\Delta z$', fontsize=13)
    ax_bot.set_xlim(initial, maximum)
    ax_bot.grid(True)
    ax_bot.legend()

    if title:
        fig.suptitle(title, fontsize=16)

    plt.tight_layout(rect=[0, 0, 1, 0.96])

    plt.savefig(f'{path}metrics.png', dpi=300, bbox_inches='tight')
    plt.show()


##### Plots

In [None]:
plot_metrics_bias_sep(ztrue, z_mode, max(z_mode))

In [None]:
compute_photoz_metrics(ztrue,z_mode)

#### 3.4 PIT QQ

In [None]:
pitobj = PIT(tpzdata(), ztrue)
quant_ens = pitobj.pit
metamets = pitobj.calculate_pit_meta_metrics()

In [None]:
metamets

In [None]:
pit_vals = np.array(pitobj.pit_samps)
pit_vals

In [None]:
pit_out_rate = metamets['outlier_rate']
print(f"PIT outlier rate of this sample: {pit_out_rate:.6f}")
pit_out_rate = pitobj.evaluate_PIT_outlier_rate()
print(f"PIT outlier rate of this sample: {pit_out_rate:.6f}")

In [None]:
pdfs = tpzdata.data.objdata['yvals']

In [None]:
pdfs

In [None]:
#qualidade de impressão
from utils import plot_pit_qq, ks_plot

plot_pit_qq(pdfs, zgrid, ztrue, title="PIT-QQ - toy data", code="TPZ",
                pit_out_rate=pit_out_rate, savefig=True)

#### 3.5 N(z)

In [None]:
stacker = NaiveStackSummarizer.make_stage(zmin=0.0, zmax=3, nzbins=301, nsamples=20, hdf5_groupname=None, output=f"Naive_sample.hdf5", single_NZ=f"NaiveStack_TPZ.hdf5")

In [None]:
naive_results = stacker.summarize(tpzdata)

In [None]:
fig = plt.figure(figsize=(8, 6))


plt.xlabel('redshift', fontsize=17)
plt.ylabel('density', fontsize=17)
#plt.grid(color='gray', linewidth=0.5)
#plt.axvline(x=0.45, color='black', linestyle=':', alpha=0.7, label='g to r: 0.45')
#plt.axvline(x=0.8, color='black', linestyle=':', alpha=0.7, label='r to i: 0.8')
#plt.axvline(x=1.2, color='black', linestyle=':', alpha=0.7, label='i to z: 1.2')
#plt.axvline(x=1.45, color='black', linestyle=':', alpha=0.7, label='z to y: 1.45')
# Histograma do ztrue com cor sólida
z = plt.hist(ztrue, bins=50, density=True, color='gray', label='z_true', alpha=0.5)
# Histograma do photoz_mode com transparência
zmode = plt.hist(photoz_mode, bins=50, density=True, color='red', label='z_phot', alpha=1, histtype='step')#, linestyle='--')
#zmode2 = plt.hist(photoz_mode2, bins=50, density=True, color='blue', label='z_phot magnitudes', alpha=1, histtype='step', linestyle='--')
#zmode1 = plt.hist(photoz_mode1, bins=50, density=True, color='green', label='z_phot magnitudes+cores', alpha=1, histtype='step')
#zmode3 = plt.hist(photoz_mode3, bins=50, density=True, color='red', label='z_phot cores', alpha=1, histtype='step', linestyle='-.' )

# Legenda
plt.legend(fontsize=10)
#plt.title('Minleaf = 30')

# Salvar em alta qualidade
plt.savefig(f'../dados_tcc/metrics/validation/hist_z_true_mags.png')
#plt.savefig('com_SN.png')
plt.show()

In [None]:
cs = UnivariateSpline(zgrid[:-1], z[0])
cs.set_smoothing_factor(0.2)

In [None]:
varinf_nz = qp.read(f"NaiveStack_TPZ.hdf5")
#varinf_nz1 = qp.read(f"NaiveStack_TPZ_1.hdf5")
#varinf_nz2 = qp.read(f"NaiveStack_TPZ_2.hdf5")
#varinf_nz4 = qp.read(f"../pkl-files/metricas/NaiveStack_GPZ.hdf5")

plt.plot(zgrid,varinf_nz.pdf(zgrid), color = 'red', label = 'z$_{phot}$')#, linestyle='--')
#plt.plot(zgrid,varinf_nz1.pdf(zgrid), color = 'blue', label = 'z$_{phot}$ magnitudes+cores', linestyle='--')
#plt.plot(zgrid,varinf_nz2.pdf(zgrid), color = 'red', label = 'z$_{phot}$ magnitudes')
#plt.plot(zgrid,varinf_nz4.pdf(zgrid), color = 'darkorange', label = 'z$_{phot}$ GPZ', linestyle='-.')
plt.fill_between(zgrid, cs(zgrid), color='gray', alpha=0.5, label='z$_{spec}$') #plt.plot(zgrid,cs(zgrid), color = 'gray', label = 'z$_{spec}$ PDF')
plt.legend(fontsize = 15)
plt.xlabel('z', fontsize=17)
plt.ylabel('p(z)', fontsize=17)
plt.savefig(f'../dados_tcc/metrics/validation/n(z)_mags.png')