Skip to content

Commit

Permalink
Squashed commit of PR #338
Browse files Browse the repository at this point in the history
squashed PR #338 (by clausmichele)
fixed merge confluct
and did black/darker cleanups
  • Loading branch information
soxofaan committed Feb 3, 2023
1 parent d30e731 commit c0ebd3d
Show file tree
Hide file tree
Showing 8 changed files with 480 additions and 0 deletions.
1 change: 1 addition & 0 deletions openeo/local/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from openeo.local.connection import LocalConnection
271 changes: 271 additions & 0 deletions openeo/local/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
import logging
from pathlib import Path
from typing import List

import numpy as np
import rioxarray
import xarray as xr
from pyproj import Transformer

_log = logging.getLogger(__name__)


def _get_dimension(dims: dict, candidates: List[str]):
for name in candidates:
if name in dims:
return name
error = f"Dimension matching one of the candidates {candidates} not found! The available ones are {dims}. Please rename the dimension accordingly and try again. This local collection will be skipped."
raise Exception(error)


def _get_netcdf_zarr_metadata(file_path):
if ".zarr" in file_path.suffixes:
data = xr.open_dataset(file_path.as_posix(), chunks={}, engine="zarr")
else:
data = xr.open_dataset(
file_path.as_posix(), chunks={}
) # Add decode_coords='all' if the crs as a band gives some issues
file_path = file_path.as_posix()
try:
t_dim = _get_dimension(data.dims, ["t", "time", "temporal", "DATE"])
except:
t_dim = None
try:
x_dim = _get_dimension(data.dims, ["x", "X", "lon", "longitude"])
y_dim = _get_dimension(data.dims, ["y", "Y", "lat", "latitude"])
except Exception as e:
_log.warning(error)
raise Exception(f"Error creating metadata for {file_path}") from e
metadata = {}
metadata["stac_version"] = "1.0.0-rc.2"
metadata["type"] = "Collection"
metadata["id"] = file_path
data_attrs_lowercase = [x.lower() for x in data.attrs]
data_attrs_original = [x for x in data.attrs]
data_attrs = dict(zip(data_attrs_lowercase, data_attrs_original))
if "title" in data_attrs_lowercase:
metadata["title"] = data.attrs[data_attrs["title"]]
else:
metadata["title"] = file_path
if "description" in data_attrs_lowercase:
metadata["description"] = data.attrs[data_attrs["description"]]
else:
metadata["description"] = ""
if "license" in data_attrs_lowercase:
metadata["license"] = data.attrs[data_attrs["license"]]
else:
metadata["license"] = ""
providers = [{"name": "", "roles": ["producer"], "url": ""}]
if "providers" in data_attrs_lowercase:
providers[0]["name"] = data.attrs[data_attrs["providers"]]
metadata["providers"] = providers
elif "institution" in data_attrs_lowercase:
providers[0]["name"] = data.attrs[data_attrs["institution"]]
metadata["providers"] = providers
else:
metadata["providers"] = providers
if "links" in data_attrs_lowercase:
metadata["links"] = data.attrs[data_attrs["links"]]
else:
metadata["links"] = ""
x_min = data[x_dim].min().item(0)
x_max = data[x_dim].max().item(0)
y_min = data[y_dim].min().item(0)
y_max = data[y_dim].max().item(0)

crs_present = False
bands = list(data.data_vars)
if "crs" in bands:
bands.remove("crs")
crs_present = True
extent = {}
if crs_present:
if "crs_wkt" in data.crs.attrs:
transformer = Transformer.from_crs(data.crs.attrs["crs_wkt"], "epsg:4326")
lat_min, lon_min = transformer.transform(x_min, y_min)
lat_max, lon_max = transformer.transform(x_max, y_max)
extent["spatial"] = {"bbox": [[lon_min, lat_min, lon_max, lat_max]]}

if t_dim is not None:
t_min = str(data[t_dim].min().values)
t_max = str(data[t_dim].max().values)
extent["temporal"] = {"interval": [[t_min, t_max]]}

metadata["extent"] = extent

t_dimension = {}
if t_dim is not None:
t_dimension = {t_dim: {"type": "temporal", "extent": [t_min, t_max]}}

x_dimension = {x_dim: {"type": "spatial", "axis": "x", "extent": [x_min, x_max]}}
y_dimension = {y_dim: {"type": "spatial", "axis": "y", "extent": [y_min, y_max]}}
if crs_present:
if "crs_wkt" in data.crs.attrs:
x_dimension[x_dim]["reference_system"] = data.crs.attrs["crs_wkt"]
y_dimension[y_dim]["reference_system"] = data.crs.attrs["crs_wkt"]

