Skip to content

Commit

Permalink
Issue #401 Improve automatic adding of save_result
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Apr 11, 2023
1 parent d505757 commit 0f10972
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 38 deletions.
105 changes: 69 additions & 36 deletions openeo/rest/datacube.py
Expand Up @@ -58,6 +58,9 @@ class DataCube(_ProcessGraphAbstraction):
and this process graph can be "grown" to a desired workflow by calling the appropriate methods.
"""

# TODO: set this based on back-end or user preference?
_DEFAULT_RASTER_FORMAT = "GTiff"

def __init__(self, graph: PGNode, connection: 'openeo.Connection', metadata: CollectionMetadata = None):
super().__init__(pgnode=graph, connection=connection)
self.metadata = CollectionMetadata.get_or_create(metadata)
Expand Down Expand Up @@ -1810,36 +1813,41 @@ def atmospheric_correction(
})

@openeo_process
def save_result(self, format: str = "GTiff", options: dict = None) -> 'DataCube':
def save_result(
self, format: str = _DEFAULT_RASTER_FORMAT, options: Optional[dict] = None
) -> "DataCube":
formats = set(self._connection.list_output_formats().keys())
# TODO: map format to correct casing too?
if format.lower() not in {f.lower() for f in formats}:
raise ValueError("Invalid format {f!r}. Should be one of {s}".format(f=format, s=formats))
return self.process(
process_id="save_result",
arguments={
"data": THIS,
"format": format,
# TODO: leave out options if unset?
"options": options or {}
}
)

def download(
self, outputfile: Union[str, pathlib.Path, None] = None, format: Optional[str] = None,
options: Optional[dict] = None
):
def _ensure_save_result(
self, format: Optional[str] = None, options: Optional[dict] = None
) -> "DataCube":
"""
Download image collection, e.g. as GeoTIFF.
If outputfile is provided, the result is stored on disk locally, otherwise, a bytes object is returned.
The bytes object can be passed on to a suitable decoder for decoding.
Make sure there is a (final) `save_result` node in the process graph.
If there is already one: check if it is consistent with the given format/options (if any)
and add a new one otherwise.
:param outputfile: Optional, an output file if the result needs to be stored on disk.
:param format: Optional, an output format supported by the backend.
:param options: Optional, file format options
:return: None if the result is stored to disk, or a bytes object returned by the backend.
:param format: (optional) desired `save_result` file format
:param options: (optional) desired `save_result` file format parameters
:return:
"""
if self.result_node().process_id == "save_result":
# There is already a `save_result` node: check if it is consistent with given format/options
args = self.result_node().arguments
# TODO: move to generic data cube parent class (not only for raster cubes, but also vector cubes)
result_node = self.result_node()
if result_node.process_id == "save_result":
# There is already a `save_result` node:
# check if it is consistent with given format/options (if any)
args = result_node.arguments
if format is not None and format.lower() != args["format"].lower():
raise ValueError(
f"Existing `save_result` node with different format {args['format']!r} != {format!r}"
Expand All @@ -1851,10 +1859,28 @@ def download(
cube = self
else:
# No `save_result` node yet: automatically add it.
if not format:
format = guess_format(outputfile) if outputfile else "GTiff"
cube = self.save_result(format=format, options=options)
cube = self.save_result(
format=format or self._DEFAULT_RASTER_FORMAT, options=options
)
return cube

def download(
self, outputfile: Union[str, pathlib.Path, None] = None, format: Optional[str] = None,
options: Optional[dict] = None
):
"""
Download image collection, e.g. as GeoTIFF.
If outputfile is provided, the result is stored on disk locally, otherwise, a bytes object is returned.
The bytes object can be passed on to a suitable decoder for decoding.
:param outputfile: Optional, an output file if the result needs to be stored on disk.
:param format: Optional, an output format supported by the backend.
:param options: Optional, file format options
:return: None if the result is stored to disk, or a bytes object returned by the backend.
"""
if format is None and outputfile is not None:
format = guess_format(outputfile)
cube = self._ensure_save_result(format=format, options=options)
return self._connection.download(cube.flat_graph(), outputfile)

def validate(self) -> List[dict]:
Expand All @@ -1869,27 +1895,36 @@ def tiled_viewing_service(self, type: str, **kwargs) -> Service:
return self._connection.create_service(self.flat_graph(), type=type, **kwargs)

def execute_batch(
self,
outputfile: Union[str, pathlib.Path] = None, out_format: str = None,
print=print, max_poll_interval=60, connection_retry_interval=30,
job_options=None, **format_options) -> BatchJob:
self,
outputfile: Optional[Union[str, pathlib.Path]] = None,
out_format: Optional[str] = None,
*,
print: typing.Callable[[str], None] = print,
max_poll_interval: float = 60,
connection_retry_interval: float = 30,
job_options: Optional[dict] = None,
# TODO: avoid `format_options` as keyword arguments
**format_options,
) -> BatchJob:
"""
Evaluate the process graph by creating a batch job, and retrieving the results when it is finished.
This method is mostly recommended if the batch job is expected to run in a reasonable amount of time.
For very long-running jobs, you probably do not want to keep the client running.
:param job_options:
:param outputfile: The path of a file to which a result can be written
:param out_format: (optional) Format of the job result.
:param format_options: String Parameters for the job result format
:param out_format: (optional) File format to use for the job result.
:param job_options:
:param format_options: output file format parameters.
"""
if "format" in format_options and not out_format:
out_format = format_options["format"] # align with 'download' call arg name
if not out_format:
out_format = guess_format(outputfile) if outputfile else "GTiff"
job = self.create_job(out_format, job_options=job_options, **format_options)
if not out_format and outputfile:
out_format = guess_format(outputfile)

job = self.create_job(
format=out_format, job_options=job_options, format_options=format_options
)
return job.run_synchronous(
outputfile=outputfile,
print=print, max_poll_interval=max_poll_interval, connection_retry_interval=connection_retry_interval
Expand All @@ -1904,6 +1939,7 @@ def create_job(
plan: Optional[str] = None,
budget: Optional[float] = None,
job_options: Optional[dict] = None,
# TODO: avoid `format_options` as keyword arguments
**format_options,
) -> BatchJob:
"""
Expand All @@ -1914,22 +1950,19 @@ def create_job(
it still needs to be started and tracked explicitly.
Use :py:meth:`execute_batch` instead to have the openEO Python client take care of that job management.
:param out_format: String Format of the job result.
:param out_format: output file format.
:param title: job title
:param description: job description
:param plan: billing plan
:param budget: maximum cost the request is allowed to produce
:param job_options: A dictionary containing (custom) job options
:param format_options: String Parameters for the job result format
:param job_options: custom job options.
:param format_options: output file format parameters.
:return: Created job.
"""
# TODO: add option to also automatically start the job?
# TODO: avoid using all kwargs as format_options
# TODO: centralize `create_job` for `DataCube`, `VectorCube`, `MlModel`, ...
cube = self
if out_format:
# add `save_result` node
cube = cube.save_result(format=out_format, options=format_options)
cube = self._ensure_save_result(format=out_format, options=format_options)
return self._connection.create_job(
process_graph=cube.flat_graph(),
title=title,
Expand Down
2 changes: 1 addition & 1 deletion openeo/util.py
Expand Up @@ -437,7 +437,7 @@ def deep_set(data: dict, *keys, value):
raise ValueError("No keys given")


def guess_format(filename: Union[str, Path]):
def guess_format(filename: Union[str, Path]) -> str:
"""
Guess the output format from a given filename and return the corrected format.
Any names not in the dict get passed through.
Expand Down
79 changes: 78 additions & 1 deletion tests/rest/datacube/test_datacube.py
Expand Up @@ -4,16 +4,18 @@
- 1.0.0-style DataCube
"""

import functools
from datetime import date, datetime
import pathlib

import mock
import numpy as np
import pytest
import shapely
import shapely.geometry

from openeo.capabilities import ComparableVersion
from openeo.internal.warnings import UserDeprecationWarning
from openeo.rest import BandMathException
from openeo.rest.datacube import DataCube
from .conftest import API_URL
Expand Down Expand Up @@ -446,3 +448,78 @@ def result_callback(request, context):
requests_mock.post(API_URL + '/result', content=result_callback)
result = connection.load_collection("S2").download(format=format)
assert result == b"data"


class TestExecuteBatch:
@pytest.fixture
def get_create_job_pg(self, connection):
"""Fixture to help intercepting the process graph that was passed to Connection.create_job"""
with mock.patch.object(connection, "create_job") as create_job:

def get() -> dict:
assert create_job.call_count == 1
return create_job.call_args.kwargs["process_graph"]

yield get

def test_basic(self, connection, s2cube, get_create_job_pg, recwarn, caplog):
s2cube.execute_batch()
pg = get_create_job_pg()
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
assert pg["saveresult1"] == {
"process_id": "save_result",
"arguments": {
"data": {"from_node": "loadcollection1"},
"format": "GTiff",
"options": {},
},
"result": True,
}
assert recwarn.list == []
assert caplog.records == []

@pytest.mark.parametrize(
["format", "expected"],
[(None, "GTiff"), ("GTiff", "GTiff"), ("gtiff", "gtiff"), ("NetCDF", "NetCDF")],
)
def test_format(
self, connection, s2cube, get_create_job_pg, format, expected, recwarn, caplog
):
s2cube.execute_batch(format=format)
pg = get_create_job_pg()
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
assert pg["saveresult1"] == {
"process_id": "save_result",
"arguments": {
"data": {"from_node": "loadcollection1"},
"format": expected,
"options": {},
},
"result": True,
}
assert recwarn.list == []
assert caplog.records == []

@pytest.mark.parametrize(
["out_format", "expected"],
[("GTiff", "GTiff"), ("NetCDF", "NetCDF")],
)
def test_out_format(
self, connection, s2cube, get_create_job_pg, out_format, expected
):
with pytest.warns(
UserDeprecationWarning,
match="`out_format`.*is deprecated.*use `format` instead",
):
s2cube.execute_batch(out_format=out_format)
pg = get_create_job_pg()
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
assert pg["saveresult1"] == {
"process_id": "save_result",
"arguments": {
"data": {"from_node": "loadcollection1"},
"format": expected,
"options": {},
},
"result": True,
}

0 comments on commit 0f10972

Please sign in to comment.