Skip to content

Commit

Permalink
Issue #401 Improve automatic adding of save_result
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Apr 4, 2023
1 parent eda8e59 commit bdea97d
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 34 deletions.
148 changes: 116 additions & 32 deletions openeo/rest/datacube.py
Expand Up @@ -45,6 +45,8 @@
from openeo.udf import XarrayDataCube


DEFAULT_RASTER_FORMAT = "GTiff"

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -1797,22 +1799,28 @@ def atmospheric_correction(
})

@openeo_process
def save_result(self, format: str = "GTiff", options: dict = None) -> 'DataCube':
def save_result(
self, format: str = DEFAULT_RASTER_FORMAT, options: Optional[dict] = None
) -> "DataCube":
formats = set(self._connection.list_output_formats().keys())
# TODO: map format to correct casing too?
if format.lower() not in {f.lower() for f in formats}:
raise ValueError("Invalid format {f!r}. Should be one of {s}".format(f=format, s=formats))
return self.process(
process_id="save_result",
arguments={
"data": THIS,
"format": format,
# TODO: leave out options if unset?
"options": options or {}
}
)

def download(
self, outputfile: Union[str, pathlib.Path, None] = None, format: Optional[str] = None,
options: Optional[dict] = None
self,
outputfile: Optional[Union[str, pathlib.Path]] = None,
format: Optional[str] = None,
options: Optional[dict] = None,
):
"""
Download image collection, e.g. as GeoTIFF.
Expand All @@ -1824,9 +1832,28 @@ def download(
:param options: Optional, file format options
:return: None if the result is stored to disk, or a bytes object returned by the backend.
"""
if self.result_node().process_id == "save_result":
# There is already a `save_result` node: check if it is consistent with given format/options
args = self.result_node().arguments
if format is None and outputfile is not None:
format = guess_format(outputfile)
cube = self._ensure_save_result(format=format, options=options)
return self._connection.download(cube.flat_graph(), outputfile)

