Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve automatic adding of save_result #403

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
148 changes: 116 additions & 32 deletions openeo/rest/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from openeo.udf import XarrayDataCube


DEFAULT_RASTER_FORMAT = "GTiff"

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -1797,22 +1799,28 @@ def atmospheric_correction(
})

@openeo_process
def save_result(self, format: str = "GTiff", options: dict = None) -> 'DataCube':
def save_result(
self, format: str = DEFAULT_RASTER_FORMAT, options: Optional[dict] = None
) -> "DataCube":
formats = set(self._connection.list_output_formats().keys())
# TODO: map format to correct casing too?
if format.lower() not in {f.lower() for f in formats}:
raise ValueError("Invalid format {f!r}. Should be one of {s}".format(f=format, s=formats))
return self.process(
process_id="save_result",
arguments={
"data": THIS,
"format": format,
# TODO: leave out options if unset?
"options": options or {}
}
)

def download(
self, outputfile: Union[str, pathlib.Path, None] = None, format: Optional[str] = None,
options: Optional[dict] = None
self,
outputfile: Optional[Union[str, pathlib.Path]] = None,
format: Optional[str] = None,
options: Optional[dict] = None,
):
"""
Download image collection, e.g. as GeoTIFF.
Expand All @@ -1824,9 +1832,28 @@ def download(
:param options: Optional, file format options
:return: None if the result is stored to disk, or a bytes object returned by the backend.
"""
if self.result_node().process_id == "save_result":
# There is already a `save_result` node: check if it is consistent with given format/options
args = self.result_node().arguments
if format is None and outputfile is not None:
format = guess_format(outputfile)
cube = self._ensure_save_result(format=format, options=options)
return self._connection.download(cube.flat_graph(), outputfile)

