diff --git a/CHANGELOG.md b/CHANGELOG.md index 394d43c63..04c5840bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `MultiResult` helper class to build process graphs with multiple result nodes ([#391](https://github.com/Open-EO/openeo-python-client/issues/391)) + ### Changed ### Removed diff --git a/docs/api.rst b/docs/api.rst index 012719b69..a20747cc4 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -47,6 +47,15 @@ openeo.rest.mlmodel :inherited-members: +openeo.rest.multiresult +----------------------- + +.. automodule:: openeo.rest.multiresult + :members: MultiResult + :inherited-members: + :special-members: __init__ + + openeo.metadata ---------------- diff --git a/docs/datacube_construction.rst b/docs/datacube_construction.rst index 79163228e..5422a7a87 100644 --- a/docs/datacube_construction.rst +++ b/docs/datacube_construction.rst @@ -196,3 +196,55 @@ Re-parameterization ``````````````````` TODO + + + +.. _multi-result-process-graphs: +Building process graphs with multiple result nodes +=================================================== + +.. note:: + Multi-result support is added in version 0.35.0 + +Most openEO use cases are just about building a single result data cube, +which is readily covered in the openEO Python client library through classes like +:py:class:`~openeo.rest.datacube.DataCube` and :py:class:`~openeo.rest.vectorcube.VectorCube`. +It is straightforward to create a batch job from these, or execute/download them synchronously. + +The openEO API also allows multiple result nodes in a single process graph, +for example to persist intermediate results or produce results in different output formats. +To support this, the openEO Python client library provides the :py:class:`~openeo.rest.multiresult.MultiResult` class, +which allows to group multiple :py:class:`~openeo.rest.datacube.DataCube` and :py:class:`~openeo.rest.vectorcube.VectorCube` objects +in a single entity that can be used to create or run batch jobs. For example: + + +.. code-block:: python + + from openeo import MultiResult + + cube1 = ... + cube2 = ... + multi_result = MultiResult([cube1, cube2]) + job = multi_result.create_job() + + +Moreover, it is not necessary to explicitly create such a +:py:class:`~openeo.rest.multiresult.MultiResult` object, +as the :py:meth:`Connection.create_job() ` method +directly supports passing multiple data cube objects in a list, +which will be automatically grouped as a multi-result: + +.. code-block:: python + + cube1 = ... + cube2 = ... + job = connection.create_job([cube1, cube2]) + + +.. important:: + + Only a single :py:class:`~openeo.rest.connection.Connection` can be in play + when grouping multiple results like this. + As everything is to be merged in a single process graph + to be sent to a single backend, + it is not possible to mix cubes created from different connections. diff --git a/openeo/__init__.py b/openeo/__init__.py index 536c4a54c..e6b5edcaf 100644 --- a/openeo/__init__.py +++ b/openeo/__init__.py @@ -18,6 +18,7 @@ class BaseOpenEoException(Exception): from openeo.rest.datacube import UDF, DataCube from openeo.rest.graph_building import collection_property from openeo.rest.job import BatchJob, RESTJob +from openeo.rest.multiresult import MultiResult def client_version() -> str: diff --git a/openeo/internal/graph_building.py b/openeo/internal/graph_building.py index 1f1c6211a..6f5918ea2 100644 --- a/openeo/internal/graph_building.py +++ b/openeo/internal/graph_building.py @@ -10,11 +10,12 @@ import abc import collections +import copy import json import sys from contextlib import nullcontext from pathlib import Path -from typing import Any, Dict, Iterator, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from openeo.api.process import Parameter from openeo.internal.process_graph_visitor import ( @@ -243,7 +244,7 @@ def walk(x) -> Iterator[PGNode]: yield from walk(self.arguments) -def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, Any]) -> Dict[str, dict]: +def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, List[FlatGraphableMixin], Any]) -> Dict[str, dict]: """ Convert given object to a internal flat dict graph representation. """ @@ -252,12 +253,15 @@ def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, Any]) -> Dict[str, di # including `{"process_graph": {nodes}}` ("process graph") # or just the raw process graph nodes? if isinstance(x, dict): + # Assume given dict is already a flat graph representation return x elif isinstance(x, FlatGraphableMixin): return x.flat_graph() elif isinstance(x, (str, Path)): # Assume a JSON resource (raw JSON, path to local file, JSON url, ...) return load_json_resource(x) + elif isinstance(x, (list, tuple)) and all(isinstance(i, FlatGraphableMixin) for i in x): + return MultiLeafGraph(x).flat_graph() raise ValueError(x) @@ -322,20 +326,29 @@ def generate(self, process_id: str): class GraphFlattener(ProcessGraphVisitor): - def __init__(self, node_id_generator: FlatGraphNodeIdGenerator = None): + def __init__(self, node_id_generator: FlatGraphNodeIdGenerator = None, multi_input_mode: bool = False): super().__init__() self._node_id_generator = node_id_generator or FlatGraphNodeIdGenerator() self._last_node_id = None self._flattened: Dict[str, dict] = {} self._argument_stack = [] self._node_cache = {} + self._multi_input_mode = multi_input_mode def flatten(self, node: PGNode) -> Dict[str, dict]: """Consume given nested process graph and return flat dict representation""" + if self._flattened and not self._multi_input_mode: + raise RuntimeError("Flattening multiple graphs, but not in multi-input mode") self.accept_node(node) assert len(self._argument_stack) == 0 - self._flattened[self._last_node_id]["result"] = True - return self._flattened + return self.flattened(set_result_flag=not self._multi_input_mode) + + def flattened(self, set_result_flag: bool = True) -> Dict[str, dict]: + flat_graph = copy.deepcopy(self._flattened) + if set_result_flag: + # TODO #583 an "end" node is not necessarily a "result" node + flat_graph[self._last_node_id]["result"] = True + return flat_graph def accept_node(self, node: PGNode): # Process reused nodes only first time and remember node id. @@ -438,3 +451,26 @@ def _process_from_parameter(self, name: str) -> Any: if name not in self._parameters: raise ProcessGraphVisitException("No substitution value for parameter {p!r}.".format(p=name)) return self._parameters[name] + + +class MultiLeafGraph(FlatGraphableMixin): + """ + Container for process graphs with multiple leaf/result nodes. + """ + + __slots__ = ["_leaves"] + + def __init__(self, leaves: Iterable[FlatGraphableMixin]): + self._leaves = list(leaves) + + def flat_graph(self) -> Dict[str, dict]: + flattener = GraphFlattener(multi_input_mode=True) + for leaf in self._leaves: + if isinstance(leaf, PGNode): + flattener.flatten(leaf) + elif isinstance(leaf, _FromNodeMixin): + flattener.flatten(leaf.from_node()) + else: + raise ValueError(f"Unsupported type {type(leaf)}") + + return flattener.flattened(set_result_flag=True) diff --git a/openeo/rest/_testing.py b/openeo/rest/_testing.py index c62aa98fb..7dc079d76 100644 --- a/openeo/rest/_testing.py +++ b/openeo/rest/_testing.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import collections import json import re -from typing import Callable, Iterator, Optional, Sequence, Union +from typing import Callable, Iterable, Iterator, Optional, Sequence, Tuple, Union from openeo import Connection, DataCube from openeo.rest.vectorcube import VectorCube @@ -19,8 +21,12 @@ class DummyBackend: and allows inspection of posted process graphs """ + # TODO: move to openeo.testing + __slots__ = ( + "_requests_mock", "connection", + "file_formats", "sync_requests", "batch_jobs", "validation_requests", @@ -33,8 +39,14 @@ class DummyBackend: # Default result (can serve both as JSON or binary data) DEFAULT_RESULT = b'{"what?": "Result data"}' - def __init__(self, requests_mock, connection: Connection): + def __init__( + self, + requests_mock, + connection: Connection, + ): + self._requests_mock = requests_mock self.connection = connection + self.file_formats = {"input": {}, "output": {}} self.sync_requests = [] self.batch_jobs = {} self.validation_requests = [] @@ -69,6 +81,59 @@ def __init__(self, requests_mock, connection: Connection): ) requests_mock.post(connection.build_url("/validation"), json=self._handle_post_validation) + @classmethod + def at_url(cls, root_url: str, *, requests_mock, capabilities: Optional[dict] = None) -> DummyBackend: + """ + Factory to build dummy backend from given root URL + including creation of connection and mocking of capabilities doc + """ + root_url = root_url.rstrip("/") + "/" + requests_mock.get(root_url, json=build_capabilities(**(capabilities or None))) + connection = Connection(root_url) + return cls(requests_mock=requests_mock, connection=connection) + + def setup_collection( + self, + collection_id: str, + *, + temporal: Union[bool, Tuple[str, str]] = True, + bands: Sequence[str] = ("B1", "B2", "B3"), + ): + # TODO: also mock `/collections` overview + # TODO: option to override cube_dimensions as a whole, or override dimension names + cube_dimensions = { + "x": {"type": "spatial"}, + "y": {"type": "spatial"}, + } + + if temporal: + cube_dimensions["t"] = { + "type": "temporal", + "extent": temporal if isinstance(temporal, tuple) else [None, None], + } + if bands: + cube_dimensions["bands"] = {"type": "bands", "values": list(bands)} + + self._requests_mock.get( + self.connection.build_url(f"/collections/{collection_id}"), + # TODO: add more metadata? + json={ + "id": collection_id, + # define temporal and band dim + "cube:dimensions": {"t": {"type": "temporal"}, "bands": {"type": "bands"}}, + }, + ) + return self + + def setup_file_format(self, name: str, type: str = "output", gis_data_types: Iterable[str] = ("raster",)): + self.file_formats[type][name] = { + "title": name, + "gis_data_types": list(gis_data_types), + "parameters": {}, + } + self._requests_mock.get(self.connection.build_url("/file_formats"), json=self.file_formats) + return self + def _handle_post_result(self, request, context): """handler of `POST /result` (synchronous execute)""" pg = request.json()["process"]["process_graph"] @@ -150,10 +215,20 @@ def get_sync_pg(self) -> dict: return self.sync_requests[0] def get_batch_pg(self) -> dict: - """Get one and only batch process graph""" + """ + Get process graph of the one and only batch job. + Fails when there is none or more than one. + """ assert len(self.batch_jobs) == 1 return self.batch_jobs[max(self.batch_jobs.keys())]["pg"] + def get_validation_pg(self) -> dict: + """ + Get process graph of the one and only validation request. + """ + assert len(self.validation_requests) == 1 + return self.validation_requests[0] + def get_pg(self, process_id: Optional[str] = None) -> dict: """ Get one and only batch process graph (sync or batch) diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 65126cf9a..79ee478f8 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -21,6 +21,7 @@ List, Optional, Sequence, + Set, Tuple, Union, ) @@ -53,7 +54,7 @@ OpenEoClientException, OpenEoRestError, ) -from openeo.rest._datacube import build_child_callback +from openeo.rest._datacube import _ProcessGraphAbstraction, build_child_callback from openeo.rest.auth.auth import BasicBearerAuth, BearerAuth, NullAuth, OidcBearerAuth from openeo.rest.auth.config import AuthConfig, RefreshTokenStore from openeo.rest.auth.oidc import ( @@ -1128,11 +1129,19 @@ def user_defined_process(self, user_defined_process_id: str) -> RESTUserDefinedP """ return RESTUserDefinedProcess(user_defined_process_id=user_defined_process_id, connection=self) - def validate_process_graph(self, process_graph: Union[dict, FlatGraphableMixin, Any]) -> List[dict]: + def validate_process_graph( + self, process_graph: Union[dict, FlatGraphableMixin, str, Path, List[FlatGraphableMixin]] + ) -> List[dict]: """ Validate a process graph without executing it. - :param process_graph: (flat) dict representing process graph + :param process_graph: openEO-style (flat) process graph representation, + or an object that can be converted to such a representation: + a dictionary, a :py:class:`~openeo.rest.datacube.DataCube` object, + a string with a JSON representation, + a local file path or URL to a JSON representation, + a :py:class:`~openeo.rest.multiresult.MultiResult` object, ... + :return: list of errors (dictionaries with "code" and "message" fields) """ pg_with_metadata = self._build_request_with_process_graph(process_graph)["process"] @@ -1608,12 +1617,19 @@ def upload_file( metadata = resp.json() return UserFile.from_metadata(metadata=metadata, connection=self) - def _build_request_with_process_graph(self, process_graph: Union[dict, FlatGraphableMixin, Any], **kwargs) -> dict: + def _build_request_with_process_graph( + self, + process_graph: Union[dict, FlatGraphableMixin, str, Path, List[FlatGraphableMixin]], + **kwargs, + ) -> dict: """ Prepare a json payload with a process graph to submit to /result, /services, /jobs, ... :param process_graph: flat dict representing a "process graph with metadata" ({"process": {"process_graph": ...}, ...}) """ # TODO: make this a more general helper (like `as_flat_graph`) + connections = extract_connections(process_graph) + if any(c != self for c in connections): + raise OpenEoClientException(f"Mixing different connections: {self} and {connections}.") result = kwargs process_graph = as_flat_graph(process_graph) if "process_graph" not in process_graph: @@ -1656,7 +1672,7 @@ def _preflight_validation(self, pg_with_metadata: dict, *, validate: Optional[bo # TODO: unify `download` and `execute` better: e.g. `download` always writes to disk, `execute` returns result (raw or as JSON decoded dict) def download( self, - graph: Union[dict, FlatGraphableMixin, str, Path], + graph: Union[dict, FlatGraphableMixin, str, Path, List[FlatGraphableMixin]], outputfile: Union[Path, str, None] = None, *, timeout: Optional[int] = None, @@ -1695,7 +1711,7 @@ def download( def execute( self, - process_graph: Union[dict, str, Path], + process_graph: Union[dict, FlatGraphableMixin, str, Path, List[FlatGraphableMixin]], *, timeout: Optional[int] = None, validate: Optional[bool] = None, @@ -1732,7 +1748,7 @@ def execute( def create_job( self, - process_graph: Union[dict, str, Path, FlatGraphableMixin], + process_graph: Union[dict, FlatGraphableMixin, str, Path, List[FlatGraphableMixin]], *, title: Optional[str] = None, description: Optional[str] = None, @@ -1744,8 +1760,12 @@ def create_job( """ Create a new job from given process graph on the back-end. - :param process_graph: (flat) dict representing a process graph, or process graph as raw JSON string, - or as local file path or URL + :param process_graph: openEO-style (flat) process graph representation, + or an object that can be converted to such a representation: + a dictionary, a :py:class:`~openeo.rest.datacube.DataCube` object, + a string with a JSON representation, + a local file path or URL to a JSON representation, + a :py:class:`~openeo.rest.multiresult.MultiResult` object, ... :param title: job title :param description: job description :param plan: The billing plan to process and charge the job with @@ -1755,6 +1775,9 @@ def create_job( :param validate: Optional toggle to enable/prevent validation of the process graphs before execution (overruling the connection's ``auto_validate`` setting). :return: Created job + + .. versionchanged:: 0.35.0 + Add :ref:`multi-result support `. """ # TODO move all this (BatchJob factory) logic to BatchJob? @@ -1968,3 +1991,25 @@ def paginate(con: Connection, url: str, params: Optional[dict] = None, callback: url = next_links[0]["href"] page += 1 params = {} + + +def extract_connections( + data: Union[_ProcessGraphAbstraction, Sequence[_ProcessGraphAbstraction], Any] +) -> Set[Connection]: + """ + Extract the :py:class:`Connection` object(s) linked from a given data construct. + Typical use case is to get the connection from a :py:class:`DataCube`, + but can also extract multiple connections from a list of data cubes. + """ + connections = set() + # TODO: define some kind of "Connected" interface/mixin/protocol + # for objects that contain a connection instead of just checking for _ProcessGraphAbstraction + # TODO: also support extracting connections from other objects like BatchJob, ... + if isinstance(data, _ProcessGraphAbstraction) and data.connection: + connections.add(data.connection) + elif isinstance(data, (list, tuple, set)): + for item in data: + if isinstance(item, _ProcessGraphAbstraction) and item.connection: + connections.add(item.connection) + + return connections diff --git a/openeo/rest/datacube.py b/openeo/rest/datacube.py index 2c0f8950d..c91fb722a 100644 --- a/openeo/rest/datacube.py +++ b/openeo/rest/datacube.py @@ -85,7 +85,9 @@ class DataCube(_ProcessGraphAbstraction): # TODO: set this based on back-end or user preference? _DEFAULT_RASTER_FORMAT = "GTiff" - def __init__(self, graph: PGNode, connection: Optional[Connection], metadata: Optional[CollectionMetadata] = None): + def __init__( + self, graph: PGNode, connection: Optional[Connection] = None, metadata: Optional[CollectionMetadata] = None + ): super().__init__(pgnode=graph, connection=connection) self.metadata: Optional[CollectionMetadata] = metadata diff --git a/openeo/rest/multiresult.py b/openeo/rest/multiresult.py new file mode 100644 index 000000000..20733b68d --- /dev/null +++ b/openeo/rest/multiresult.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from typing import Dict, List, Optional + +from openeo import BatchJob +from openeo.internal.graph_building import FlatGraphableMixin, MultiLeafGraph +from openeo.rest import OpenEoClientException +from openeo.rest.connection import Connection, extract_connections + + +class MultiResult(FlatGraphableMixin): + """ + Helper to create and run batch jobs with process graphs + that contain multiple result nodes + or, more generally speaking, multiple process graph "leaf" nodes. + + Provide multiple + :py:class:`~openeo.rest.datacube.DataCube`/:py:class:`~openeo.rest.vectorcube.VectorCube` + instances to the constructor, + and start a batch job from that, + for example as follows: + + .. code-block:: python + + from openeo import MultiResult + + cube1 = ... + cube2 = ... + multi_result = MultiResult([cube1, cube2]) + job = multi_result.create_job() + + .. seealso:: + + :ref:`multi-result-process-graphs` + + .. versionadded:: 0.35.0 + """ + + __slots__ = ("_multi_leaf_graph", "_connection") + + def __init__(self, leaves: List[FlatGraphableMixin], connection: Optional[Connection] = None): + """ + Build a :py:class:`MultiResult` instance from multiple leaf nodes + + :param leaves: list of objects that can be + converted to an openEO-style (flat) process graph representation, + typically :py:class:`~openeo.rest.datacube.DataCube` + or :py:class:`~openeo.rest.vectorcube.VectorCube` instances. + :param connection: Optional connection to use for creating/starting batch jobs, + for special use cases where the provided leaf instances + are not already associated with a connection. + """ + self._multi_leaf_graph = MultiLeafGraph(leaves=leaves) + self._connection = self._extract_connection(leaves=leaves, connection=connection) + + @staticmethod + def _extract_connection(leaves: List[FlatGraphableMixin], connection: Optional[Connection] = None) -> Connection: + """ + Extract common connection from leaves and/or explicitly provided connection. + Fails if there are multiple or none. + """ + connections = set() + if connection: + connections.add(connection) + connections.update(extract_connections(leaves)) + + if len(connections) == 1: + return connections.pop() + elif len(connections) == 0: + raise OpenEoClientException("No connection in any of the MultiResult leaves") + else: + raise OpenEoClientException("MultiResult with multiple different connections") + + def flat_graph(self) -> Dict[str, dict]: + return self._multi_leaf_graph.flat_graph() + + def create_job( + self, + *, + title: Optional[str] = None, + description: Optional[str] = None, + job_options: Optional[dict] = None, + validate: Optional[bool] = None, + ) -> BatchJob: + return self._connection.create_job( + process_graph=self._multi_leaf_graph, + title=title, + description=description, + additional=job_options, + validate=validate, + ) + + def execute_batch( + self, + *, + title: Optional[str] = None, + description: Optional[str] = None, + job_options: Optional[dict] = None, + validate: Optional[bool] = None, + ) -> BatchJob: + job = self.create_job(title=title, description=description, job_options=job_options, validate=validate) + return job.run_synchronous() diff --git a/openeo/rest/udp.py b/openeo/rest/udp.py index aea78f093..0df9015ab 100644 --- a/openeo/rest/udp.py +++ b/openeo/rest/udp.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing +from pathlib import Path from typing import List, Optional, Union from openeo.api.process import Parameter @@ -16,7 +17,7 @@ def build_process_dict( - process_graph: Union[dict, FlatGraphableMixin], + process_graph: Union[dict, FlatGraphableMixin, Path, List[FlatGraphableMixin]], process_id: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, diff --git a/tests/internal/test_graphbuilding.py b/tests/internal/test_graphbuilding.py index 1bdf56015..44f304c4b 100644 --- a/tests/internal/test_graphbuilding.py +++ b/tests/internal/test_graphbuilding.py @@ -1,5 +1,7 @@ import io +import re import textwrap +from pathlib import Path import pytest @@ -7,11 +9,15 @@ from openeo.api.process import Parameter from openeo.internal.graph_building import ( FlatGraphNodeIdGenerator, + GraphFlattener, + MultiLeafGraph, PGNode, PGNodeGraphUnflattener, ReduceNode, + as_flat_graph, ) from openeo.internal.process_graph_visitor import ProcessGraphVisitException +from openeo.rest.datacube import DataCube def test_pgnode_process_id(): @@ -143,6 +149,91 @@ def test_flat_graph_key_generate(): assert g.generate("foo") == "foo3" +class TestGraphFlattener: + def test_simple(self): + node = PGNode("foo", bar="meh") + flattener = GraphFlattener() + assert flattener.flatten(node) == {"foo1": {"process_id": "foo", "arguments": {"bar": "meh"}, "result": True}} + + def test_chain(self): + a = PGNode("a", bar="meh") + b = PGNode("b", a=a) + c = PGNode("c", a=a, b=b) + flattener = GraphFlattener() + assert flattener.flatten(c) == { + "a1": {"process_id": "a", "arguments": {"bar": "meh"}}, + "b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}}, + "c1": { + "process_id": "c", + "arguments": {"a": {"from_node": "a1"}, "b": {"from_node": "b1"}}, + "result": True, + }, + } + + def test_no_multi_input_mode(self): + a = PGNode("a") + b = PGNode("b", a=a) + flattener = GraphFlattener() + flat_graph = flattener.flatten(a) + assert flat_graph == {"a1": {"process_id": "a", "arguments": {}, "result": True}} + with pytest.raises(RuntimeError, match="not in multi-input mode"): + flattener.flatten(b) + assert flat_graph == {"a1": {"process_id": "a", "arguments": {}, "result": True}} + + def test_multi_input_mode(self): + a = PGNode("a") + b = PGNode("b", a=a) + c = PGNode("c", a=a) + flattener = GraphFlattener(multi_input_mode=True) + # Flatten b + assert flattener.flatten(b) == { + "a1": {"process_id": "a", "arguments": {}}, + "b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}}, + } + assert flattener.flattened() == { + "a1": {"process_id": "a", "arguments": {}}, + "b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}, "result": True}, + } + # Flatten c + assert flattener.flatten(c) == { + "a1": {"process_id": "a", "arguments": {}}, + "b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}}, + "c1": {"process_id": "c", "arguments": {"a": {"from_node": "a1"}}}, + } + assert flattener.flattened() == { + "a1": {"process_id": "a", "arguments": {}}, + "b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}}, + "c1": {"process_id": "c", "arguments": {"a": {"from_node": "a1"}}, "result": True}, + } + + def test_multi_input_mode_mutation(self): + """Verify that previously produced flat graphs are not silently mutated""" + a = PGNode("a") + b = PGNode("b", a=a) + flattener = GraphFlattener(multi_input_mode=True) + a_flat = flattener.flatten(a) + assert a_flat == { + "a1": {"process_id": "a", "arguments": {}}, + } + b_flat = flattener.flatten(b) + assert b_flat == { + "a1": {"process_id": "a", "arguments": {}}, + "b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}}, + } + assert flattener.flattened() == { + "a1": {"process_id": "a", "arguments": {}}, + "b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}, "result": True}, + } + # Original graphs are not mutated silently + assert a_flat == { + "a1": {"process_id": "a", "arguments": {}}, + } + assert b_flat == { + "a1": {"process_id": "a", "arguments": {}}, + "b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}}, + } + + def test_build_and_flatten_simple(): node = PGNode("foo") assert node.flat_graph() == {"foo1": {"process_id": "foo", "arguments": {}, "result": True}} @@ -412,3 +503,68 @@ def test_walk_nodes_nested(): walk = list(node.walk_nodes()) assert all(isinstance(n, PGNode) for n in walk) assert set(n.process_id for n in walk) == {"load1", "max", "foo", "load2", "add", "five"} + + +def test_as_flat_graph_dict(): + pg = {"foo1": {"process_id": "foo", "arguments": {"color": "red"}, "result": True}} + assert as_flat_graph(pg) == {"foo1": {"process_id": "foo", "arguments": {"color": "red"}, "result": True}} + + +def test_as_flat_graph_pgnode(): + node = PGNode("foo", color="red") + assert as_flat_graph(node) == {"foo1": {"process_id": "foo", "arguments": {"color": "red"}, "result": True}} + + +def test_as_flat_graph_path(tmp_path): + path = tmp_path / "graph.json" + with path.open("w") as f: + f.write('{"foo1": {"process_id": "foo", "arguments": {"color": "red"}, "result": true}}') + assert as_flat_graph(str(path)) == {"foo1": {"process_id": "foo", "arguments": {"color": "red"}, "result": True}} + assert as_flat_graph(Path(path)) == {"foo1": {"process_id": "foo", "arguments": {"color": "red"}, "result": True}} + + +def test_as_flat_graph_pgnode_list(): + a = PGNode("a") + b = PGNode("b", a=a) + c = PGNode("c", a=a) + expected = { + "a1": {"process_id": "a", "arguments": {}}, + "b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}}, + "c1": {"process_id": "c", "arguments": {"a": {"from_node": "a1"}}, "result": True}, + } + assert as_flat_graph([b, c]) == expected + assert as_flat_graph((b, c)) == expected + + +class TestMultiLeafGraph: + def test_simple(self): + multi = MultiLeafGraph([PGNode("foo"), PGNode("bar")]) + assert multi.flat_graph() == { + "foo1": {"process_id": "foo", "arguments": {}}, + "bar1": {"process_id": "bar", "arguments": {}, "result": True}, + } + + def test_simple_duplicates(self): + multi = MultiLeafGraph([PGNode("foo"), PGNode("foo")]) + assert multi.flat_graph() == { + "foo1": {"process_id": "foo", "arguments": {}}, + "foo2": {"process_id": "foo", "arguments": {}, "result": True}, + } + + def test_multi_save_result_same_root(self): + load_collection = DataCube(PGNode("load_collection", collection_id="S2")) + save_a = load_collection.save_result(format="GTiff") + save_b = load_collection.save_result(format="NetCDF") + multi = MultiLeafGraph([save_a, save_b]) + assert multi.flat_graph() == { + "loadcollection1": {"process_id": "load_collection", "arguments": {"collection_id": "S2"}}, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}}, + }, + "saveresult2": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "NetCDF", "options": {}}, + "result": True, + }, + } diff --git a/tests/rest/conftest.py b/tests/rest/conftest.py index c193bfd73..5879929c5 100644 --- a/tests/rest/conftest.py +++ b/tests/rest/conftest.py @@ -101,5 +101,21 @@ def con120(requests_mock, api_capabilities): @pytest.fixture -def dummy_backend(requests_mock, con100) -> DummyBackend: - yield DummyBackend(requests_mock=requests_mock, connection=con100) +def dummy_backend(requests_mock, con120) -> DummyBackend: + dummy_backend = DummyBackend(requests_mock=requests_mock, connection=con120) + dummy_backend.setup_collection("S2") + dummy_backend.setup_file_format("GTiff") + dummy_backend.setup_file_format("netCDF") + return dummy_backend + + +@pytest.fixture +def another_dummy_backend(requests_mock) -> DummyBackend: + root_url = "https://openeo.other.test/" + another_dummy_backend = DummyBackend.at_url( + root_url, requests_mock=requests_mock, capabilities={"api_version": "1.2.0"} + ) + another_dummy_backend.setup_collection("S2") + another_dummy_backend.setup_file_format("GTiff") + another_dummy_backend.setup_file_format("netCDF") + return another_dummy_backend diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index 2a96e4fe2..a7e7d318b 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -25,7 +25,7 @@ OpenEoClientException, OpenEoRestError, ) -from openeo.rest._testing import build_capabilities +from openeo.rest._testing import DummyBackend, build_capabilities from openeo.rest.auth.auth import BearerAuth, NullAuth from openeo.rest.auth.oidc import OidcException from openeo.rest.auth.testing import ABSENT, OidcMock @@ -35,6 +35,7 @@ Connection, RestApiConnection, connect, + extract_connections, paginate, ) from openeo.rest.vectorcube import VectorCube @@ -3648,3 +3649,194 @@ def test_create_job_validation( else: assert caplog.messages == [] assert dummy_backend.validation_requests == [] + + +def test_extract_connections_elementary(): + assert extract_connections(123) == set() + assert extract_connections("foo") == set() + assert extract_connections([1, 2, 3]) == set() + assert extract_connections((1, 2, 3)) == set() + assert extract_connections({1, 2, 3}) == set() + assert extract_connections({"a": "b", "c": "d"}) == set() + + +def test_extract_connections_cube(dummy_backend): + con = dummy_backend.connection + cube = con.load_collection("S2") + assert extract_connections(cube) == {con} + + +def test_extract_connections_cube_list(dummy_backend): + con = dummy_backend.connection + cube1 = con.load_collection("S2") + cube2 = con.load_collection("S2") + assert extract_connections([cube1, cube2]) == {con} + + +def test_extract_connections_cube_list_mixed(dummy_backend, another_dummy_backend): + con1 = dummy_backend.connection + con2 = another_dummy_backend.connection + cube1 = con1.load_collection("S2") + cube2 = con2.load_collection("S2") + assert extract_connections([cube1]) == {con1} + assert extract_connections([cube2]) == {con2} + assert extract_connections([cube1, cube2]) == {con1, con2} + assert extract_connections((cube1, cube2)) == {con1, con2} + + +def test_create_job_mixed_connections(dummy_backend, another_dummy_backend): + con = dummy_backend.connection + cube = con.load_collection("S2") + + other_connection = another_dummy_backend.connection + with pytest.raises(OpenEoClientException, match="Mixing different connections"): + other_connection.create_job(cube) + + +class TestMultiResultHandling: + + def test_create_job_with_cube_list(self, con120, dummy_backend): + cube = con120.load_collection("S2") + save1 = cube.save_result(format="GTiff") + save2 = cube.save_result(format="netCDF") + con120.create_job([save1, save2]) + assert dummy_backend.get_batch_pg() == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}}, + }, + "saveresult2": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "netCDF", "options": {}}, + "result": True, + }, + } + + def test_download_with_cube_list(self, con120, dummy_backend, tmp_path): + dummy_backend.next_result = b"-:[ZIP data]:-" + + cube = con120.load_collection("S2") + save1 = cube.save_result(format="GTiff") + save2 = cube.save_result(format="netCDF") + output_path = tmp_path / "result.zip" + con120.download([save1, save2], outputfile=output_path) + assert dummy_backend.get_sync_pg() == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}}, + }, + "saveresult2": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "netCDF", "options": {}}, + "result": True, + }, + } + assert output_path.read_bytes() == b"-:[ZIP data]:-" + + def test_synchronous_execute_with_cube_list(self, con120, dummy_backend): + dummy_backend.next_result = b"-:[ZIP data]:-" + + cube = con120.load_collection("S2") + save1 = cube.save_result(format="GTiff") + save2 = cube.save_result(format="netCDF") + res = con120.execute([save1, save2], auto_decode=False) + assert dummy_backend.get_sync_pg() == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}}, + }, + "saveresult2": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "netCDF", "options": {}}, + "result": True, + }, + } + assert res.content == b"-:[ZIP data]:-" + + def test_validate_with_cube_list(self, con120, dummy_backend): + cube = con120.load_collection("S2") + save1 = cube.save_result(format="GTiff") + save2 = cube.save_result(format="netCDF") + con120.validate_process_graph([save1, save2]) + assert dummy_backend.get_validation_pg() == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}}, + }, + "saveresult2": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "netCDF", "options": {}}, + "result": True, + }, + } + + def test_create_job_with_mixed_connections(self, con120, dummy_backend, another_dummy_backend): + other_connection = another_dummy_backend.connection + + save1 = con120.load_collection("S2").save_result(format="GTiff") + save2 = other_connection.load_collection("S2").save_result(format="netCDF") + + # Same connection should work + con120.create_job([save1]) + other_connection.create_job([save2]) + + # Mixing connections + with pytest.raises(OpenEoClientException, match="Mixing different connections"): + con120.create_job([save1, save2]) + + with pytest.raises(OpenEoClientException, match="Mixing different connections"): + other_connection.create_job([save1, save2]) + + def test_create_job_intermediate_resultst(self, con120, dummy_backend): + cube = con120.load_collection("S2") + save1 = cube.save_result(format="GTiff") + reduced = cube.reduce_temporal("mean") + save2 = reduced.save_result(format="GTiff") + con120.create_job([save1, save2]) + assert dummy_backend.get_batch_pg() == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}}, + }, + "reducedimension1": { + "process_id": "reduce_dimension", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "dimension": "t", + "reducer": { + "process_graph": { + "mean1": { + "arguments": {"data": {"from_parameter": "data"}}, + "process_id": "mean", + "result": True, + } + } + }, + }, + }, + "saveresult2": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "reducedimension1"}, "format": "GTiff", "options": {}}, + "result": True, + }, + } diff --git a/tests/rest/test_multiresult.py b/tests/rest/test_multiresult.py new file mode 100644 index 000000000..3b5486c50 --- /dev/null +++ b/tests/rest/test_multiresult.py @@ -0,0 +1,112 @@ +import pytest + +from openeo import BatchJob +from openeo.rest._testing import DummyBackend +from openeo.rest.multiresult import MultiResult + + +class TestMultiResultHandling: + def test_flat_graph(self, dummy_backend): + cube = dummy_backend.connection.load_collection("S2") + save1 = cube.save_result(format="GTiff") + save2 = cube.save_result(format="netCDF") + multi_result = MultiResult([save1, save2]) + assert multi_result.flat_graph() == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}}, + }, + "saveresult2": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "netCDF", "options": {}}, + "result": True, + }, + } + + def test_create_job_method(self, dummy_backend): + con = dummy_backend.connection + cube = con.load_collection("S2") + save1 = cube.save_result(format="GTiff") + save2 = cube.save_result(format="netCDF") + multi_result = MultiResult([save1, save2]) + multi_result.create_job(title="multi result test") + assert dummy_backend.batch_jobs == { + "job-000": { + "job_id": "job-000", + "pg": { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}}, + }, + "saveresult2": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "netCDF", "options": {}}, + "result": True, + }, + }, + "status": "created", + "title": "multi result test", + } + } + + def test_create_job_through_connection(self, con120, dummy_backend): + con = dummy_backend.connection + cube = con120.load_collection("S2") + save1 = cube.save_result(format="GTiff") + save2 = cube.save_result(format="netCDF") + multi_result = MultiResult([save1, save2]) + con.create_job(multi_result) + assert dummy_backend.get_batch_pg() == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}}, + }, + "saveresult2": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "netCDF", "options": {}}, + "result": True, + }, + } + + def test_execute_batch(self, dummy_backend): + con = dummy_backend.connection + cube = con.load_collection("S2") + save1 = cube.save_result(format="GTiff") + save2 = cube.save_result(format="netCDF") + multi_result = MultiResult([save1, save2]) + job = multi_result.execute_batch(title="multi result test") + assert isinstance(job, BatchJob) + assert dummy_backend.batch_jobs == { + "job-000": { + "job_id": "job-000", + "pg": { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}}, + }, + "saveresult2": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "netCDF", "options": {}}, + "result": True, + }, + }, + "status": "finished", + "title": "multi result test", + } + }