-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from LSSTDESC/issue/10/tomo_bins
Issue/10/tomo bins
- Loading branch information
Showing
5 changed files
with
415 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
A classifier that uses pz point estimate to assign | ||
tomographic bins with uniform binning. | ||
""" | ||
|
||
import numpy as np | ||
from ceci.config import StageParameter as Param | ||
from rail.estimation.classifier import PZClassifier | ||
from rail.core.data import TableHandle | ||
|
||
class EqualCountClassifier(PZClassifier): | ||
"""Classifier that simply assign tomographic | ||
bins based on point estimate according to SRD""" | ||
|
||
name = 'EqualCountClassifier' | ||
config_options = PZClassifier.config_options.copy() | ||
config_options.update( | ||
id_name=Param(str, "", msg="Column name for the object ID in the input data, if empty the row index is used as the ID."), | ||
point_estimate=Param(str, 'zmode', msg="Which point estimate to use"), | ||
zmin=Param(float, 0.0, msg="Minimum redshift of the sample"), | ||
zmax=Param(float, 3.0, msg="Maximum redshift of the sample"), | ||
nbins=Param(int, 5, msg="Number of tomographic bins"), | ||
no_assign=Param(int, -99, msg="Value for no assignment flag"), | ||
) | ||
outputs = [('output', TableHandle)] | ||
|
||
def __init__(self, args, comm=None): | ||
PZClassifier.__init__(self, args, comm=comm) | ||
|
||
def run(self): | ||
test_data = self.get_data('input') | ||
npdf = test_data.npdf | ||
|
||
try: | ||
zb = test_data.ancil[self.config.point_estimate] | ||
except KeyError: | ||
raise KeyError(f"{self.config.point_estimate} is not contained in the data ancil, you will need to compute it explicitly.") | ||
|
||
# tomographic bins with equal number density | ||
sortind = np.argsort(zb) | ||
cum=np.arange(1,(len(zb)+1)) | ||
bin_index = np.zeros(len(zb)) | ||
for ii in range(self.config.nbins): | ||
perc1=ii/self.config.nbins | ||
perc2=(ii+1)/self.config.nbins | ||
ind=(cum/cum[-1]>perc1)&(cum/cum[-1]<=perc2) | ||
useind=sortind[ind] | ||
bin_index[useind] = int(ii+1) | ||
|
||
if self.config.id_name != "": | ||
# below is commented out and replaced by a redundant line | ||
# because the data doesn't have ID yet | ||
# obj_id = test_data[self.config.id_name] | ||
obj_id = np.arange(npdf) | ||
elif self.config.id_name == "": | ||
# ID set to row index | ||
obj_id = np.arange(npdf) | ||
self.config.id_name="row_index" | ||
|
||
class_id = {self.config.id_name: obj_id, "class_id": bin_index} | ||
self.add_data('output', class_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
""" | ||
A classifier that uses pz point estimate to assign | ||
tomographic bins with uniform binning. | ||
""" | ||
|
||
import numpy as np | ||
from ceci.config import StageParameter as Param | ||
from rail.estimation.classifier import PZClassifier | ||
from rail.core.data import TableHandle | ||
|
||
class UniformBinningClassifier(PZClassifier): | ||
"""Classifier that simply assign tomographic | ||
bins based on point estimate according to SRD""" | ||
|
||
name = 'UniformBinningClassifier' | ||
config_options = PZClassifier.config_options.copy() | ||
config_options.update( | ||
id_name=Param(str, "", msg="Column name for the object ID in the input data, if empty the row index is used as the ID."), | ||
point_estimate=Param(str, 'zmode', msg="Which point estimate to use"), | ||
zbin_edges=Param(list, [], msg="The tomographic redshift bin edges. If this is given (contains two or more entries), all settings below will be ignored."), | ||
zmin=Param(float, 0.0, msg="Minimum redshift of the sample"), | ||
zmax=Param(float, 3.0, msg="Maximum redshift of the sample"), | ||
nbins=Param(int, 5, msg="Number of tomographic bins"), | ||
no_assign=Param(int, -99, msg="Value for no assignment flag"), | ||
) | ||
outputs = [('output', TableHandle)] | ||
|
||
def __init__(self, args, comm=None): | ||
PZClassifier.__init__(self, args, comm=comm) | ||
|
||
def run(self): | ||
test_data = self.get_data('input') | ||
npdf = test_data.npdf | ||
|
||
try: | ||
zb = test_data.ancil[self.config.point_estimate] | ||
except KeyError: | ||
raise KeyError(f"{self.config.point_estimate} is not contained in the data ancil, you will need to compute it explicitly.") | ||
|
||
# binning options | ||
if len(self.config.zbin_edges)>=2: | ||
# this overwrites all other key words | ||
# linear binning defined by zmin, zmax, and nbins | ||
bin_index = np.digitize(zb, self.config.zbin_edges) | ||
# assign -99 to objects not in any bin: | ||
bin_index[bin_index==0]=self.config.no_assign | ||
bin_index[bin_index==len(self.config.zbin_edges)]=self.config.no_assign | ||
|
||
else: | ||
# linear binning defined by zmin, zmax, and nbins | ||
bin_index = np.digitize(zb, np.linspace(self.config.zmin, self.config.zmax, self.config.nbins+1)) | ||
# assign -99 to objects not in any bin: | ||
bin_index[bin_index==0]=self.config.no_assign | ||
bin_index[bin_index==(self.config.nbins+1)]=self.config.no_assign | ||
|
||
|
||
if self.config.id_name != "": | ||
# below is commented out and replaced by a redundant line | ||
# because the data doesn't have ID yet | ||
# obj_id = test_data[self.config.id_name] | ||
obj_id = np.arange(npdf) | ||
elif self.config.id_name == "": | ||
# ID set to row index | ||
obj_id = np.arange(npdf) | ||
self.config.id_name="row_index" | ||
|
||
class_id = {self.config.id_name: obj_id, "class_id": bin_index} | ||
self.add_data('output', class_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
""" | ||
Abstract base classes defining classifiers. | ||
""" | ||
from rail.core.data import QPHandle, TableHandle, ModelHandle | ||
from rail.core.stage import RailStage | ||
|
||
|
||
class CatClassifier(RailStage): #pragma: no cover | ||
"""The base class for assigning classes to catalogue-like table. | ||
Classifier uses a generic "model", the details of which depends on the sub-class. | ||
CatClassifier take as "input" a catalogue-like table, assign each object into | ||
a tomographic bin, and provide as "output" a tabular data which can be appended | ||
to the catalogue. | ||
""" | ||
|
||
name='CatClassifier' | ||
config_options = RailStage.config_options.copy() | ||
config_options.update(chunk_size=10000, hdf5_groupname=str) | ||
inputs = [('model', ModelHandle), | ||
('input', TableHandle)] | ||
outputs = [('output', TableHandle)] | ||
|
||
def __init__(self, args, comm=None): | ||
"""Initialize Classifier""" | ||
RailStage.__init__(self, args, comm=comm) | ||
self._output_handle = None | ||
self.model = None | ||
if not isinstance(args, dict): #pragma: no cover | ||
args = vars(args) | ||
self.open_model(**args) | ||
|
||
|
||
def open_model(self, **kwargs): | ||
"""Load the model and/or attach it to this Classifier | ||
Parameters | ||
---------- | ||
model : `object`, `str` or `ModelHandle` | ||
Either an object with a trained model, | ||
a path pointing to a file that can be read to obtain the trained model, | ||
or a `ModelHandle` providing access to the trained model. | ||
Returns | ||
------- | ||
self.model : `object` | ||
The object encapsulating the trained model. | ||
""" | ||
model = kwargs.get('model', None) | ||
if model is None or model == 'None': | ||
self.model = None | ||
return self.model | ||
if isinstance(model, str): | ||
self.model = self.set_data('model', data=None, path=model) | ||
self.config['model'] = model | ||
return self.model | ||
if isinstance(model, ModelHandle): | ||
if model.has_path: | ||
self.config['model'] = model.path | ||
self.model = self.set_data('model', model) | ||
return self.model | ||
|
||
|
||
def classify(self, input_data): | ||
"""The main run method for the classifier, should be implemented | ||
in the specific subclass. | ||
This will attach the input_data to this `CatClassifier` | ||
(for introspection and provenance tracking). | ||
Then it will call the run() and finalize() methods, which need to | ||
be implemented by the sub-classes. | ||
The run() method will need to register the data that it creates to this Classifier | ||
by using `self.add_data('output', output_data)`. | ||
Finally, this will return a TableHandle providing access to that output data. | ||
Parameters | ||
---------- | ||
input_data : `dict` | ||
A dictionary of all input data | ||
Returns | ||
------- | ||
output: `dict` | ||
Class assignment for each galaxy. | ||
""" | ||
self.set_data('input', input_data) | ||
self.run() | ||
self.finalize() | ||
return self.get_handle('output') | ||
|
||
|
||
|
||
class PZClassifier(RailStage): | ||
"""The base class for assigning classes (tomographic bins) to per-galaxy PZ estimates | ||
PZClassifier take as "input" a `qp.Ensemble` with per-galaxy PDFs, and | ||
provide as "output" a tabular data which can be appended to the catalogue. | ||
""" | ||
|
||
name='PZClassifier' | ||
config_options = RailStage.config_options.copy() | ||
config_options.update(chunk_size=10000) | ||
inputs = [('input', QPHandle)] | ||
outputs = [('output', TableHandle)] | ||
|
||
def __init__(self, args, comm=None): | ||
"""Initialize Classifier""" | ||
RailStage.__init__(self, args, comm=comm) | ||
|
||
def classify(self, input_data): | ||
"""The main run method for the classifier, should be implemented | ||
in the specific subclass. | ||
This will attach the input_data to this `PZClassifier` | ||
(for introspection and provenance tracking). | ||
Then it will call the run() and finalize() methods, which need to | ||
be implemented by the sub-classes. | ||
The run() method will need to register the data that it creates to this Classifier | ||
by using `self.add_data('output', output_data)`. | ||
Finally, this will return a TableHandle providing access to that output data. | ||
Parameters | ||
---------- | ||
input_data : `qp.Ensemble` | ||
Per-galaxy p(z), and any ancilary data associated with it | ||
Returns | ||
------- | ||
output: `dict` | ||
Class assignment for each galaxy. | ||
""" | ||
self.set_data('input', input_data) | ||
self.run() | ||
self.finalize() | ||
return self.get_handle('output') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.