In [None]:
# usual imports
import os
import numpy as np
import matplotlib.pyplot as plt
from rail.core.utils import RAILDIR
from rail.core import RailStage
from rail.core.data import TableHandle
#from rail.estimation.algos.sompz_version.utils import RAIL_SOMPZ_DIR
#from rail.pipelines.estimation.estimate_all import EstimatePipeline
#from rail.core import common_params
#from rail.pipelines.utils.name_factory import NameFactory, DataType, CatalogType, ModelType, PdfType
import qp
import ceci

In [None]:
from rail.estimation.algos.sompz import SOMPZEstimator

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
# change to your rail location
RAIL_SOMPZ_DIR = "/global/u2/j/jmyles/repositories/LSSTDESC/rail_sompz/src"
data_path = '/global/cfs/projectdirs/des/jmyles/sompz_desc/'

In [None]:
#from rail.core.utils import find_rail_file
datadir = '/global/cfs/projectdirs/des/jmyles/sompz_desc/'
testFileSpec = os.path.join(datadir, 'spec_data.h5') #'./datafiles/romandesc_deep_data_3700.hdf5'
testFileBalrog = os.path.join(datadir, 'balrog_data.h5') #'./datafiles/romandesc_deep_data_3700.hdf5'
testFileWide = os.path.join(datadir, 'wide_data.h5') #'./datafiles/romandesc_wide_data_5000.hdf5'
spec_data = DS.read_file("input_spec_data", TableHandle, testFileSpec)
balrog_data = DS.read_file("input_balrog_data", TableHandle, testFileBalrog)
wide_data = DS.read_file("input_wide_data", TableHandle, testFileWide)

### Define metadata for SOMPZ inference

In [None]:
bands = ['U','G','R','I','Z','Y','J','H','K']

deepbands = []
deeperrs = []
zeropts = []
for band in bands:
    deepbands.append(f'FLUX_{band}')
    deeperrs.append(f'FLUX_ERR_{band}')
    zeropts.append(30.)

widebands = []
wideerrs = []  
for band in bands[:6]:
    widebands.append(f'FLUX_{band}')
    wideerrs.append(f'FLUX_ERR_{band}')
    
refband_deep=deepbands[3]
refband_wide=widebands[3]

In [None]:
sompz_params = dict(inputs_deep=deepbands, input_errs_deep=deeperrs,
                    zero_points_deep=zeropts, 
                    inputs_wide=widebands, input_errs_wide=wideerrs,
                    convert_to_flux_deep=False, convert_to_flux_wide=False, 
                    set_threshold_deep=True, thresh_val_deep=1.e-5, 
                    som_shape_wide=(32,32), som_minerror_wide=0.005,
                    som_take_log_wide=False, som_wrap_wide=False,
                    specz_name='Z',
                    debug=False)

In [None]:
# bands = 'grizy'
# maglims = [27.66, 27.25, 26.6, 26.24, 25.35]
# maglim_dict={}
# for band,limx in zip(bands, maglims):
#     maglim_dict[f"HSC{band}_cmodel_dered"] = limx

### Prepare and run SOMPZ Estimation

In [None]:
som_estimate = SOMPZEstimator.make_stage(name="sompz_estimator", 
                                      spec_groupname="key", 
                                      balrog_groupname="key", 
                                      wide_groupname="key",
                                      model="DEMO_CARDINAL_model.pkl", 
                                      data_path='./', **sompz_params)

In [None]:
spec_data.data['key'].head()

In [None]:
output = som_estimate.estimate(spec_data, balrog_data, wide_data)

In [None]:
# output = {
#     'nz': som_estimate.get_handle("nz"),
#     'spec_data_deep_assignment': som_estimate.get_handle("spec_data_deep_assignment"),
#     'balrog_data_deep_assignment': som_estimate.get_handle("balrog_data_deep_assignment"),
#     'wide_data_assignment': som_estimate.get_handle("wide_data_assignment"), 
#     'pz_c': som_estimate.get_handle("pz_c"), 
#     'pz_chat': som_estimate.get_handle("pz_chat"), 
#     'pc_chat': som_estimate.get_handle("pc_chat"), 
# }

## display $n(z)$

In [None]:
# directly reading the hdf5 files with qp

qp_single_nz_sompz = qp.read('./nz_som_estimator.hdf5')

In [None]:
z_grid = np.linspace(0,6,600)
nz_sompz_grid = qp_single_nz_sompz.pdf(z_grid)

In [None]:
# Part of the spectroscopic samples failed and have z=-99, we should exclude them. 
specz_good = spec_data.data['key'][spec_data.data['key']['Z']>0.0]['Z']

### Make a plot to compare the two summarizers and the true n(z)

In [None]:
plt.figure()
plt.plot(z_grid, nz_sompz_grid[0], label = 'SOMPZ')
plt.hist(specz_good,density = True ,bins = 600,histtype = 'step', label = 'Spec-z')
plt.xlim(-0.1,3)
plt.ylim(0,1.2)
plt.xlabel('z')
plt.ylabel('n(z)')
plt.legend()