def _ensure_save_result(
self, format: Optional[str], options: Optional[dict]
) -> "DataCube":
"""
Make sure there is a (final) `save_result` node in the process graph.
If there is already one: check if it is consistent with the given format/options
and add a new one otherwise.

:param format: (optional) desired `save_result` file format
:param options: (optional) desired `save_result` file format parameters
:return:
"""
result_node = self.result_node()
if result_node.process_id == "save_result":
# There is already a `save_result` node:
# check if it is consistent with given format/options (if any)
args = result_node.arguments
if format is not None and format.lower() != args["format"].lower():
raise ValueError(
f"Existing `save_result` node with different format {args['format']!r} != {format!r}"
Expand All @@ -1838,11 +1865,10 @@ def download(
cube = self
else:
# No `save_result` node yet: automatically add it.
if not format:
format = guess_format(outputfile) if outputfile else "GTiff"
cube = self.save_result(format=format, options=options)

return self._connection.download(cube.flat_graph(), outputfile)
cube = self.save_result(
format=format or DEFAULT_RASTER_FORMAT, options=options
)
return cube

def validate(self) -> List[dict]:
"""
Expand All @@ -1856,35 +1882,74 @@ def tiled_viewing_service(self, type: str, **kwargs) -> Service:
return self._connection.create_service(self.flat_graph(), type=type, **kwargs)

def execute_batch(
self,
outputfile: Union[str, pathlib.Path] = None, out_format: str = None,
print=print, max_poll_interval=60, connection_retry_interval=30,
job_options=None, **format_options) -> BatchJob:
self,
outputfile: Optional[Union[str, pathlib.Path]] = None,
out_format: Optional[str] = None,
*,
format: Optional[str] = None,
print: typing.Callable[[str], None] = print,
max_poll_interval: float = 60,
connection_retry_interval: float = 30,
job_options: Optional[dict] = None,
format_options: Optional[dict] = None,
**kwargs,
) -> BatchJob:
"""
Evaluate the process graph by creating a batch job, and retrieving the results when it is finished.
This method is mostly recommended if the batch job is expected to run in a reasonable amount of time.

For very long-running jobs, you probably do not want to keep the client running.

:param job_options:
:param outputfile: The path of a file to which a result can be written
:param out_format: (optional) Format of the job result.
:param out_format: (optional) File format to use for the job result.
This argument is deprecated: use ``format`` instead
:param format: (optional) File format to use for the job result.
:param job_options:
:param format_options: String Parameters for the job result format

.. versionchanged:: 0.16.0 deprecate argument ``out_format`` in favor of new argument ``format``
"""
if "format" in format_options and not out_format:
out_format = format_options["format"] # align with 'download' call arg name
if not out_format:
out_format = guess_format(outputfile) if outputfile else "GTiff"
job = self.create_job(out_format, job_options=job_options, **format_options)
if out_format:
warnings.warn(
"Argument `out_format` in `execute_batch` is deprecated since 0.16.0, use `format` instead.",
category=UserDeprecationWarning,
)
assert format is None
format = out_format

if not format and outputfile:
format = guess_format(outputfile)

if kwargs:
warnings.warn(
"Deprecated usage of keyword arguments in `execute_batch()` to set file format options. "
"Instead, use the `format_options` argument or an explicit `save_result()` method call "
"to properly set file format options.",
category=UserDeprecationWarning,
)
assert format_options is None
format_options = kwargs

job = self.create_job(
format=format, job_options=job_options, format_options=format_options
)
return job.run_synchronous(
outputfile=outputfile,
print=print, max_poll_interval=max_poll_interval, connection_retry_interval=connection_retry_interval
)

def create_job(
self, out_format=None, title: str = None, description: str = None, plan: str = None, budget=None,
job_options=None, **format_options
self,
out_format: Optional[str] = None,
*,
format: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
plan: Optional[str] = None,
budget: Optional[float] = None,
job_options: Optional[dict] = None,
format_options: Optional[dict] = None,
**kwargs,
) -> BatchJob:
"""
Sends the datacube's process graph as a batch job to the back-end
Expand All @@ -1894,18 +1959,37 @@ def create_job(
it still needs to be started and tracked explicitly.
Use :py:meth:`execute_batch` instead to have the openEO Python client take care of that job management.

:param out_format: String Format of the job result.
:param job_options: A dictionary containing (custom) job options
:param format_options: String Parameters for the job result format
:return: status: Job resulting job.
:param out_format: (optional) output file format.
:param job_options: (optional) custom job options.
:param format_options: (optional) output file format parameters.

.. versionchanged:: 0.16.0 deprecate argument ``out_format`` in favor of new argument ``format``

"""
# TODO: add option to also automatically start the job?
img = self
if out_format:
# add `save_result` node
img = img.save_result(format=out_format, options=format_options)
warnings.warn(
"Argument `out_format` in `create_job` is deprecated since 0.16.0, use `format` instead.",
category=UserDeprecationWarning,
)
assert format is None
format = out_format

if kwargs:
warnings.warn(
"Deprecated usage of keyword arguments in `create_job()` to set file format options. "
"Instead, use the `format_options` argument or an explicit `save_result()` method call "
"to properly set file format options.",
category=UserDeprecationWarning,
)
# Legacy usage pattern of setting format options through keyword arguments
assert format_options is None
format_options = kwargs

cube = self
cube = cube._ensure_save_result(format=format, options=format_options)
return self._connection.create_job(
process_graph=img.flat_graph(),
process_graph=cube.flat_graph(),
title=title, description=description, plan=plan, budget=budget, additional=job_options
)

Expand Down
2 changes: 2 additions & 0 deletions openeo/rest/mlmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class MlModel(_ProcessGraphAbstraction):

.. versionadded:: 0.10.0
"""

# TODO
def __init__(self, graph: PGNode, connection: 'Connection'):
super().__init__(pgnode=graph, connection=connection)

Expand Down
5 changes: 5 additions & 0 deletions openeo/rest/vectorcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def run_udf(

@openeo_process
def save_result(self, format: str = "GeoJson", options: dict = None):
# TODO?
# TODO: check format against supported formats
# TODO: should not return a VectorCube again, but bool wrapper
# TODO: should save_result also work on non-cube data types, e.g. arrays, scalars?
return self.process(
process_id="save_result",
arguments={
Expand All @@ -105,6 +109,7 @@ def execute(self) -> dict:

def download(self, outputfile: str, format: str = "GeoJSON", options: dict = None):
# TODO: only add save_result, when not already present (see DataCube.download)
# TODO
cube = self.save_result(format=format, options=options)
return self._connection.download(cube.flat_graph(), outputfile)

Expand Down
2 changes: 1 addition & 1 deletion openeo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def deep_set(data: dict, *keys, value):
raise ValueError("No keys given")


def guess_format(filename: Union[str, Path]):
def guess_format(filename: Union[str, Path]) -> str:
"""
Guess the output format from a given filename and return the corrected format.
Any names not in the dict get passed through.
Expand Down
79 changes: 78 additions & 1 deletion tests/rest/datacube/test_datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
- 1.0.0-style DataCube

"""

import functools
from datetime import date, datetime
import pathlib

import mock
import numpy as np
import pytest
import shapely
import shapely.geometry

from openeo.capabilities import ComparableVersion
from openeo.internal.warnings import UserDeprecationWarning
from openeo.rest import BandMathException
from openeo.rest.datacube import DataCube
from .conftest import API_URL
Expand Down Expand Up @@ -446,3 +448,78 @@ def result_callback(request, context):
requests_mock.post(API_URL + '/result', content=result_callback)
result = connection.load_collection("S2").download(format=format)
assert result == b"data"


class TestExecuteBatch:
@pytest.fixture
def get_create_job_pg(self, connection):
"""Fixture to help intercepting the process graph that was passed to Connection.create_job"""
with mock.patch.object(connection, "create_job") as create_job:

def get() -> dict:
assert create_job.call_count == 1
return create_job.call_args.kwargs["process_graph"]

yield get

def test_basic(self, connection, s2cube, get_create_job_pg, recwarn, caplog):
s2cube.execute_batch()
pg = get_create_job_pg()
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
assert pg["saveresult1"] == {
"process_id": "save_result",
"arguments": {
"data": {"from_node": "loadcollection1"},
"format": "GTiff",
"options": {},
},
"result": True,
}
assert recwarn.list == []
assert caplog.records == []

@pytest.mark.parametrize(
["format", "expected"],
[(None, "GTiff"), ("GTiff", "GTiff"), ("gtiff", "gtiff"), ("NetCDF", "NetCDF")],
)
def test_format(
self, connection, s2cube, get_create_job_pg, format, expected, recwarn, caplog
):
s2cube.execute_batch(format=format)
pg = get_create_job_pg()
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
assert pg["saveresult1"] == {
"process_id": "save_result",
"arguments": {
"data": {"from_node": "loadcollection1"},
"format": expected,
"options": {},
},
"result": True,
}
assert recwarn.list == []
assert caplog.records == []

@pytest.mark.parametrize(
["out_format", "expected"],
[("GTiff", "GTiff"), ("NetCDF", "NetCDF")],
)
def test_out_format(
self, connection, s2cube, get_create_job_pg, out_format, expected
):
with pytest.warns(
UserDeprecationWarning,
match="`out_format`.*is deprecated.*use `format` instead",
):
s2cube.execute_batch(out_format=out_format)
pg = get_create_job_pg()
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
assert pg["saveresult1"] == {
"process_id": "save_result",
"arguments": {
"data": {"from_node": "loadcollection1"},
"format": expected,
"options": {},
},
"result": True,
}