## **07.a Binning**,
#### Authors: **Amanda Farias (afariassantos2@gmail.com), Iago Lopes (iagolops2012@gmail.com), Bruno Moraes (bruno.a.l.moraes@gmail.com)**
#### Creation date: **09/13/2024**
#### Last Verifed to Run: **11/19/2024** (by @iago)


The objective of this notebook is to obtain the binned redshift distribution of the dataset. First, we will create the redshift distribution for each sample: for the Roman Rubin (true) redshifts, we will use the *TrueNZHistogrammer* to generate histograms, and for the observed redshifts, we will stack the PDFs using the *NaiveStackSummarizer*. The resulting distribution will then be binned according to the __[LSST DESC SRD](https://arxiv.org/pdf/1809.01669)__. The lens sample, using the MagLim cut, will be divided into 5 bins with a width of $\Delta z$ = 0.2, ranging from 0.2 to 1.2. The source sample will be divided into 5 bins, each containing an equal number of galaxies.


##### Logistics: This notebook is intended to be run through the Jupyter Lab NERSC interface available in __[Jupyter nersc](https://jupyter.nersc.gov/)__ in the **desc-python** kernel.

<div class="alert alert-block alert-info">
<b>ATTENTION:</b> We intend to change the binning functions for RAIL functions. It's not implemented yet.
</div>

In [None]:
from enum import Enum
import numpy as np 
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
from scipy import integrate
from rail.estimation.algos.naive_stack import NaiveStackSummarizer
from rail.estimation.algos.true_nz import TrueNZHistogrammer
from rail.core.stage import RailStage
import scipy
import tables_io
from scipy.special import erf
import h5py
import pandas as pd
import qp
from scipy.interpolate import UnivariateSpline
from rail.evaluation.dist_to_dist_evaluator import DistToDistEvaluator

from rail.core.data import (
    QPHandle,
    TableHandle,
    Hdf5Handle,
)

import gc
from binning import *

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

## Loading the dataset 
This dataset was evaluated in notebook 04a and it contains the true redshift, observed redshift and the PDFs of each galaxy. 

<div class="alert alert-block alert-warning">
<b>ATTENTION:</b> In order to run this notebook on your NERSC account, please update the kernel to use your username.
</div> 

In [None]:
nersc_name = 'iago'

In [None]:
path= "/pscratch/sd/" + nersc_name[0] + "/" + nersc_name 

In [None]:
############## Download the catalog with the photo-z #############

result = DS.read_file('pdfs_data', QPHandle, 
                      f'{path}/results_y1/output_estimate_a_roman_fzb_y1_10sig.hdf5')

############### Max redshift ##############
z_max = result().build_tables()['meta']['xvals'][0][-1]

############ X values for plot ############
zgrid = np.linspace(0, z_max, 301)
    
################## PDFS ##################
pdfs = result().build_tables()['data']['yvals']

########## Median or mode of PDF ##########
mode = result().mode(zgrid)

#### Array of photo-z with 132891 galaxies
zphot  = np.array([valor for sublista in mode for valor in sublista])

##############################################
################ True redshift ###############
##############################################

catalog = pd.read_csv(f'{path}/roman_rubin_y1_a_test_10sig.csv', sep=' ')
catalog['zphot'] = zphot
ztrue = catalog['redshift']

# Observed redshift distribution

The Naive stack method sums the PDFs of all galaxies without considering weights or other parameters. In the ``NaiveStackSummarizer`` from RAIL, the entire redshift range is divided into bins, and the PDFs of all galaxies within each bin are summed. Subsequently, the code normalizes these summed PDFs.

For our configuration, we utilize the redshift range defined by the output, starting from 0 to z_max, and divide this range into 300 bins.

For a more detailed explanation of each component in the ``NaiveStackSummarizer``, please refer to the __[RAIL documentation](https://github.com/LSSTDESC/rail/blob/6f4e15315962b3010dbd52eb2c4e308710df9b87/docs/source/new_rail_stage.rst#L90)__ on GitHub.

### LENS MagLim

In [None]:
a_vector = [3.0,3.5,4.0,4.5]
b_vector = [17.5,18.0,18.5,19.0]

#### Analyzing the metrics for all cuts

In [None]:
%%capture
metrics = {}
for a in a_vector:
    for b in b_vector:
        mask = catalog['mag_i_lsst'] < b + a*catalog['zphot']
        total = len(catalog[mask]['redshift'])

        create_filtered_hdf5_files(mask, path, zphot)
        process_bins()
        create_histograms()
        sizes = []
        
        for i in range(5):
            mask_iter = (catalog['mag_i_lsst'] < b + a*catalog['zphot']) & (catalog["mag_i_lsst"] > 17.5) & (catalog['zphot'] > 0.2 + 0.2*i) & (catalog['zphot'] < 0.4 + 0.2*i)
            zphot_lens = zphot[mask_iter]
            pdfs_mask = result.data[mask_iter]
            tables = pdfs_mask.build_tables()
            sizes.append(tables['data']['yvals'].shape[0])
            pdfs_qp = DS.add_data(key=f'maglim_pdfs_{i}', handle_class=QPHandle, data=pdfs_mask)
            
    
            naive_stack_lens_phot  = NaiveStackSummarizer.make_stage(zmin=0.0, zmax=3.0, nzbins=301, nsamples = 25, hdf5_groupname='', chunk_size=400000, name=f'naive_stack_lens_phot_bin{i}')
            naive_results_lens_phot = naive_stack_lens_phot.summarize(pdfs_qp)

        metrics[f'{a}_{b}'] = metrics_sample(num_bins=5, param=(a,b), sizes=sizes)

In [None]:
colors = ['#4d4d4d', '#08306b', '#6baed6', '#ffcc00', '#ffb347']

def plot_metrics(metric_index, title, ylabel):
    plt.figure(figsize=(12, 6))
    for i in range(5):
        x_values = []
        y_values = []
        for a in a_vector:
            for b in b_vector:
                x_values.append(f'{a}_{b}')
                y_values.append(metrics[f'{a}_{b}'][metric_index][i])
        plt.plot(x_values, y_values, color=colors[i], marker='o', label=f'Bin {i+1}')
    plt.axhline(ls='--', color='black')
    plt.title(title, fontsize=20)
    plt.ylabel(ylabel, fontsize=16)
    plt.xticks(rotation=45)
    plt.legend(loc='best', fontsize=16)
    plt.tight_layout()
    plt.show()

plot_metrics(metric_index=0, title='Mean Bias', ylabel='Mean Bias')
plot_metrics(metric_index=1, title='Sigma Bias', ylabel='Std Bias')
plot_metrics(metric_index=2, title=r'$\frac{N}{arcmin^2}$', ylabel=r'$\frac{N}{arcmin^2}$')
plot_metrics(metric_index=3, title=r'Sigma truth', ylabel=r'Sigma truth')


### Binning for each sample

In [None]:
for a in a_vector:
    for b in b_vector:
        mask = catalog['mag_i_lsst'] < b + a*catalog['zphot']
        total = len(catalog[mask]['redshift'])

        create_filtered_hdf5_files(mask, path, zphot)
        process_bins()
        create_histograms()
        sizes = []
        
        for i in range(5):
            mask_iter = (catalog['mag_i_lsst'] < b + a*catalog['zphot']) & (catalog["mag_i_lsst"] > 17.5) & (catalog['zphot'] > 0.2 + 0.2*i) & (catalog['zphot'] < 0.4 + 0.2*i)
            zphot_lens = zphot[mask_iter]
            pdfs_mask = result.data[mask_iter]
            tables = pdfs_mask.build_tables()
            sizes.append(tables['data']['yvals'].shape[0])
            pdfs_qp = DS.add_data(key=f'maglim_pdfs_{i}', handle_class=QPHandle, data=pdfs_mask)
            
    
            naive_stack_lens_phot  = NaiveStackSummarizer.make_stage(zmin=0.0, zmax=3.0, nzbins=301, nsamples = 25, hdf5_groupname='', chunk_size=400000, name=f'naive_stack_lens_phot_bin{i}')
            naive_results_lens_phot = naive_stack_lens_phot.summarize(pdfs_qp)

        plot_nz_from_bins(num_bins=5, param=(a,b), sizes=sizes)

### SOURCE

After defining which lens sample we are going to use, we can create the source sample by removing galaxies that are in lens.

In [None]:
num_galaxies = len(source_catalog)

num_bins = 5
galaxies_per_bin = num_galaxies // num_bins

sorted_indices = np.argsort(source_catalog['zphot'])

for bin_i in range(num_bins):
    start_index = bin_i * galaxies_per_bin
    end_index = start_index + galaxies_per_bin if bin_i < num_bins - 1 else num_galaxies


    bin_indices = sorted_indices[start_index:end_index]

    zphot_src = source_catalog.iloc[bin_indices]['zphot']  
    pdfs_mag = result.data[mask][bin_indices]  
    
    mag_qp = DS.add_data(key='src_pdfs', handle_class=QPHandle, data=pdfs_mag)

    naive_stack_src_phot = NaiveStackSummarizer.make_stage(
        zmin=0.0, zmax=3.0, nzbins=301, nsamples=25, 
        hdf5_groupname='', chunk_size=400000, 
        name=f'naive_stack_src_phot_bin{bin_i}'
    )
    naive_results_src_phot = naive_stack_src_phot.summarize(mag_qp)
    
    del zphot_src, pdfs_mag, mag_qp, naive_results_src_phot
    gc.collect()  

In [None]:
colors = ['#4d4d4d', '#08306b', '#6baed6', '#ffcc00', '#ffb347']

lens_srd = pd.read_csv('/global/u1/' + nersc_name[0] + "/" + nersc_name + '/lens_SRD', sep=' ', index_col=False).T #Change here the location of SRD lens file
bins = [float(x) for x in np.array(lens_srd.index)]

src_srd=pd.read_csv('/global/u1/' + nersc_name[0] + "/" + nersc_name + '/src_SRD', sep=' ',index_col=False).T
bins = [float(x) for x in np.array(lens_srd.index)]
bins = np.round(np.array(bins),4)

def plot_nz_from_bins(num_bins):
    plt.figure(figsize=(12, 8))  # Create a new figure for plotting
    
    for i, bin_num in enumerate(range(1, num_bins + 1)):
        # Read the data from the HDF5 file
        input_file_true = qp.read(f'true_NZ_true_nz_src_{bin_num}.hdf5')
        input_file_phot = DS.read_file('pdfs_data', QPHandle, 
                      f'single_NZ_naive_stack_src_phot_bin{bin_num-1}.hdf5')
        
        y_true = input_file_true.objdata()['pdfs'][0]
        x_true = input_file_true.metadata()['bins'][0]
        
        y_phot = input_file_phot().build_tables()['data']['yvals'][0]
        x_phot = input_file_phot().build_tables()['meta']['xvals'][0]

        # Smoothing the true curve before normalizing
        cs_true = UnivariateSpline(x_true[:-1], y_true)
        cs_true.set_smoothing_factor(1)  # Adjust the smoothing factor here

        smoothed_y_true = cs_true(x_true[:-1])
        
        # Normalize the area under the smoothed true curve
        area_true = np.trapz(smoothed_y_true, x_true[:-1])
        y_true_normalized = smoothed_y_true / area_true  # Normalized to area 1
        
        # Normalize the area under the photometric curve
        area_phot = np.trapz(y_phot, x_phot)
        y_phot_normalized = y_phot / area_phot  # Normalized to area 1

        # Plot the photometric and true curves
        plt.plot(x_phot, y_phot_normalized, color=colors[i], linewidth=2)
        plt.plot(x_true[:-1], y_true_normalized, color=colors[i], linestyle='--', linewidth=2)
        plt.plot(bins,src_srd[i],color='red',linewidth=2)
        
        
    plt.axvspan(0, 0.4, color=colors[0], alpha=0.3)  
    plt.axvspan(0.4, 0.6, color=colors[1], alpha=0.3)  
    plt.axvspan(0.6, 0.8, color=colors[2], alpha=0.3)  
    plt.axvspan(0.8, 1.0, color=colors[3], alpha=0.3)  
    plt.axvspan(1.0, 3, color=colors[4], alpha=0.3)  
      
    
    plt.plot([],[], label=f'True Bin',linewidth=2,ls='--',color=colors[i])
    plt.plot([],[], label=f'Phot Bin',linewidth=2,color=colors[i])
    plt.plot([],[], label=f'LSST DESC SRD Y1',linewidth=2,color='red')
        
    # Customizing the plot
    plt.xlabel('Redshift (z)', fontsize=18)
    plt.ylabel('N(z)', fontsize=18)
    plt.title('N(z) Distribution for src sample for the true and observed redshifts',fontsize=20)
    plt.legend(fontsize=14,loc=1)
    plt.ylim(0, 7)  # Adjusted for normalized values
    plt.xlim(0, 3)
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)
    plt.show()

# Call the function to plot N(z) for 5 bins
plot_nz_from_bins(num_bins=5)


# Roman Rubin sims redshift distribution

### SOURCE

In [None]:
a_final = 3.0
b_final = 19.0

lens_final = catalog[catalog['mag_i_lsst']<a_final*catalog['zphot']+b_final]

mask = ~np.isin(catalog['galaxy_id'], lens_final['galaxy_id'])
source_catalog = catalog[mask]

In [None]:
# Open the original HDF5 file (read mode)
with h5py.File(f'{path}/roman_rubin_y1_a_test_10sig.hdf5', 'r') as old_file:
    # Navigate inside the 'photometry' group in the old file
    if 'photometry' in old_file:
        old_photometry_group = old_file['photometry']
        
        # Get the redshift data from the old file
        if 'zphot' in old_photometry_group:
            redshift_data = old_photometry_group['zphot'][:]
            redshift_data = redshift_data[catalog['mag_i_lsst']>a_final*catalog['zphot']+b_final]
            
            num_galaxies = len(redshift_data)
            
            num_bins = 5
            galaxies_per_bin = num_galaxies // num_bins
            
            sorted_indices = np.argsort(redshift_data)
            
            for bin_i in range(1, num_bins + 1):
                # Create a new HDF5 file for each bin (write mode)
                with h5py.File(f'roman_rubin_test_binning_src_{bin_i}.hdf5', 'w') as new_file:

                    photometry_group = new_file.create_group('photometry')

                    start_index = (bin_i - 1) * galaxies_per_bin
                    end_index = start_index + galaxies_per_bin if bin_i < num_bins else num_galaxies
                    
                    bin_indices = sorted_indices[start_index:end_index]

                    columns_to_keep = [
                        "mag_err_g_lsst", "mag_err_i_lsst", "mag_err_r_lsst", 
                        "mag_err_u_lsst", "mag_err_y_lsst", "mag_err_z_lsst",
                        "mag_g_lsst", "mag_i_lsst", "mag_r_lsst", 
                        "mag_u_lsst", "mag_y_lsst", "mag_z_lsst", 
                        "redshift", "galaxy_id"
                    ]

                    # Loop through the columns and filter based on bin_indices
                    for column in columns_to_keep:
                        if column in old_photometry_group:
                            data = old_photometry_group[column][:]
                            data = data[catalog['mag_i_lsst']>a_final*catalog['zphot']+b_final]
                            filtered_data = data[bin_indices]  
                            
                            if column == "galaxy_id":
                                photometry_group.create_dataset("id", data=filtered_data)
                            else:
                                photometry_group.create_dataset(column, data=filtered_data)
                        else:
                            print(f"Column {column} not found in the 'photometry' group.")


In [None]:
with h5py.File(f'{path}/roman_rubin_y1_a_test_10sig.hdf5', 'r') as old_file:
    # Navigate inside the 'photometry' group in the old file
    if 'photometry' in old_file:
        old_photometry_group = old_file['photometry']

        if 'zphot' in old_photometry_group:
            redshifts = old_photometry_group['zphot'][:]
            redshifts = redshifts[catalog['mag_i_lsst']>a_final*catalog['zphot']+b_final]
            num_galaxies = len(redshifts)
            
            sorted_indices = np.argsort(redshifts)
            sorted_redshifts = redshifts[sorted_indices]
            
            num_bins = 5
            galaxies_per_bin = num_galaxies // num_bins
            
            # Loop over each bin from 1 to 5
            for bin_num in range(1, num_bins + 1):
                
                start_index = (bin_num - 1) * galaxies_per_bin
                end_index = start_index + galaxies_per_bin if bin_num < num_bins else num_galaxies
                
                bin_indices = sorted_indices[start_index:end_index]
                
                class_ids = np.array([assign_class_id(z) for z in sorted_redshifts[start_index:end_index]])

                # Function to create each bin file
                def create_bin_file(bin_num, bin_indices, class_ids):
                    output_file = f'output_tomo_binned_src_{bin_num}.hdf5'
                    
                    with h5py.File(output_file, 'w') as outfile:
                        outfile.create_dataset('row_index', data=bin_indices)
                        outfile.create_dataset('class_id', data=class_ids)
                    
                    print(f"HDF5 file '{output_file}' created successfully!")


                create_bin_file(bin_num, bin_indices, class_ids)

In [None]:
%%capture
for i in range(5):
    
    true_nz_file = f'roman_rubin_test_binning_src_{i+1}.hdf5'
    true_nz = DS.read_file('true_nz', path=true_nz_file, handle_class=TableHandle)
    
    # Create the histogram stage for the ith bin
    nz_hist = TrueNZHistogrammer.make_stage(
        name=f'true_nz_src_{i+1}',  
        hdf5_groupname='photometry',
        redshift_col='redshift',
        zmin=0.0,
        zmax=3.0,
        nzbins=301
    )
    
    
    tomo_file = f"output_tomo_binned_src_{i+1}.hdf5"  
    tomo_bins = DS.read_file('tomo_bins', path=tomo_file, handle_class=TableHandle)
    
    out_hist = nz_hist.histogram(true_nz, tomo_bins)
    
    print(f"Histogram for bin {i+1} created successfully.")

In [None]:
# Function to read and plot results from each bin file
def plot_nz_from_bins(num_bins):
    plt.figure(figsize=(10, 6))  
    
    for bin_num in range(1, num_bins + 1):
        # Read the data from the HDF5 file
        input_file = qp.read(f'true_NZ_true_nz_src_{bin_num}.hdf5')
        
        y_true = input_file.objdata()['pdfs'][0]
        x_true = input_file.metadata()['bins'][0]

        cs = UnivariateSpline(x_true[:-1], y_true)
        cs.set_smoothing_factor(8) # adjust the smoothing of ztrue here !!!!

        plt.plot(x_true[:-1], y_true, label=f'Bin {bin_num}')
            
    plt.xlabel('Redshift (z)')
    plt.ylabel('N(z)')
    plt.title('N(z) Distribution for src sample for the true redshifts')
    plt.legend()
    plt.grid()
    plt.ylim(0, 7.5)
    plt.xlim(0, 3.0)
    plt.show()

plot_nz_from_bins(num_bins=5)