In [None]:
print('Begin script')

In [None]:
# usual imports
import os
import numpy as np
import matplotlib.pyplot as plt
from rail.core.utils import RAILDIR
from rail.core import RailStage
from rail.core.data import TableHandle
#from rail.estimation.algos.sompz_version.utils import RAIL_SOMPZ_DIR
#from rail.pipelines.estimation.estimate_all import EstimatePipeline
#from rail.core import common_params
#from rail.pipelines.utils.name_factory import NameFactory, DataType, CatalogType, ModelType, PdfType
import qp
import ceci

In [None]:
from rail.estimation.algos.sompz import SOMPZEstimator

In [None]:
import rail.estimation.algos.sompz as sompz_

In [None]:
from rail.sompz.utils import mean_of_hist

In [None]:
sompz_.__file__

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
#from rail.core.utils import find_rail_file

# change to your rail location
RAIL_SOMPZ_DIR = "/global/u2/j/jmyles/repositories/LSSTDESC/rail_sompz/src"

datadir = '/pscratch/sd/j/jmyles/sompz_buzzard/2024-06-24/'
datadir_run = os.path.join(datadir, 'run-2024-07-01')
outdir = datadir_run
os.system(f'mkdir -p {outdir}')

testFileSpec = os.path.join(datadir, 'spec_data.h5') #'./datafiles/romandesc_deep_data_3700.hdf5'
testFileBalrog = os.path.join(datadir, 'balrog_data_subcatalog.h5') #'./datafiles/romandesc_deep_data_3700.hdf5'
testFileWide = os.path.join(datadir, 'wide_data_subsample.hdf5') #'./datafiles/romandesc_wide_data_5000.hdf5'

spec_data = DS.read_file("input_spec_data", TableHandle, testFileSpec)
balrog_data = DS.read_file("input_balrog_data", TableHandle, testFileBalrog)
wide_data = DS.read_file("input_wide_data", TableHandle, testFileWide)

model_file = os.path.join(datadir_run, "DEMO_CARDINAL_model_2024-06-24.pkl")

print('Catalogs specified')

### Define metadata for SOMPZ inference

In [None]:
# bands = ['U','G','R','I','Z','Y','J','H','K']

# deepbands = []
# deeperrs = []
# zeropts = []
# for band in bands:
#     deepbands.append(f'FLUX_{band}')
#     deeperrs.append(f'FLUX_ERR_{band}')
#     zeropts.append(30.)

# widebands = []
# wideerrs = []  
# for band in bands[:6]:
#     widebands.append(f'FLUX_{band}')
#     wideerrs.append(f'FLUX_ERR_{band}')
    
# refband_deep=deepbands[3]
# refband_wide=widebands[3]

bands_deep = ['lsst_u', 'lsst_g', 'lsst_r', 'lsst_i', 'lsst_z', 
              'VISTA_Filters_at80K_forETC_Y', 'VISTA_Filters_at80K_forETC_J', 'VISTA_Filters_at80K_forETC_H', 'VISTA_Filters_at80K_forETC_Ks',]
bands_wide = ['G','R','I','Z',] # 'U', 'Y','J','H','K'

deepbands = []
deeperrs = []
zeropts = []
for band in bands_deep:
    deepbands.append(f'TRUEMAG_{band}')
    deeperrs.append(f'TRUEMAG_ERR_{band}')
    zeropts.append(30.)

widebands = []
wideerrs = []  
for band in bands_wide: #[:6]:
    widebands.append(f'FLUX_{band}')
    wideerrs.append(f'FLUX_ERR_{band}')
    
    
refband_deep=deepbands[3]
refband_wide=widebands[2]

In [None]:
sompz_params = dict(inputs_deep=deepbands, input_errs_deep=deeperrs,
                    zero_points_deep=zeropts, 
                    inputs_wide=widebands, input_errs_wide=wideerrs,
                    convert_to_flux_deep=True, convert_to_flux_wide=False, 
                    set_threshold_deep=True, thresh_val_deep=1.e-5, 
                    som_shape_wide=(32,32), som_minerror_wide=0.005,
                    som_take_log_wide=False, som_wrap_wide=False,
                    specz_name='Z',
                    debug=False)

In [None]:
# bands = 'grizy'
# maglims = [27.66, 27.25, 26.6, 26.24, 25.35]
# maglim_dict={}
# for band,limx in zip(bands, maglims):
#     maglim_dict[f"HSC{band}_cmodel_dered"] = limx

### Prepare and run SOMPZ Estimation

In [None]:
print('make stage')
som_estimate = SOMPZEstimator.make_stage(name="cardinal_som_estimator", 
                                      spec_groupname="key", 
                                      balrog_groupname="key", 
                                      wide_groupname="", #"key"
                                      model=model_file, 
#                                      data_path=outdir, # TODO enable setting outdir for output files
                                         **sompz_params)

In [None]:
spec_data.data['key'].head()

In [None]:
#som_estimate.estimate?

In [None]:
print('estimate')
output = som_estimate.estimate(spec_data, balrog_data, wide_data)

In [None]:
outfile = os.path.join(outdir, 'output.npy')
np.save(outfile, output)

In [None]:
print('Finished Estimation. Proceed to plotting.')

## display $n(z)$

In [None]:
# directly reading the hdf5 files with qp
qp_file = os.path.join(RAIL_SOMPZ_DIR, '../examples/nz_cardinal_som_estimator.hdf5') # TODO make the hdf5 go in a dir we want
qp_single_nz_sompz = qp.read(qp_file)

In [None]:
nbins = 600
z_grid = np.linspace(0,6,nbins)
nz_sompz_grid = qp_single_nz_sompz.pdf(z_grid)

In [None]:
# Part of the spectroscopic samples failed and have z=-99, we should exclude them. 
specz_good = spec_data.data['key'][spec_data.data['key']['Z']>0.0]['Z']

### Make a plot to compare the two summarizers and the true n(z)

In [None]:
colors = ['tab:blue', 'tab:orange', 'tab:red', 'tab:green']
fig, axarr = plt.subplots(2, 1, figsize=(12, 9))
for i in range(4):
    axarr[i%2].hist(specz_good,density = True ,bins = nbins,histtype = 'step', label = 'SOMPZ spec-z calibration sample' if i // 2 == 0 else '', 
                    color='k', alpha=0.25)
    axarr[i%2].plot(z_grid, nz_sompz_grid[i], label = f'SOMPZ Estimate -- Bin {i+1}', color=colors[i])
    axarr[i%2].set_xlim(0,3)
    axarr[i%2].set_ylim(0,4)
    axarr[1].set_xlabel('z')
    axarr[i%2].set_ylabel('n(z)')
    axarr[i%2].legend()
    
outfile = os.path.join(outdir, 'nz_sompz_est_script.png')
fig.savefig(outfile, dpi=150)