def _ensure_save_result(
self, format: Optional[str], options: Optional[dict]
) -> "DataCube":
"""
Make sure there is a (final) `save_result` node in the process graph.
If there is already one: check if it is consistent with the given format/options
and add a new one otherwise.
:param format: (optional) desired `save_result` file format
:param options: (optional) desired `save_result` file format parameters
:return:
"""
result_node = self.result_node()
if result_node.process_id == "save_result":
# There is already a `save_result` node:
# check if it is consistent with given format/options (if any)
args = result_node.arguments
if format is not None and format.lower() != args["format"].lower():
raise ValueError(
f"Existing `save_result` node with different format {args['format']!r} != {format!r}"
Expand All @@ -1838,11 +1865,10 @@ def download(
cube = self
else:
# No `save_result` node yet: automatically add it.
if not format:
format = guess_format(outputfile) if outputfile else "GTiff"
cube = self.save_result(format=format, options=options)

return self._connection.download(cube.flat_graph(), outputfile)
cube = self.save_result(
format=format or DEFAULT_RASTER_FORMAT, options=options
)
return cube

def validate(self) -> List[dict]:
"""
Expand All @@ -1856,35 +1882,74 @@ def tiled_viewing_service(self, type: str, **kwargs) -> Service:
return self._connection.create_service(self.flat_graph(), type=type, **kwargs)

def execute_batch(
self,
outputfile: Union[str, pathlib.Path] = None, out_format: str = None,
print=print, max_poll_interval=60, connection_retry_interval=30,
job_options=None, **format_options) -> BatchJob:
self,
outputfile: Optional[Union[str, pathlib.Path]] = None,
out_format: Optional[str] = None,
*,
format: Optional[str] = None,
print: typing.Callable[[str], None] = print,
max_poll_interval: float = 60,
connection_retry_interval: float = 30,
job_options: Optional[dict] = None,
format_options: Optional[dict] = None,
**kwargs,
) -> BatchJob:
"""
Evaluate the process graph by creating a batch job, and retrieving the results when it is finished.
This method is mostly recommended if the batch job is expected to run in a reasonable amount of time.
For very long-running jobs, you probably do not want to keep the client running.
:param job_options:
:param outputfile: The path of a file to which a result can be written
:param out_format: (optional) Format of the job result.
:param out_format: (optional) File format to use for the job result.
This argument is deprecated: use ``format`` instead
:param format: (optional) File format to use for the job result.
:param job_options:
:param format_options: String Parameters for the job result format
.. versionchanged:: 0.16.0 deprecate argument ``out_format`` in favor of new argument ``format``
"""
if "format" in format_options and not out_format:
out_format = format_options["format"] # align with 'download' call arg name
if not out_format:
out_format = guess_format(outputfile) if outputfile else "GTiff"
job = self.create_job(out_format, job_options=job_options, **format_options)
if out_format:
warnings.warn(
"Argument `out_format` in `execute_batch` is deprecated since 0.16.0, use `format` instead.",
category=UserDeprecationWarning,
)
assert format is None
format = out_format

if not format and outputfile:
format = guess_format(outputfile)

if kwargs:
warnings.warn(
"Deprecated usage of keyword arguments in `execute_batch()` to set file format options. "
"Instead, use the `format_options` argument or an explicit `save_result()` method call "
"to properly set file format options.",
category=UserDeprecationWarning,
)
assert format_options is None
format_options = kwargs

job = self.create_job(
format=format, job_options=job_options, format_options=format_options
)
return job.run_synchronous(
outputfile=outputfile,
print=print, max_poll_interval=max_poll_interval, connection_retry_interval=connection_retry_interval
)

def create_job(
self, out_format=None, title: str = None, description: str = None, plan: str = None, budget=None,
job_options=None, **format_options
self,
out_format: Optional[str] = None,
*,
format: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
plan: Optional[str] = None,
budget: Optional[float] = None,
job_options: Optional[dict] = None,
format_options: Optional[dict] = None,
**kwargs,
) -> BatchJob:
"""
Sends the datacube's process graph as a batch job to the back-end
Expand All @@ -1894,18 +1959,37 @@ def create_job(
it still needs to be started and tracked explicitly.
Use :py:meth:`execute_batch` instead to have the openEO Python client take care of that job management.
:param out_format: String Format of the job result.
:param job_options: A dictionary containing (custom) job options
:param format_options: String Parameters for the job result format
:return: status: Job resulting job.
:param out_format: (optional) output file format.
:param job_options: (optional) custom job options.
:param format_options: (optional) output file format parameters.
.. versionchanged:: 0.16.0 deprecate argument ``out_format`` in favor of new argument ``format``
"""
# TODO: add option to also automatically start the job?
img = self
if out_format:
# add `save_result` node
img = img.save_result(format=out_format, options=format_options)
warnings.warn(
"Argument `out_format` in `create_job` is deprecated since 0.16.0, use `format` instead.",
category=UserDeprecationWarning,
)
assert format is None
format = out_format

if kwargs:
warnings.warn(
"Deprecated usage of keyword arguments in `create_job()` to set file format options. "
"Instead, use the `format_options` argument or an explicit `save_result()` method call "
"to properly set file format options.",
category=UserDeprecationWarning,
)
# Legacy usage pattern of setting format options through keyword arguments
assert format_options is None
format_options = kwargs

cube = self
cube = cube._ensure_save_result(format=format, options=format_options)
return self._connection.create_job(
process_graph=img.flat_graph(),
process_graph=cube.flat_graph(),
title=title, description=description, plan=plan, budget=budget, additional=job_options
)

Expand Down
2 changes: 1 addition & 1 deletion openeo/util.py
Expand Up @@ -437,7 +437,7 @@ def deep_set(data: dict, *keys, value):
raise ValueError("No keys given")


def guess_format(filename: Union[str, Path]):
def guess_format(filename: Union[str, Path]) -> str:
"""
Guess the output format from a given filename and return the corrected format.
Any names not in the dict get passed through.
Expand Down
79 changes: 78 additions & 1 deletion tests/rest/datacube/test_datacube.py
Expand Up @@ -4,16 +4,18 @@
- 1.0.0-style DataCube
"""

import functools
from datetime import date, datetime
import pathlib

import mock
import numpy as np
import pytest
import shapely
import shapely.geometry

from openeo.capabilities import ComparableVersion
from openeo.internal.warnings import UserDeprecationWarning
from openeo.rest import BandMathException
from openeo.rest.datacube import DataCube
from .conftest import API_URL
Expand Down Expand Up @@ -446,3 +448,78 @@ def result_callback(request, context):
requests_mock.post(API_URL + '/result', content=result_callback)
result = connection.load_collection("S2").download(format=format)
assert result == b"data"


class TestExecuteBatch:
@pytest.fixture
def get_create_job_pg(self, connection):
"""Fixture to help intercepting the process graph that was passed to Connection.create_job"""
with mock.patch.object(connection, "create_job") as create_job:

def get() -> dict:
assert create_job.call_count == 1
return create_job.call_args.kwargs["process_graph"]

yield get

def test_basic(self, connection, s2cube, get_create_job_pg, recwarn, caplog):
s2cube.execute_batch()
pg = get_create_job_pg()
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
assert pg["saveresult1"] == {
"process_id": "save_result",
"arguments": {
"data": {"from_node": "loadcollection1"},
"format": "GTiff",
"options": {},
},
"result": True,
}
assert recwarn.list == []
assert caplog.records == []

@pytest.mark.parametrize(
["format", "expected"],
[(None, "GTiff"), ("GTiff", "GTiff"), ("gtiff", "gtiff"), ("NetCDF", "NetCDF")],
)
def test_format(
self, connection, s2cube, get_create_job_pg, format, expected, recwarn, caplog
):
s2cube.execute_batch(format=format)
pg = get_create_job_pg()
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
assert pg["saveresult1"] == {
"process_id": "save_result",
"arguments": {
"data": {"from_node": "loadcollection1"},
"format": expected,
"options": {},
},
"result": True,
}
assert recwarn.list == []
assert caplog.records == []

@pytest.mark.parametrize(
["out_format", "expected"],
[("GTiff", "GTiff"), ("NetCDF", "NetCDF")],
)
def test_out_format(
self, connection, s2cube, get_create_job_pg, out_format, expected
):
with pytest.warns(
UserDeprecationWarning,
match="`out_format`.*is deprecated.*use `format` instead",
):
s2cube.execute_batch(out_format=out_format)
pg = get_create_job_pg()
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
assert pg["saveresult1"] == {
"process_id": "save_result",
"arguments": {
"data": {"from_node": "loadcollection1"},
"format": expected,
"options": {},
},
"result": True,
}

0 comments on commit bdea97d

Please sign in to comment.