Skip to content

Commit

Permalink
Issue #401/#449 support format guessing in VectorCube.download
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Jul 17, 2023
1 parent 8cbaf25 commit 4971005
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add support in `VectoCube.download()` to guess output format from extension of a given filename
([#401](https://github.com/Open-EO/openeo-python-client/issues/401), [#449](https://github.com/Open-EO/openeo-python-client/issues/449))

### Changed

### Removed
Expand Down
2 changes: 2 additions & 0 deletions openeo/rest/datacube.py
Expand Up @@ -1943,6 +1943,7 @@ def download(
:return: None if the result is stored to disk, or a bytes object returned by the backend.
"""
if format is None and outputfile is not None:
# TODO #401/#449 don't guess/override format if there is already a save_result with format?
format = guess_format(outputfile)
cube = self._ensure_save_result(format=format, options=options)
return self._connection.download(cube.flat_graph(), outputfile)
Expand Down Expand Up @@ -2062,6 +2063,7 @@ def execute_batch(
if "format" in format_options and not out_format:
out_format = format_options["format"] # align with 'download' call arg name
if not out_format and outputfile:
# TODO #401/#449 don't guess/override format if there is already a save_result with format?
out_format = guess_format(outputfile)

job = self.create_job(
Expand Down
6 changes: 4 additions & 2 deletions openeo/rest/vectorcube.py
Expand Up @@ -9,7 +9,7 @@
from openeo.rest._datacube import _ProcessGraphAbstraction, UDF
from openeo.rest.mlmodel import MlModel
from openeo.rest.job import BatchJob
from openeo.util import dict_no_none
from openeo.util import dict_no_none, guess_format

if typing.TYPE_CHECKING:
# Imports for type checking only (circular import issue at runtime).
Expand Down Expand Up @@ -137,8 +137,10 @@ def execute(self) -> dict:
return self._connection.execute(self.flat_graph())

def download(self, outputfile: Union[str, pathlib.Path], format: Optional[str] = None, options: dict = None):
# TODO #401 guess format from outputfile?
# TODO #401 make outputfile optional (See DataCube.download)
# TODO #401/#449 don't guess/override format if there is already a save_result with format?
if format is None and outputfile is not None:
format = guess_format(outputfile)
cube = self._ensure_save_result(format=format, options=options)
return self._connection.download(cube.flat_graph(), outputfile)

Expand Down
1 change: 1 addition & 0 deletions openeo/util.py
Expand Up @@ -451,6 +451,7 @@ def guess_format(filename: Union[str, Path]) -> str:
format_map = {
"gtiff": "GTiff", "geotiff": "GTiff", "geotif": "GTiff", "tiff": "GTiff", "tif": "GTiff",
"nc": "netCDF", "netcdf": "netCDF",
"geojson": "GeoJSON",
}

return format_map.get(extension, extension.upper())
Expand Down
26 changes: 14 additions & 12 deletions tests/rest/datacube/test_vectorcube.py
Expand Up @@ -85,9 +85,9 @@ def test_raster_to_vector(con100):
@pytest.mark.parametrize(
["filename", "expected_format"],
[
("result.json", "GeoJSON"), # TODO #401 possible to detect "GeoJSON from ".json" extension?
("result.json", "JSON"), # TODO possible to allow "GeoJSON" with ".json" extension?
("result.geojson", "GeoJSON"),
("result.nc", "GeoJSON"), # TODO #401 autodetect format from filename
("result.nc", "netCDF"),
],
)
@pytest.mark.parametrize("path_class", [str, Path])
Expand Down Expand Up @@ -118,11 +118,12 @@ def test_download_auto_save_result_only_file(
("result.json", "JSON", "JSON"),
("result.geojson", "GeoJSON", "GeoJSON"),
("result.nc", "netCDF", "netCDF"),
# TODO #401 more formats to autodetect?
("result.nc", "NETcDf", "NETcDf"), # TODO #401 normalize format
("result.nc", "inV6l1d!!!", "inV6l1d!!!"), # TODO #401 this should fail
("result.json", None, "GeoJSON"), # TODO #401 autodetect format from filename?
("result.nc", None, "GeoJSON"), # TODO #401 autodetect format from filename
("result.nc", "inV6l1d!!!", "inV6l1d!!!"), # TODO #401 this should fail?
("result.json", None, "JSON"),
("result.geojson", None, "GeoJSON"),
("result.nc", None, "netCDF"),
# TODO #449 more formats to autodetect?
],
)
def test_download_auto_save_result_with_format(vector_cube, download_spy, tmp_path, filename, format, expected_format):
Expand Down Expand Up @@ -164,17 +165,18 @@ def test_download_auto_save_result_with_options(vector_cube, download_spy, tmp_p


@pytest.mark.parametrize(
["format", "expected_format"],
["output_file", "format", "expected_format"],
[
(None, "GeoJSON"),
("JSON", "JSON"),
("netCDF", "netCDF"),
("result.geojson", None, "GeoJSON"),
("result.geojson", "GeoJSON", "GeoJSON"),
("result.json", "JSON", "JSON"),
("result.nc", "netCDF", "netCDF"),
],
)
def test_save_result_and_download(vector_cube, download_spy, tmp_path, format, expected_format):
def test_save_result_and_download(vector_cube, download_spy, tmp_path, output_file, format, expected_format):
"""e.g. https://github.com/Open-EO/openeo-geopyspark-driver/issues/477"""
vector_cube = vector_cube.save_result(format=format)
output_path = tmp_path / "result.json"
output_path = tmp_path / output_file
vector_cube.download(output_path)
assert download_spy.only_request == {
"createvectorcube1": {"process_id": "create_vector_cube", "arguments": {}},
Expand Down

0 comments on commit 4971005

Please sign in to comment.