b_dimension = {}
if len(bands) > 0:
b_dimension = {"bands": {"type": "bands", "values": bands}}

metadata["cube:dimensions"] = {
**t_dimension,
**x_dimension,
**y_dimension,
**b_dimension,
}

return metadata


def _get_geotiff_metadata(file_path):
data = rioxarray.open_rasterio(file_path.as_posix(), chunks={})
file_path = file_path.as_posix()
try:
t_dim = _get_dimension(data.dims, ["t", "time", "temporal", "DATE"])
except:
t_dim = None
try:
x_dim = _get_dimension(data.dims, ["x", "X", "lon", "longitude"])
y_dim = _get_dimension(data.dims, ["y", "Y", "lat", "latitude"])
except Exception as e:
_log.warning(error)
raise Exception(f"Error creating metadata for {file_path}") from e

metadata = {}
metadata["stac_version"] = "1.0.0-rc.2"
metadata["type"] = "Collection"
metadata["id"] = file_path
data_attrs_lowercase = [x.lower() for x in data.attrs]
data_attrs_original = [x for x in data.attrs]
data_attrs = dict(zip(data_attrs_lowercase, data_attrs_original))
if "title" in data_attrs_lowercase:
metadata["title"] = data.attrs[data_attrs["title"]]
else:
metadata["title"] = file_path
if "description" in data_attrs_lowercase:
metadata["description"] = data.attrs[data_attrs["description"]]
else:
metadata["description"] = ""
if "license" in data_attrs_lowercase:
metadata["license"] = data.attrs[data_attrs["license"]]
else:
metadata["license"] = ""
providers = [{"name": "", "roles": ["producer"], "url": ""}]
if "providers" in data_attrs_lowercase:
providers[0]["name"] = data.attrs[data_attrs["providers"]]
metadata["providers"] = providers
elif "institution" in data_attrs_lowercase:
providers[0]["name"] = data.attrs[data_attrs["institution"]]
metadata["providers"] = providers
else:
metadata["providers"] = providers
if "links" in data_attrs_lowercase:
metadata["links"] = data.attrs[data_attrs["links"]]
else:
metadata["links"] = ""
x_min = data[x_dim].min().item(0)
x_max = data[x_dim].max().item(0)
y_min = data[y_dim].min().item(0)
y_max = data[y_dim].max().item(0)

crs_present = False
coords = list(data.coords)
if "spatial_ref" in coords:
# bands.remove('crs')
crs_present = True
# TODO: list bands if more available
bands = []
if "band" in coords:
bands = list(data["band"].values)
if len(bands) > 0:
# The JSON decoder does not handle npint types, we need to convert them in advance
if (
isinstance(bands[0], np.int8)
or isinstance(bands[0], np.int16)
or isinstance(bands[0], np.int32)
or isinstance(bands[0], np.int64)
):
bands = [int(b) for b in bands]
extent = {}
if crs_present:
if "crs_wkt" in data.spatial_ref.attrs:
transformer = Transformer.from_crs(
data.spatial_ref.attrs["crs_wkt"], "epsg:4326"
)
lat_min, lon_min = transformer.transform(x_min, y_min)
lat_max, lon_max = transformer.transform(x_max, y_max)
extent["spatial"] = {"bbox": [[lon_min, lat_min, lon_max, lat_max]]}

if t_dim is not None:
t_min = str(data[t_dim].min().values)
t_max = str(data[t_dim].max().values)
extent["temporal"] = {"interval": [[t_min, t_max]]}

metadata["extent"] = extent

t_dimension = {}
if t_dim is not None:
t_dimension = {t_dim: {"type": "temporal", "extent": [t_min, t_max]}}

x_dimension = {x_dim: {"type": "spatial", "axis": "x", "extent": [x_min, x_max]}}
y_dimension = {y_dim: {"type": "spatial", "axis": "y", "extent": [y_min, y_max]}}
if crs_present:
if "crs_wkt" in data.spatial_ref.attrs:
x_dimension[x_dim]["reference_system"] = data.spatial_ref.attrs["crs_wkt"]
y_dimension[y_dim]["reference_system"] = data.spatial_ref.attrs["crs_wkt"]

b_dimension = {}
if len(bands) > 0:
b_dimension = {"bands": {"type": "bands", "values": bands}}

metadata["cube:dimensions"] = {
**t_dimension,
**x_dimension,
**y_dimension,
**b_dimension,
}

return metadata


