From b5eff6d0b6aee02fcfcc63c0903c4283c557d700 Mon Sep 17 00:00:00 2001 From: Drew Oldag Date: Thu, 27 Jul 2023 15:04:41 -0700 Subject: [PATCH] Introduce a new config parameters, `calculated_point_estimates`, in `SHARED_PARAMS` that will be used by subclasses of `CatEstimator`. --- src/rail/core/common_params.py | 4 +++- src/rail/estimation/estimator.py | 7 ++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/rail/core/common_params.py b/src/rail/core/common_params.py index 985050b5..2113624f 100644 --- a/src/rail/core/common_params.py +++ b/src/rail/core/common_params.py @@ -27,7 +27,9 @@ err_bands=Param(list, lsst_mag_err_cols, msg="Names of columns for magnitgude errors by filter band"), mag_limits=Param(dict, lsst_def_maglims, msg="Limiting magnitdues by filter"), ref_band=Param(str, "mag_i_lsst", msg="band to use in addition to colors"), - redshift_col=Param(str, 'redshift', msg="name of redshift column") + redshift_col=Param(str, 'redshift', msg="name of redshift column"), + calculated_point_estimates=Param(dtype=list, default=[], + msg="List of strings defining which point estimates to automatically calculate using `qp.Ensemble`. Options include, 'mean', 'mode', 'median'.") ) diff --git a/src/rail/estimation/estimator.py b/src/rail/estimation/estimator.py index 360b8c44..0f4edf2a 100644 --- a/src/rail/estimation/estimator.py +++ b/src/rail/estimation/estimator.py @@ -2,8 +2,10 @@ Abstract base classes defining redshift estimations Informers and Estimators """ +from rail.core.common_params import SHARED_PARAMS from rail.core.data import TableHandle, QPHandle, ModelHandle from rail.core.stage import RailStage + import gc class CatEstimator(RailStage): @@ -19,7 +21,10 @@ class CatEstimator(RailStage): name = 'CatEstimator' config_options = RailStage.config_options.copy() - config_options.update(chunk_size=10000, hdf5_groupname=str) + config_options.update( + chunk_size=10000, + hdf5_groupname=SHARED_PARAMS['hdf5_groupname'], + calculated_point_estimates=SHARED_PARAMS['calculated_point_estimates']) inputs = [('model', ModelHandle), ('input', TableHandle)] outputs = [('output', QPHandle)]