### usual imports


In [None]:
import os
import numpy as np
from rail.core.utils import RAILDIR
#from rail.estimation.algos.bpz_version.utils import RAIL_BPZ_DIR
from rail.pipelines.estimation.estimate_all import EstimatePipeline
from rail.core import common_params

import qp
import ceci

import rail.stages
rail.stages.import_and_attach_all()
from rail.stages import *

from rail.pipelines.utils.name_factory import NameFactory, DataType, CatalogType, ModelType, PdfType
from rail.core.stage import RailStage, RailPipeline

import ceci
namer = NameFactory()

In [None]:
RAIL_BPZ_DIR = '../../../rail_bpz/rail_bpz/src/'
DATA_DIR = "/net/home/fohlen14/yanza21/DATA/data_for_rail/"  # the directory that sore data

### The following two cells define common parameters before building stages. These parameters are shared by all the stages. Params changed in individual stages will not result in changes in common_params

In [None]:
calib_file = f'{DATA_DIR}dered_223501_sz_match_pdr3_dud_NONDET_v2.hdf5'
test_file = f'{DATA_DIR}hecto_deredden_nondet.hdf5'

from rail.core.utilStages import ColumnMapper, TableConverter

bands = 'grizy'
maglims = [27.66, 27.25, 26.6, 26.24, 25.35]
maglim_dict={}
for band,limx in zip(bands, maglims):
    maglim_dict[f"HSC{band}_cmodel_dered"] = limx
    
common_params.set_param_defaults(
    bands=[f'{band}_cmodel_mag_dered' for band in bands],
    err_bands=[f'{band}_cmodel_magerr' for band in bands],
    nondetect_val=np.nan,
    ref_band='HSCi_cmodel_dered',
    redshift_col='specz_redshift',
    mag_limits=maglim_dict,
    zmax=6.0,
)

from rail.core import *
RailEnv.print_rail_stage_dict()

In [None]:
pipe = EstimatePipeline()

pipe.estimate_bpz.config.update(
    zp_errors=[0.01, 0.01, 0.01, 0.01, 0.01],
    columns_file = os.path.join(
        RAIL_BPZ_DIR, 'rail', 'examples_data', 'estimation_data', 'configs', 'test_bpz_hsc.columns',
    ),
)

In [None]:
input_dict = dict(spec_groupname="", 
    model_fzboost=os.path.join(namer.get_data_dir(DataType.model, ModelType.estimator), "model_FZBoost.hdf5"),             
    model_somoclu=os.path.join(namer.get_data_dir(DataType.model, ModelType.estimator), "model_somoclu.hdf5"),   
    model_bpz=os.path.join(namer.get_data_dir(DataType.model, ModelType.estimator), "model_bpz.hdf5"),                             
    model_simplesom=os.path.join(namer.get_data_dir(DataType.model, ModelType.estimator), "model_simplesom.hdf5"),                                        
    model_trainz=os.path.join(namer.get_data_dir(DataType.model, ModelType.estimator), "model_trainz.pkl"),
    model_knn=os.path.join(namer.get_data_dir(DataType.model, ModelType.estimator), "model_knn.pkl"),
    input=test_file,
    spec_input=calib_file,
)


In [None]:
pipe_info = pipe.initialize(input_dict, dict(output_dir='.', log_dir='.', resume=False), None)
pipe.save(f'tmp_estimate_soms.yml')

import ceci
pr = ceci.Pipeline.read(f'tmp_estimate_soms.yml')

In [None]:
pr.run()

In [None]:
import tables_io
DS = RailStage.data_store
DS.__class__.allow_overwrite = True
calib_data = tables_io.read(calib_file)
calib_data = DS.add_data("calib_data", calib_data, TableHandle)

In [None]:
infile_somoclu = f'pdf/nz/output_somoclu.hdf5'
infile_simplesom = f'pdf/nz/output_simplesom.hdf5'

qp_somoclu = qp.read(infile_somoclu, )
qp_simplesom = qp.read(infile_simplesom, )

In [None]:
def get_cont_hist(data, bins):
    hist, bin_edge = np.histogram(data, bins=bins, density=True)
    return hist, (bin_edge[1:]+bin_edge[:-1])/2

In [None]:
test_nz_hist, zbin = get_cont_hist(calib_data.data[common_params.SHARED_PARAMS['redshift_col']], np.linspace(0,3,50))
somoclu_nz_hist = np.squeeze(qp_somoclu.pdf(zbin))
simplesom_nz_hist = np.squeeze(qp_simplesom.pdf(zbin))


plt.figure(figsize=(12,12))
plt.plot(spec_data.data['photometry']['specz_redshift'][spec_data.data['photometry']['specz_redshift']>0], 
         qp_FZBoost.ancil['zmode'].reshape(-1)[spec_data.data['photometry']['specz_redshift']>0], '.', ms=0.2)
plt.xlabel('specz')
plt.ylabel('FZBoost')
plt.xlim(0,6)
plt.ylim(0,6)

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1,1, figsize=(12,8))
ax.set_xlabel("redshift", fontsize=15)
ax.set_ylabel("N(z)", fontsize=15)
ax.plot(zbin, test_nz_hist, label='True N(z)')
#ax.plot(zbin, test_nz_hist, label='True N(z)')
ax.plot(zbin, np.mean(somoclu_nz_hist, axis=0), label='somoclu N(z)')
ax.plot(zbin, np.mean(simplesom_nz_hist, axis=0), label='simplesom N(z)')
plt.legend()