Skip to content

Commit

Permalink
Merge pull request #29 from LSSTDESC/issue/11/point-estimate-stage-co…
Browse files Browse the repository at this point in the history
…nfig-value

New config param to define point estimates to calculate when running a `CatEstimator` stage
  • Loading branch information
jfcrenshaw committed Aug 3, 2023
2 parents d3c996b + b5eff6d commit 0c27549
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/rail/core/common_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
err_bands=Param(list, lsst_mag_err_cols, msg="Names of columns for magnitgude errors by filter band"),
mag_limits=Param(dict, lsst_def_maglims, msg="Limiting magnitdues by filter"),
ref_band=Param(str, "mag_i_lsst", msg="band to use in addition to colors"),
redshift_col=Param(str, 'redshift', msg="name of redshift column")
redshift_col=Param(str, 'redshift', msg="name of redshift column"),
calculated_point_estimates=Param(dtype=list, default=[],
msg="List of strings defining which point estimates to automatically calculate using `qp.Ensemble`. Options include, 'mean', 'mode', 'median'.")
)


Expand Down
7 changes: 6 additions & 1 deletion src/rail/estimation/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
Abstract base classes defining Estimators of individual galaxy redshift uncertainties
"""

from rail.core.common_params import SHARED_PARAMS
from rail.core.data import TableHandle, QPHandle, ModelHandle
from rail.core.stage import RailStage

import gc

from rail.estimation.informer import CatInformer
Expand All @@ -22,7 +24,10 @@ class CatEstimator(RailStage):

name = 'CatEstimator'
config_options = RailStage.config_options.copy()
config_options.update(chunk_size=10000, hdf5_groupname=str)
config_options.update(
chunk_size=10000,
hdf5_groupname=SHARED_PARAMS['hdf5_groupname'],
calculated_point_estimates=SHARED_PARAMS['calculated_point_estimates'])
inputs = [('model', ModelHandle),
('input', TableHandle)]
outputs = [('output', QPHandle)]
Expand Down

0 comments on commit 0c27549

Please sign in to comment.