def _get_netcdf_zarr_collections(local_collections_path):
if isinstance(local_collections_path, str):
local_collections_path = [local_collections_path]
local_collections_list = []
for flds in local_collections_path:
local_collections_netcdf_zarr = [
p for p in Path(flds).rglob("*") if p.suffix in [".nc", ".zarr"]
]
for local_file in local_collections_netcdf_zarr:
try:
metadata = _get_netcdf_zarr_metadata(local_file)
local_collections_list.append(metadata)
except Exception as e:
_log.error(e)
continue

local_collections_dict = {"collections": local_collections_list}

return local_collections_dict


def _get_geotiff_collections(local_collections_path):
if isinstance(local_collections_path, str):
local_collections_path = [local_collections_path]
local_collections_list = []
for flds in local_collections_path:
local_collections_geotiffs = [
p for p in Path(flds).rglob("*") if p.suffix in [".tif", ".tiff"]
]
for local_file in local_collections_geotiffs:
try:
metadata = _get_geotiff_metadata(local_file)
local_collections_list.append(metadata)
except Exception as e:
_log.error(e)
continue

local_collections_dict = {"collections": local_collections_list}

return local_collections_dict
113 changes: 113 additions & 0 deletions openeo/local/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import datetime
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union

import xarray as xr
from openeo_pg_parser_networkx.graph import OpenEOProcessGraph

from openeo.internal.graph_building import PGNode, as_flat_graph
from openeo.internal.jupyter import VisualDict, VisualList
from openeo.local.collections import (
_get_geotiff_collections,
_get_geotiff_metadata,
_get_netcdf_zarr_collections,
_get_netcdf_zarr_metadata,
)
from openeo.local.processing import PROCESS_REGISTRY
from openeo.metadata import CollectionMetadata
from openeo.rest.datacube import DataCube


class LocalConnection:
"""
Connection to no backend, for local processing.
"""

def __init__(self, local_collections_path: Union[str, List]):
"""
Constructor of LocalConnection.
:param local_collections_path: String or list of strings, path to the folder(s) with
the local collections in netCDF, geoTIFF or ZARR.
"""
self.local_collections_path = local_collections_path

def list_collections(self) -> List[dict]:
"""
List basic metadata of all collections provided in the local collections folder.
.. caution::
:return: list of dictionaries with basic collection metadata.
"""
data_nc = _get_netcdf_zarr_collections(self.local_collections_path)[
"collections"
]
data_tif = _get_geotiff_collections(self.local_collections_path)["collections"]
data = data_nc + data_tif
return VisualList("collections", data=data)

def describe_collection(self, collection_id: str) -> dict:
"""
Get full collection metadata for given collection id.
.. seealso::
:py:meth:`~openeo.rest.connection.Connection.list_collection_ids`
to list all collection ids provided by the back-end.
:param collection_id: collection id
:return: collection metadata.
"""
local_collection = Path(collection_id)
if ".nc" in local_collection.suffixes or ".zarr" in local_collection.suffixes:
data = _get_netcdf_zarr_metadata(local_collection)
elif (
".tif" in local_collection.suffixes or ".tiff" in local_collection.suffixes
):
data = _get_geotiff_metadata(local_collection)
return VisualDict("collection", data=data)

def collection_metadata(self, name) -> CollectionMetadata:
# TODO: duplication with `Connection.describe_collection`: deprecate one or the other?
return CollectionMetadata(metadata=self.describe_collection(name))

def load_collection(
self,
collection_id: str,
spatial_extent: Optional[Dict[str, float]] = None,
temporal_extent: Optional[
List[Union[str, datetime.datetime, datetime.date]]
] = None,
bands: Optional[List[str]] = None,
properties: Optional[Dict[str, Union[str, PGNode, Callable]]] = None,
fetch_metadata=True,
) -> DataCube:
"""
Load a DataCube by collection id.
:param collection_id: image collection identifier
:param spatial_extent: limit data to specified bounding box or polygons
:param temporal_extent: limit data to specified temporal interval
:param bands: only add the specified bands
:param properties: limit data by metadata property predicates
:return: a datacube containing the requested data
"""
return DataCube.load_collection(
collection_id=collection_id,
connection=self,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
bands=bands,
properties=properties,
fetch_metadata=fetch_metadata,
)

def execute(self, process_graph: Union[dict, str, Path]) -> xr.DataArray:
"""
Execute locally the process graph and return the result as an xarray.DataArray.
:param process_graph: (flat) dict representing a process graph, or process graph as raw JSON string,
:return: a datacube containing the requested data
"""
process_graph = as_flat_graph(process_graph)
return OpenEOProcessGraph(process_graph).to_callable(PROCESS_REGISTRY)()

0 comments on commit c0ebd3d

Please sign in to comment.