Skip to content

Commit

Permalink
add predict_catboost_probabilities process
Browse files Browse the repository at this point in the history
  • Loading branch information
JeroenVerstraelen committed May 10, 2022
1 parent 0655889 commit 474f92c
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion openeo/rest/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ def apply_dimension(
process = PGNode.to_process_graph_argument(callback_process_node)
elif code or process:
# TODO EP-3555 unify `code` and `process`
process = self._get_callback(code or process, parent_parameters=["data"])
process = self._get_callback(code or process, parent_parameters=["data", "context"])
else:
raise OpenEoClientException("No UDF code or process given")
arguments = {
Expand Down Expand Up @@ -1902,6 +1902,27 @@ def predict_catboost(self, model: Union[str, RESTJob, MlModel], dimension: str =
reducer = lambda data, context: process('predict_catboost', data=data, model=context)
return self.reduce_dimension(dimension=dimension, reducer=reducer, context=model)

@openeo_process(mode="apply_dimension")
def predict_catboost_probabilities(self, model: Union[str, RESTJob, MlModel], dimension: str = "bands"):
"""
Apply ``apply_dimension`` with `predict_catboost` as process.
:param model: a reference to a trained model, one of
- a :py:class:`MlModel` instance (e.g. loaded from :py:meth:`Connection.load_ml_model`)
- a :py:class:`RESTJob` instance of a batch job that saved a single random forest model
- a job id (``str``) of a batch job that saved a single random forest model
- a STAC item URL (``str``) to load the random forest from.
(The STAC Item must implement the `ml-model` extension.)
:param dimension: dimension along which to apply the ``reduce_dimension`` process.
.. versionadded:: 0.10.1
"""
if not isinstance(model, MlModel):
model = MlModel.load_ml_model(connection=self.connection, id=model)
from openeo.processes import process
p = lambda data, context: process('predict_catboost_probabilities', data=data, model=context)
return self.apply_dimension(dimension=dimension, process=p, context=model)

@openeo_process
def dimension_labels(self, dimension: str) -> "DataCube":
"""
Expand Down

0 comments on commit 474f92c

Please sign in to comment.