Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `MultiResult` helper class to build process graphs with multiple result nodes ([#391](https://github.com/Open-EO/openeo-python-client/issues/391))

### Changed

### Removed
Expand Down
9 changes: 9 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ openeo.rest.mlmodel
:inherited-members:


openeo.rest.multiresult
-----------------------

.. automodule:: openeo.rest.multiresult
:members: MultiResult
:inherited-members:
:special-members: __init__


openeo.metadata
----------------

Expand Down
52 changes: 52 additions & 0 deletions docs/datacube_construction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,55 @@ Re-parameterization
```````````````````

TODO



.. _multi-result-process-graphs:
Building process graphs with multiple result nodes
===================================================

.. note::
Multi-result support is added in version 0.35.0

Most openEO use cases are just about building a single result data cube,
which is readily covered in the openEO Python client library through classes like
:py:class:`~openeo.rest.datacube.DataCube` and :py:class:`~openeo.rest.vectorcube.VectorCube`.
It is straightforward to create a batch job from these, or execute/download them synchronously.

The openEO API also allows multiple result nodes in a single process graph,
for example to persist intermediate results or produce results in different output formats.
To support this, the openEO Python client library provides the :py:class:`~openeo.rest.multiresult.MultiResult` class,
which allows to group multiple :py:class:`~openeo.rest.datacube.DataCube` and :py:class:`~openeo.rest.vectorcube.VectorCube` objects
in a single entity that can be used to create or run batch jobs. For example:


.. code-block:: python

from openeo import MultiResult

cube1 = ...
cube2 = ...
multi_result = MultiResult([cube1, cube2])
job = multi_result.create_job()


Moreover, it is not necessary to explicitly create such a
:py:class:`~openeo.rest.multiresult.MultiResult` object,
as the :py:meth:`Connection.create_job() <openeo.rest.connection.Connection.create_job>` method
directly supports passing multiple data cube objects in a list,
which will be automatically grouped as a multi-result:

.. code-block:: python

cube1 = ...
cube2 = ...
job = connection.create_job([cube1, cube2])


.. important::

Only a single :py:class:`~openeo.rest.connection.Connection` can be in play
when grouping multiple results like this.
As everything is to be merged in a single process graph
to be sent to a single backend,
it is not possible to mix cubes created from different connections.
1 change: 1 addition & 0 deletions openeo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class BaseOpenEoException(Exception):
from openeo.rest.datacube import UDF, DataCube
from openeo.rest.graph_building import collection_property
from openeo.rest.job import BatchJob, RESTJob
from openeo.rest.multiresult import MultiResult


def client_version() -> str:
Expand Down
46 changes: 41 additions & 5 deletions openeo/internal/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

import abc
import collections
import copy
import json
import sys
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Dict, Iterator, Optional, Tuple, Union
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union

from openeo.api.process import Parameter
from openeo.internal.process_graph_visitor import (
Expand Down Expand Up @@ -243,7 +244,7 @@ def walk(x) -> Iterator[PGNode]:
yield from walk(self.arguments)


def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, Any]) -> Dict[str, dict]:
def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, List[FlatGraphableMixin], Any]) -> Dict[str, dict]:
"""
Convert given object to a internal flat dict graph representation.
"""
Expand All @@ -252,12 +253,15 @@ def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, Any]) -> Dict[str, di
# including `{"process_graph": {nodes}}` ("process graph")
# or just the raw process graph nodes?
if isinstance(x, dict):
# Assume given dict is already a flat graph representation
return x
elif isinstance(x, FlatGraphableMixin):
return x.flat_graph()
elif isinstance(x, (str, Path)):
# Assume a JSON resource (raw JSON, path to local file, JSON url, ...)
return load_json_resource(x)
elif isinstance(x, (list, tuple)) and all(isinstance(i, FlatGraphableMixin) for i in x):
return MultiLeafGraph(x).flat_graph()
raise ValueError(x)


Expand Down Expand Up @@ -322,20 +326,29 @@ def generate(self, process_id: str):

class GraphFlattener(ProcessGraphVisitor):

def __init__(self, node_id_generator: FlatGraphNodeIdGenerator = None):
def __init__(self, node_id_generator: FlatGraphNodeIdGenerator = None, multi_input_mode: bool = False):
super().__init__()
self._node_id_generator = node_id_generator or FlatGraphNodeIdGenerator()
self._last_node_id = None
self._flattened: Dict[str, dict] = {}
self._argument_stack = []
self._node_cache = {}
self._multi_input_mode = multi_input_mode

def flatten(self, node: PGNode) -> Dict[str, dict]:
"""Consume given nested process graph and return flat dict representation"""
if self._flattened and not self._multi_input_mode:
raise RuntimeError("Flattening multiple graphs, but not in multi-input mode")
self.accept_node(node)
assert len(self._argument_stack) == 0
self._flattened[self._last_node_id]["result"] = True
return self._flattened
return self.flattened(set_result_flag=not self._multi_input_mode)

def flattened(self, set_result_flag: bool = True) -> Dict[str, dict]:
flat_graph = copy.deepcopy(self._flattened)
if set_result_flag:
# TODO #583 an "end" node is not necessarily a "result" node
flat_graph[self._last_node_id]["result"] = True
return flat_graph

def accept_node(self, node: PGNode):
# Process reused nodes only first time and remember node id.
Expand Down Expand Up @@ -438,3 +451,26 @@ def _process_from_parameter(self, name: str) -> Any:
if name not in self._parameters:
raise ProcessGraphVisitException("No substitution value for parameter {p!r}.".format(p=name))
return self._parameters[name]


class MultiLeafGraph(FlatGraphableMixin):
"""
Container for process graphs with multiple leaf/result nodes.
"""

__slots__ = ["_leaves"]

def __init__(self, leaves: Iterable[FlatGraphableMixin]):
self._leaves = list(leaves)

def flat_graph(self) -> Dict[str, dict]:
flattener = GraphFlattener(multi_input_mode=True)
for leaf in self._leaves:
if isinstance(leaf, PGNode):
flattener.flatten(leaf)
elif isinstance(leaf, _FromNodeMixin):
flattener.flatten(leaf.from_node())
else:
raise ValueError(f"Unsupported type {type(leaf)}")

return flattener.flattened(set_result_flag=True)
81 changes: 78 additions & 3 deletions openeo/rest/_testing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import collections
import json
import re
from typing import Callable, Iterator, Optional, Sequence, Union
from typing import Callable, Iterable, Iterator, Optional, Sequence, Tuple, Union

from openeo import Connection, DataCube
from openeo.rest.vectorcube import VectorCube
Expand All @@ -19,8 +21,12 @@ class DummyBackend:
and allows inspection of posted process graphs
"""

# TODO: move to openeo.testing

__slots__ = (
"_requests_mock",
"connection",
"file_formats",
"sync_requests",
"batch_jobs",
"validation_requests",
Expand All @@ -33,8 +39,14 @@ class DummyBackend:
# Default result (can serve both as JSON or binary data)
DEFAULT_RESULT = b'{"what?": "Result data"}'

def __init__(self, requests_mock, connection: Connection):
def __init__(
self,
requests_mock,
connection: Connection,
):
self._requests_mock = requests_mock
self.connection = connection
self.file_formats = {"input": {}, "output": {}}
self.sync_requests = []
self.batch_jobs = {}
self.validation_requests = []
Expand Down Expand Up @@ -69,6 +81,59 @@ def __init__(self, requests_mock, connection: Connection):
)
requests_mock.post(connection.build_url("/validation"), json=self._handle_post_validation)

@classmethod
def at_url(cls, root_url: str, *, requests_mock, capabilities: Optional[dict] = None) -> DummyBackend:
"""
Factory to build dummy backend from given root URL
including creation of connection and mocking of capabilities doc
"""
root_url = root_url.rstrip("/") + "/"
requests_mock.get(root_url, json=build_capabilities(**(capabilities or None)))
connection = Connection(root_url)
return cls(requests_mock=requests_mock, connection=connection)

def setup_collection(
self,
collection_id: str,
*,
temporal: Union[bool, Tuple[str, str]] = True,
bands: Sequence[str] = ("B1", "B2", "B3"),
):
# TODO: also mock `/collections` overview
# TODO: option to override cube_dimensions as a whole, or override dimension names
cube_dimensions = {
"x": {"type": "spatial"},
"y": {"type": "spatial"},
}

if temporal:
cube_dimensions["t"] = {
"type": "temporal",
"extent": temporal if isinstance(temporal, tuple) else [None, None],
}
if bands:
cube_dimensions["bands"] = {"type": "bands", "values": list(bands)}

self._requests_mock.get(
self.connection.build_url(f"/collections/{collection_id}"),
# TODO: add more metadata?
json={
"id": collection_id,
# define temporal and band dim
"cube:dimensions": {"t": {"type": "temporal"}, "bands": {"type": "bands"}},
},
)
return self

def setup_file_format(self, name: str, type: str = "output", gis_data_types: Iterable[str] = ("raster",)):
self.file_formats[type][name] = {
"title": name,
"gis_data_types": list(gis_data_types),
"parameters": {},
}
self._requests_mock.get(self.connection.build_url("/file_formats"), json=self.file_formats)
return self

def _handle_post_result(self, request, context):
"""handler of `POST /result` (synchronous execute)"""
pg = request.json()["process"]["process_graph"]
Expand Down Expand Up @@ -150,10 +215,20 @@ def get_sync_pg(self) -> dict:
return self.sync_requests[0]

def get_batch_pg(self) -> dict:
"""Get one and only batch process graph"""
"""
Get process graph of the one and only batch job.
Fails when there is none or more than one.
"""
assert len(self.batch_jobs) == 1
return self.batch_jobs[max(self.batch_jobs.keys())]["pg"]

def get_validation_pg(self) -> dict:
"""
Get process graph of the one and only validation request.
"""
assert len(self.validation_requests) == 1
return self.validation_requests[0]

def get_pg(self, process_id: Optional[str] = None) -> dict:
"""
Get one and only batch process graph (sync or batch)
Expand Down
Loading