diff --git a/algorithm_catalog/wur/worldagrocommodities/benchmark_scenarios/WAC_inference_africa.json b/algorithm_catalog/wur/worldagrocommodities/benchmark_scenarios/wac_inference.json similarity index 62% rename from algorithm_catalog/wur/worldagrocommodities/benchmark_scenarios/WAC_inference_africa.json rename to algorithm_catalog/wur/worldagrocommodities/benchmark_scenarios/wac_inference.json index 6fcb0e38..2845a40d 100644 --- a/algorithm_catalog/wur/worldagrocommodities/benchmark_scenarios/WAC_inference_africa.json +++ b/algorithm_catalog/wur/worldagrocommodities/benchmark_scenarios/wac_inference.json @@ -7,21 +7,22 @@ "process_graph": { "parcel_delineation1": { "process_id": "wac_inference_africa", - "namespace": "https://raw.githubusercontent.com/ESA-APEx/apex_algorithms/b43116632be5e3cd6150fbff5ae21cb135763654/algorithm_catalog/wur/worldagrocommodities/openeo_udp/wac_inference_africa.json", + "namespace": "https://raw.githubusercontent.com/ESA-APEx/apex_algorithms/c2d74e55afbc76ddb8eca9679196d1eab282e98c/algorithm_catalog/wur/worldagrocommodities/openeo_udp/wac_inference_africa.json", "arguments": { - "spatial_extent": { - "west": 677736, - "south": 624010, - "east": 682736, - "north": 629010, - "crs": "EPSG:32630" - }, + "spatial_extent": + {"west": 749000, + "south": -2378000, + "east": 750000, + "north": -2377000, + "crs": "EPSG:32638"}, "temporal_extent": [ "2023-01-01", "2024-01-01" ], - "crs": "EPSG:32629" + "crs": "EPSG:32638", + "model_id":"WorldAgriCommodities_Africa_v1" + }, "result": true diff --git a/algorithm_catalog/wur/worldagrocommodities/openeo_udp/README.md b/algorithm_catalog/wur/worldagrocommodities/openeo_udp/README.md index abfc35da..334acf25 100644 --- a/algorithm_catalog/wur/worldagrocommodities/openeo_udp/README.md +++ b/algorithm_catalog/wur/worldagrocommodities/openeo_udp/README.md @@ -34,4 +34,9 @@ Patch-based sliding window (size 128×128 px, overlap 64 px) via apply_neighborh ONNX U-Net model loaded at runtime +### 5 Auxiliary Data + +After the Model Inference, A preliminary Tree Cover Density Product is included for the year 2020 + + \ No newline at end of file diff --git a/algorithm_catalog/wur/worldagrocommodities/openeo_udp/wac_inference_africa.json b/algorithm_catalog/wur/worldagrocommodities/openeo_udp/wac_inference_africa.json index e6623378..462e4659 100644 --- a/algorithm_catalog/wur/worldagrocommodities/openeo_udp/wac_inference_africa.json +++ b/algorithm_catalog/wur/worldagrocommodities/openeo_udp/wac_inference_africa.json @@ -133,7 +133,7 @@ "from_parameter": "data" }, "probabilities": [ - 0.95 + 0.85 ] }, "result": true @@ -550,7 +550,7 @@ "from_parameter": "data" }, "probabilities": [ - 0.95 + 0.85 ] }, "result": true @@ -642,7 +642,7 @@ "from_parameter": "x" }, "runtime": "Python", - "udf": "import numpy as np\nimport xarray as xr\nimport logging\nfrom pyproj import Transformer\nfrom openeo.udf import XarrayDataCube\n\n# Setup logging\ndef _setup_logging() -> logging.Logger:\n logging.basicConfig(level=logging.INFO, format=\"%(message)s\")\n return logging.getLogger(__name__)\n\nfrom openeo.udf.udf_data import UdfData\n\n\nlogger = _setup_logging()\n\ndef apply_udf_data(udf_data: UdfData) -> UdfData:\n \"\"\"This is the actual openeo UDF that will be executed by the backend.\"\"\"\n\n cube = udf_data.datacube_list[0]\n arr = cube.get_array().transpose(\"bands\", \"y\", \"x\")\n\n crs = udf_data.proj\n if crs is not None:\n crs = crs[\"EPSG\"]\n\n logger.info(f\"EPSG code determined for feature extraction: {crs}\")\n\n transformer = Transformer.from_crs(crs, \"EPSG:4326\", always_xy=True)\n longitudes, latitudes = transformer.transform(arr.x, arr.y)\n lon_grid, lat_grid = np.meshgrid(longitudes, latitudes)\n\n logger.info(f\"Transformed longitudes range: {longitudes.min()}, {longitudes.max()}\")\n logger.info(f\"Transformed latitudes range: {latitudes.min()}, {latitudes.max()}\")\n\n combined = xr.DataArray(\n data=np.stack([lon_grid, lat_grid], axis=0), # shape: (2, y, x)\n dims=(\"bands\", \"y\", \"x\"),\n coords={\n \"bands\": [\"lon\", \"lat\"],\n \"x\": arr.coords[\"x\"],\n \"y\": arr.coords[\"y\"]\n }\n )\n\n cube_output = XarrayDataCube(combined) \n udf_data.datacube_list = [cube_output]\n\n return udf_data\n\n" + "udf": "import logging\n\nimport numpy as np\nimport xarray as xr\nfrom openeo.udf import XarrayDataCube\nfrom pyproj import Transformer\n\n\n# Setup logging\ndef _setup_logging() -> logging.Logger:\n logging.basicConfig(level=logging.INFO, format=\"%(message)s\")\n return logging.getLogger(__name__)\n\n\nfrom openeo.udf.udf_data import UdfData\n\nlogger = _setup_logging()\n\n\ndef apply_udf_data(udf_data: UdfData) -> UdfData:\n \"\"\"This is the actual openeo UDF that will be executed by the backend.\"\"\"\n\n cube = udf_data.datacube_list[0]\n arr = cube.get_array().transpose(\"bands\", \"y\", \"x\")\n\n crs = udf_data.proj\n if crs is not None:\n crs = crs[\"EPSG\"]\n\n logger.info(f\"EPSG code determined for feature extraction: {crs}\")\n\n transformer = Transformer.from_crs(crs, \"EPSG:4326\", always_xy=True)\n longitudes, latitudes = transformer.transform(arr.x, arr.y)\n lon_grid, lat_grid = np.meshgrid(longitudes, latitudes)\n\n logger.info(f\"Transformed longitudes range: {longitudes.min()}, {longitudes.max()}\")\n logger.info(f\"Transformed latitudes range: {latitudes.min()}, {latitudes.max()}\")\n\n combined = xr.DataArray(\n data=np.stack([lon_grid, lat_grid], axis=0), # shape: (2, y, x)\n dims=(\"bands\", \"y\", \"x\"),\n coords={\"bands\": [\"lon\", \"lat\"], \"x\": arr.coords[\"x\"], \"y\": arr.coords[\"y\"]},\n )\n\n cube_output = XarrayDataCube(combined)\n udf_data.datacube_list = [cube_output]\n\n return udf_data\n" }, "result": true } @@ -699,11 +699,16 @@ "runudf2": { "process_id": "run_udf", "arguments": { + "context": { + "model_id": { + "from_parameter": "model_id" + } + }, "data": { "from_parameter": "x" }, "runtime": "Python", - "udf": "import numpy as np\nimport xarray as xr\nimport logging\nfrom typing import Dict, Tuple, List\n\n\n# Configure logging\nlogging.basicConfig(level=logging.INFO, format=\"%(message)s\")\nlogger = logging.getLogger(__name__)\n\n# Normalization specifications\nNORMALIZATION_SPECS = {\n \"optical\": {\n \"B02\": (1.7417268007636313, 2.023298706048351),\n \"B03\": (1.7261204997060209, 2.038905204308012),\n \"B04\": (1.6798346251414997, 2.179592821212937),\n \"B05\": (2.3828939530384052, 2.7578332604178284),\n \"B06\": (1.7417268007636313, 2.023298706048351),\n \"B07\": (1.7417268007636313, 2.023298706048351),\n \"B08\": (1.7417268007636313, 2.023298706048351),\n \"B11\": (1.7417268007636313, 2.023298706048351),\n \"B12\": (1.7417268007636313, 2.023298706048351),\n },\n \"linear\": {\n \"NDVI\": (0, 1), #TODO should not get normalized\n \"NDRE\": (-1, 1),\n \"EVI\": (-1, 1),\n \"VV\": (-25, 0),\n \"VH\": (-30, -5),\n \"DEM\": (-400, 8000),\n \"lon\": (-180, 180),\n \"lat\": (-60, 60)\n },\n}\n\n\ndef _normalize_optical(arr: np.ndarray, min_spec: float, max_spec: float) -> np.ndarray:\n \"\"\"Log-based normalization for optical bands.\"\"\"\n arr = np.log(arr * 0.005 + 1)\n arr = (arr - min_spec) / (max_spec)\n arr = np.exp(arr * 5 - 1)\n return arr / (arr + 1)\n\n\ndef _normalize_linear(arr: np.ndarray, min_spec: float, max_spec: float) -> np.ndarray:\n \"\"\"Linear min\u2013max normalization for continuous variables.\"\"\"\n arr = np.clip(arr, min_spec, max_spec)\n return (arr - min_spec) / (max_spec - min_spec)\n\n\nNORMALIZE_FUNCS = {\n \"optical\": _normalize_optical,\n \"linear\": _normalize_linear,\n}\n\ndef get_expected_bands() -> List[str]:\n \"\"\"\n Derive expected band order directly from NORMALIZATION_SPECS.\n Preserves the order in which groups and bands were defined.\n \"\"\"\n expected = []\n for group_bands in NORMALIZATION_SPECS.values():\n expected.extend(group_bands.keys())\n return expected\n\ndef validate_bands(cube: xr.DataArray, expected_bands: list):\n \"\"\"\n Validate presence and order of required bands in a data cube.\n\n Ensures that:\n 1. All required bands are present.\n 2. Bands are in the correct order.\n\n Args:\n cube (xr.DataArray):\n Input data cube with a 'bands' coordinate.\n expected_bands (list):\n Ordered list of band names required for processing.\n\n Returns:\n xr.DataArray:\n Data cube with bands in the correct order.\n\n Raises:\n ValueError: If any required bands are missing.\n \"\"\"\n band_names = list(cube.coords[\"bands\"].values)\n logger.info(f\"Input bands: {band_names}\")\n\n # Check for missing bands\n missing_bands = [b for b in expected_bands if b not in band_names]\n if missing_bands:\n raise ValueError(f\"Missing required bands: {missing_bands}. Got: {band_names}\")\n\n # Reorder if needed\n if band_names != expected_bands:\n raise ValueError(f\"Band order mismatch: {band_names} vs {expected_bands}\")\n\n\n\ndef apply_datacube(cube: xr.DataArray, context: dict) -> xr.DataArray:\n\n \"\"\"\n Normalize all bands in an input data cube according to predefined specifications.\n\n Steps:\n 1. Derive expected band order from NORMALIZATION_SPECS.\n 2. Validate band presence and order.\n 3. Apply normalization function per band based on its group.\n\n Args:\n cube (xr.DataArray):\n Input data cube with dimensions (\"bands\", \"y\", \"x\").\n\n Returns:\n xr.DataArray:\n Normalized data cube with same shape, dimensions, and band names.\n\n Raises:\n ValueError: If required bands are missing or in the wrong order.\n \"\"\"\n\n logger.info(f\"Received data with shape: {cube.shape}, dims: {cube.dims}\")\n\n # --- Validate & reorder bands in one call ---\n expected_bands = get_expected_bands()\n validate_bands(cube, expected_bands)\n\n # --- Normalization logic stays unchanged ---\n band_names = list(cube.coords[\"bands\"].values)\n logger.info(f\"Normalizing bands: {band_names}\")\n\n img_values = cube.values\n normalized_bands = []\n output_band_names = []\n for band in band_names:\n arr = img_values[band_names.index(band)]\n pre_stats = (arr.min(), arr.max(), arr.mean())\n # Find which group this band belongs to\n group = None\n for g, specs in NORMALIZATION_SPECS.items():\n if band in specs:\n group = g\n min_spec, max_spec = specs[band]\n norm_func = NORMALIZE_FUNCS[group]\n normalized = norm_func(arr, min_spec, max_spec)\n post_stats = (normalized.min(), normalized.max(), normalized.mean())\n logger.info(\n f\"Band {band}: group={group}, \"\n f\"min={pre_stats[0]:.3f}->{post_stats[0]:.3f}, \"\n f\"max={pre_stats[1]:.3f}->{post_stats[1]:.3f}, \"\n f\"mean={pre_stats[2]:.3f}->{post_stats[2]:.3f}\"\n )\n normalized_bands.append(normalized)\n output_band_names.append(band)\n break\n\n if group is None:\n logger.warning(f\"Band {band}: no normalization defined, leaving unchanged.\")\n post_stats = pre_stats\n logger.info(\n f\"Band {band}: kept as-is, \"\n f\"min={pre_stats[0]:.3f}, max={pre_stats[1]:.3f}, mean={pre_stats[2]:.3f}\"\n )\n normalized_bands.append(arr.astype(np.float32))\n output_band_names.append(band)\n\n # Stack back into DataArray\n result_array = np.stack(normalized_bands, axis=0)\n da = xr.DataArray(\n result_array,\n dims=(\"bands\", \"y\", \"x\"),\n coords={\n \"bands\": output_band_names,\n \"x\": cube.coords[\"x\"],\n \"y\": cube.coords[\"y\"],\n },\n )\n logger.info(f\"Normalization complete. Output bands: {output_band_names}\")\n return da\n" + "udf": "import functools\nimport logging\nfrom typing import List\n\nimport numpy as np\nimport requests\nimport xarray as xr\nfrom openeo.metadata import CollectionMetadata\n\n# Configure logging\nlogging.basicConfig(level=logging.INFO, format=\"%(message)s\")\nlogger = logging.getLogger(__name__)\n\n\n@functools.lru_cache(maxsize=1)\ndef get_model_metadata_from_stac(\n model_id: str, stac_api_url: str = \"https://stac.openeo.vito.be\"\n) -> dict:\n \"\"\"Fetch model metadata from STAC API\"\"\"\n try:\n # Get collection and item\n collection_id = \"world-agri-commodities-models\"\n url = f\"{stac_api_url}/collections/{collection_id}/items/{model_id}\"\n\n response = requests.get(url)\n response.raise_for_status()\n\n item = response.json()\n properties = item.get(\"properties\", {})\n\n logger.info(f\"Retrieved model metadata for {model_id}\")\n return {\n \"input_bands\": properties.get(\"input_channels\", []),\n \"input_shape\": properties.get(\"input_shape\", 0),\n }\n\n except Exception as e:\n logger.error(f\"Failed to fetch model metadata: {e}\")\n raise\n\n\ndef get_normalization_specs(input_bands: List[str]) -> dict:\n \"\"\"Dynamically generate normalization specs based on input bands\"\"\"\n specs = {\n \"optical\": {\n \"B02\": (1.7417268007636313, 2.023298706048351),\n \"B03\": (1.7261204997060209, 2.038905204308012),\n \"B04\": (1.6798346251414997, 2.179592821212937),\n \"B05\": (2.3828939530384052, 2.7578332604178284),\n \"B06\": (1.7417268007636313, 2.023298706048351),\n \"B07\": (1.7417268007636313, 2.023298706048351),\n \"B08\": (1.7417268007636313, 2.023298706048351),\n \"B11\": (1.7417268007636313, 2.023298706048351),\n \"B12\": (1.7417268007636313, 2.023298706048351),\n },\n \"linear\": {\n \"NDVI\": (0, 1),\n \"NDRE\": (-1, 1),\n \"EVI\": (-1, 1),\n \"VV\": (-25, 0),\n \"VH\": (-30, -5),\n \"DEM\": (-400, 8000),\n \"lon\": (-180, 180),\n \"lat\": (-60, 60),\n },\n }\n\n # Filter specs to only include bands that are actually in the input\n filtered_specs = {\"optical\": {}, \"linear\": {}}\n\n for band in input_bands:\n if band in specs[\"optical\"]:\n filtered_specs[\"optical\"][band] = specs[\"optical\"][band]\n elif band in specs[\"linear\"]:\n filtered_specs[\"linear\"][band] = specs[\"linear\"][band]\n\n return filtered_specs\n\n\ndef _normalize_optical(arr: np.ndarray, min_spec: float, max_spec: float) -> np.ndarray:\n \"\"\"Log-based normalization for optical bands.\"\"\"\n arr = np.log(arr * 0.005 + 1)\n arr = (arr - min_spec) / (max_spec)\n arr = np.exp(arr * 5 - 1)\n return arr / (arr + 1)\n\n\ndef _normalize_linear(arr: np.ndarray, min_spec: float, max_spec: float) -> np.ndarray:\n \"\"\"Linear min\u2013max normalization for continuous variables.\"\"\"\n arr = np.clip(arr, min_spec, max_spec)\n return (arr - min_spec) / (max_spec - min_spec)\n\n\nNORMALIZE_FUNCS = {\n \"optical\": _normalize_optical,\n \"linear\": _normalize_linear,\n}\n\n\ndef validate_bands(cube: xr.DataArray, expected_bands: list):\n band_names = list(cube.coords[\"bands\"].values)\n logger.info(f\"Input bands: {band_names}\")\n logger.info(f\"Expected bands from model: {expected_bands}\")\n\n # Check for missing required bands\n missing_bands = [b for b in expected_bands if b not in band_names]\n if missing_bands:\n raise ValueError(f\"Missing required bands: {missing_bands}. Got: {band_names}\")\n\n # Check order\n if band_names != expected_bands:\n logger.warning(\n f\"Band order mismatch: reordering from {band_names} to {expected_bands}\"\n )\n cube = cube.sel(bands=expected_bands)\n\n return cube\n\n\ndef apply_datacube(cube: xr.DataArray, context: dict) -> xr.DataArray:\n \"\"\"\n Normalize bands based on model metadata from STAC API.\n \"\"\"\n logger.info(f\"Received data with shape: {cube.shape}, dims: {cube.dims}\")\n\n # Get model ID from context\n model_id = context.get(\"model_id\")\n if not model_id:\n raise ValueError(\"model_id must be provided in context\")\n\n # Fetch model metadata from STAC\n model_metadata = get_model_metadata_from_stac(model_id)\n expected_bands = model_metadata[\"input_bands\"]\n\n logger.info(f\"Using model: {model_id} with expected bands: {expected_bands}\")\n logger.info(f\"Model expects {model_metadata['input_shape']} input bands\")\n\n # Validate and reorder bands\n cube = validate_bands(cube, expected_bands)\n\n # Get normalization specs for this specific model\n normalization_specs = get_normalization_specs(expected_bands)\n\n band_names = list(cube.coords[\"bands\"].values)\n logger.info(f\"Normalizing bands: {band_names}\")\n\n img_values = cube.values\n normalized_bands = []\n\n for band in band_names:\n arr = img_values[band_names.index(band)]\n pre_stats = (arr.min(), arr.max(), arr.mean())\n\n # Find which group this band belongs to\n group = None\n for g, specs in normalization_specs.items():\n if band in specs:\n group = g\n min_spec, max_spec = specs[band]\n norm_func = NORMALIZE_FUNCS[group]\n normalized = norm_func(arr, min_spec, max_spec)\n post_stats = (normalized.min(), normalized.max(), normalized.mean())\n logger.info(\n f\"Band {band}: group={group}, \"\n f\"min={pre_stats[0]:.3f}->{post_stats[0]:.3f}, \"\n f\"max={pre_stats[1]:.3f}->{post_stats[1]:.3f}\"\n )\n normalized_bands.append(normalized)\n break\n\n if group is None:\n logger.warning(f\"Band {band}: no normalization defined, leaving unchanged.\")\n normalized_bands.append(arr.astype(np.float32))\n\n # Stack back into DataArray\n result_array = np.stack(normalized_bands, axis=0)\n da = xr.DataArray(\n result_array,\n dims=(\"bands\", \"y\", \"x\"),\n coords={\n \"bands\": band_names,\n \"x\": cube.coords[\"x\"],\n \"y\": cube.coords[\"y\"],\n },\n )\n logger.info(f\"Normalization complete for model {model_id}\")\n return da\n\n\ndef apply_metadata(metadata: CollectionMetadata, context: dict) -> CollectionMetadata:\n model_id = context.get(\"model_id\")\n\n # Fetch model metadata from STAC\n model_metadata = get_model_metadata_from_stac(model_id)\n input_bands = model_metadata.get(\"input_bands\", [])\n\n logger.info(f\"Applying metadata with input bands: {input_bands}\")\n return metadata.rename_labels(dimension=\"bands\", target=input_bands)\n" }, "result": true } @@ -720,12 +725,12 @@ "overlap": [ { "dimension": "x", - "value": 32, + "value": 16, "unit": "px" }, { "dimension": "y", - "value": 32, + "value": 16, "unit": "px" } ], @@ -735,14 +740,15 @@ "process_id": "run_udf", "arguments": { "context": { - "model_path": "dynamic_models//best_weights_att_unet_lagtime_5_Fused3_2023_totalLoss6V1_without_loss_sentAfrica6.onnx" + "model_id": { + "from_parameter": "model_id" + } }, "data": { "from_parameter": "data" }, "runtime": "Python", - "version": "3.8", - "udf": "import sys\nimport functools\nimport numpy as np\nimport xarray as xr\nimport logging\nfrom typing import Dict, Tuple\nfrom scipy.special import expit\n\n\n# Setup logger\ndef _setup_logging():\n logging.basicConfig(level=logging.INFO)\n return logging.getLogger(__name__)\n\nlogger = _setup_logging()\n\n# Add ONNX paths\nsys.path.append(\"onnx_deps\")\nsys.path.append(\"onnx_models\")\nimport onnxruntime as ort\n\n# Constants for sanitization\n_INF_REPLACEMENT = 1e6\n_NEG_INF_REPLACEMENT = -1e6\n\n@functools.lru_cache(maxsize=1)\ndef _load_ort_session(model_name: str) -> ort.InferenceSession:\n \"\"\"Loads an ONNX model and returns a cached ONNX runtime session.\"\"\"\n return ort.InferenceSession(f\"onnx_models/{model_name}\")\n\ndef preprocess_image(cube: xr.DataArray) -> Tuple[np.ndarray, Dict[str, xr.Coordinate], np.ndarray]:\n \"\"\"\n Prepare the input cube for inference:\n - Transpose to (y, x, bands)\n - Sanitize NaN/Inf\n - Return batch tensor, coords, and invalid-value mask\n \"\"\"\n # Reorder dims\n reordered = cube.transpose(\"y\", \"x\", \"bands\")\n values = reordered.values.astype(np.float32)\n\n # Mask invalid entries\n mask_invalid = ~np.isfinite(values)\n\n # Replace NaN with 0, inf with large sentinel\n sanitized = np.where(np.isnan(values), 0.0, values)\n sanitized = np.where(np.isposinf(sanitized), _INF_REPLACEMENT, sanitized)\n sanitized = np.where(np.isneginf(sanitized), _NEG_INF_REPLACEMENT, sanitized)\n\n # Add batch dimension\n input_tensor = sanitized[None, ...]\n logger.info(f\"Preprocessed tensor shape={input_tensor.shape}\")\n return input_tensor, reordered.coords, mask_invalid\n\n\ndef run_inference(\n session: ort.InferenceSession,\n input_name: str,\n input_tensor: np.ndarray\n) -> np.ndarray:\n \"\"\"Run ONNX session and remove batch dimension from output.\"\"\"\n outputs = session.run(None, {input_name: input_tensor})\n pred = np.squeeze(outputs[0], axis=0)\n logger.info(f\"Inference output shape={pred.shape}\")\n return pred\n\n#TODO\ndef postprocess_output(\n pred: np.ndarray, # Shape: [y, x, bands]\n coords: Dict[str, xr.Coordinate],\n mask_invalid: np.ndarray # Shape: [y, x, bands]\n) -> xr.DataArray:\n \"\"\"\n Appends winning class index as new band to predictions:\n - Keeps original prediction values\n - Adds new band (-1 for invalid, 0..n-1 for winning class)\n \"\"\"\n\n # Apply sigmoid\n #sigmoid_probs = expit(pred) # shape [y, x, bands]\n\n # Optionally pick highest prob if needed\n #class_index = np.argmax(pred, axis=-1, keepdims=True)\n\n # Identify invalid pixels (any invalid in input bands)\n class_index = np.argmax(pred, axis=-1, keepdims=True) # shape [y, x, 1]\n\n invalid_mask = np.any(mask_invalid, axis=-1, keepdims=True)\n class_index = np.where(invalid_mask, -1, class_index).astype(np.float32)\n\n # Update band coordinates\n new_band_coords = np.arange(pred.shape[-1] + 1)\n\n combined = np.concatenate([pred, class_index], axis=-1)\n\n return xr.DataArray(\n combined,\n dims=(\"y\", \"x\", \"bands\"),\n coords={\n \"y\": coords[\"y\"],\n \"x\": coords[\"x\"],\n \"bands\": new_band_coords\n },\n attrs={\"description\": \"Original preds, sigmoid probs, class index\"}\n )\n\n\n\ndef apply_model(\n cube: xr.DataArray,\n model_path: str\n) -> xr.DataArray:\n \"\"\"\n Full inference pipeline: preprocess, infer, postprocess.\n \"\"\"\n input_tensor, coords, mask_invalid = preprocess_image(cube)\n session = _load_ort_session(model_path)\n input_name = session.get_inputs()[0].name\n raw_pred = run_inference(session, input_name, input_tensor)\n\n #TODO evaluate reprocessing\n result = postprocess_output(raw_pred, coords, mask_invalid)\n #logger.info(f\"apply_model result shape={result.shape}\")\n return result\n\n\ndef apply_datacube(cube: xr.DataArray, context: dict) -> xr.DataArray:\n \"\"\"\n Apply ONNX model per timestep in the datacube.\n \"\"\"\n logger.info(f\"apply_datacube received shape={cube.shape}, dims={cube.dims}\")\n\n model_path = str(context.get(\"model_path\" ))\n\n logger.info(f\"Applying model: {model_path}\")\n\n cube = cube.transpose('y', 'x', 'bands', 't')\n\n if 't' in cube.dims:\n logger.info(\"Applying model per timestep via groupby-map.\")\n return cube.groupby('t').map(lambda da: apply_model(da, model_path))\n else:\n logger.info(\"Single timestep: applying model once.\")\n return apply_model(cube, model_path)" + "udf": "import functools\nimport hashlib\nimport logging\nimport os\nimport sys\nimport tempfile\nimport threading\nfrom typing import Dict, Tuple\nimport logging\nimport numpy as np\nimport requests\nimport xarray as xr\nfrom openeo.metadata import CollectionMetadata\n\nlogger = logging.getLogger(__name__)\n\n# Global lock dictionary for thread-safe model downloading\n_model_locks: Dict[str, threading.Lock] = {}\n_model_locks_lock = threading.Lock() # Lock for managing the lock dictionary\n\n\ndef get_model_lock(model_id: str) -> threading.Lock:\n \"\"\"Get or create a lock for a specific model ID (thread-safe).\"\"\"\n with _model_locks_lock:\n if model_id not in _model_locks:\n _model_locks[model_id] = threading.Lock()\n return _model_locks[model_id]\n\n\ndef get_model_cache_path(model_id: str, cache_dir: str = \"/tmp/onnx_models\") -> str:\n \"\"\"Get the cache path for a model, creating directory if needed.\"\"\"\n os.makedirs(cache_dir, exist_ok=True)\n\n # Create a safe filename from model_id\n model_hash = hashlib.md5(model_id.encode()).hexdigest()\n return os.path.join(cache_dir, f\"{model_hash}.onnx\")\n\n\ndef download_model_with_lock(\n model_id: str,\n model_url: str,\n cache_dir: str = \"/tmp/onnx_models\",\n max_file_size_mb: int = 250,\n) -> str:\n \"\"\"\n Download model with thread locking to prevent concurrent downloads.\n \"\"\"\n cache_path = get_model_cache_path(model_id, cache_dir)\n\n # Get the lock for this specific model\n lock = get_model_lock(model_id)\n\n with lock:\n # Check if model already exists in cache\n if os.path.exists(cache_path):\n logger.info(f\"Using cached model: {cache_path}\")\n return cache_path\n\n # Download the model\n logger.info(f\"Downloading model {model_id} from {model_url}\")\n\n try:\n # Create temporary file\n temp_fd, temp_path = tempfile.mkstemp(suffix=\".onnx\", dir=cache_dir)\n\n try:\n with os.fdopen(temp_fd, \"wb\") as temp_file:\n response = requests.get(model_url, stream=True, timeout=300)\n response.raise_for_status()\n\n # Download with size checking\n downloaded_size = 0\n for chunk in response.iter_content(chunk_size=8192):\n if chunk:\n temp_file.write(chunk)\n downloaded_size += len(chunk)\n if downloaded_size > max_file_size_mb * 1024 * 1024:\n raise ValueError(\n f\"Downloaded file exceeds size limit of {max_file_size_mb}MB\"\n )\n\n # Atomic move from temp file to final location\n os.rename(temp_path, cache_path)\n logger.info(f\"Successfully downloaded and cached model: {cache_path}\")\n\n except Exception as e:\n # Clean up temp file on error\n try:\n os.unlink(temp_path)\n except OSError:\n pass\n raise ValueError(f\"Error downloading model {model_id}: {e}\")\n\n except Exception as e:\n logger.error(f\"Failed to download model {model_id}: {e}\")\n raise\n\n return cache_path\n\n\ndef get_model_from_stac(\n model_id: str,\n stac_api_url: str = \"https://stac.openeo.vito.be\",\n cache_dir: str = \"/tmp/onnx_models\",\n) -> Tuple[str, dict]:\n \"\"\"Fetch model file and metadata from STAC API with caching.\"\"\"\n try:\n collection_id = \"world-agri-commodities-models\"\n url = f\"{stac_api_url}/collections/{collection_id}/items/{model_id}\"\n\n response = requests.get(url, timeout=30)\n response.raise_for_status()\n\n item = response.json()\n properties = item.get(\"properties\", {})\n assets = item.get(\"assets\", {})\n\n # Get model URL\n model_asset = assets.get(\"model\")\n if not model_asset:\n raise ValueError(f\"No model asset found for {model_id}\")\n\n model_url = model_asset[\"href\"]\n\n # Download model with caching and locking\n model_path = download_model_with_lock(model_id, model_url, cache_dir)\n\n metadata = {\n \"input_bands\": properties.get(\"input_channels\", []),\n \"output_classes\": properties.get(\"output_classes\", []),\n \"output_shape\": properties.get(\"output_shape\", 0),\n \"framework\": properties.get(\"framework\", \"ONNX\"),\n \"region\": properties.get(\"region\", \"Unknown\"),\n \"model_url\": model_url,\n \"cached_path\": model_path,\n }\n\n logger.info(\n f\"Retrieved model {model_id} with {len(metadata['output_classes'])} output classes\"\n )\n return model_path, metadata\n\n except Exception as e:\n logger.error(f\"Failed to fetch model from STAC: {e}\")\n raise\n\n\n# Add ONNX paths\nsys.path.append(\"onnx_deps\")\nimport onnxruntime as ort\n\n# Constants for sanitization\n_INF_REPLACEMENT = 1e6\n_NEG_INF_REPLACEMENT = -1e6\n\n\n@functools.lru_cache(maxsize=1)\ndef _load_ort_session(model_id: str) -> Tuple[ort.InferenceSession, dict]:\n \"\"\"Loads an ONNX model from STAC and returns session with metadata\"\"\"\n model_path, metadata = get_model_from_stac(model_id)\n\n try:\n session = ort.InferenceSession(model_path)\n logger.info(f\"Loaded ONNX model for {model_id} on path {model_path}\")\n return session, metadata\n finally:\n # Clean up temporary file\n try:\n os.unlink(model_path)\n except OSError as e:\n raise RuntimeError(\n f\"Failed to delete temporary model file: {model_path}\"\n ) from e\n\n\ndef apply_metadata(metadata: CollectionMetadata, context: dict) -> CollectionMetadata:\n model_id = context.get(\"model_id\")\n _, metadata_dict = _load_ort_session(model_id)\n\n output_classes = metadata_dict[\"output_classes\"] + [\"ARGMAX\"]\n logger.info(f\"Applying metadata with output classes: {output_classes}\")\n return metadata.rename_labels(dimension=\"bands\", target=output_classes)\n\n\ndef preprocess_image(\n cube: xr.DataArray,\n) -> Tuple[np.ndarray, Dict[str, xr.DataArray], np.ndarray]:\n \"\"\"\n Prepare the input cube for inference:\n - Transpose to (y, x, bands)\n - Sanitize NaN/Inf\n - Return batch tensor, coords, and invalid-value mask\n \"\"\"\n # Reorder dims\n if 't' in cube.dims:\n cube = cube.squeeze('t', drop=True)\n\n reordered = cube.transpose(\"y\", \"x\", \"bands\")\n values = reordered.values.astype(np.float32)\n\n # Mask invalid entries\n mask_invalid = ~np.isfinite(values)\n\n # Replace NaN with 0, inf with large sentinel\n sanitized = np.where(np.isnan(values), 0.0, values) # TODO validate if this is okay\n sanitized = np.where(np.isposinf(sanitized), _INF_REPLACEMENT, sanitized)\n sanitized = np.where(np.isneginf(sanitized), _NEG_INF_REPLACEMENT, sanitized)\n\n # Add batch dimension\n input_tensor = sanitized[None, ...]\n logger.info(f\"Preprocessed tensor shape={input_tensor.shape}\")\n return input_tensor, dict(reordered.coords), mask_invalid\n\n\ndef run_inference(\n session: ort.InferenceSession, input_name: str, input_tensor: np.ndarray\n) -> np.ndarray:\n \"\"\"Run ONNX session and remove batch dimension from output.\"\"\"\n outputs = session.run(None, {input_name: input_tensor})\n pred = np.squeeze(outputs[0], axis=0)\n logger.info(f\"Inference output shape={pred.shape}\")\n return pred\n\n\n# TODO\ndef postprocess_output(\n pred: np.ndarray, # Shape: [y, x, bands]\n coords: Dict[str, xr.DataArray],\n mask_invalid: np.ndarray, # Shape: [y, x, bands]\n) -> xr.DataArray:\n \"\"\"\n Appends winning class index as new band to predictions:\n - Keeps original prediction values\n - Adds new band (-1 for invalid, 0..n-1 for winning class)\n \"\"\"\n\n # Apply sigmoid\n # sigmoid_probs = expit(pred) # shape [y, x, bands]\n\n # Optionally pick highest prob if needed\n # class_index = np.argmax(pred, axis=-1, keepdims=True)\n\n # Identify invalid pixels (any invalid in input bands)\n class_index = np.argmax(pred, axis=-1, keepdims=True) # shape [y, x, 1]\n\n invalid_mask = np.any(mask_invalid, axis=-1, keepdims=True)\n class_index = np.where(invalid_mask, -1, class_index).astype(np.float32)\n\n # Update band coordinates\n new_band_coords = np.arange(pred.shape[-1] + 1)\n\n combined = np.concatenate([pred, class_index], axis=-1)\n\n return xr.DataArray(\n combined,\n dims=(\"y\", \"x\", \"bands\"),\n coords={\"y\": coords[\"y\"], \"x\": coords[\"x\"], \"bands\": new_band_coords},\n attrs={\"description\": \"Original preds, sigmoid probs, class index\"},\n )\n\n\ndef apply_model(cube: xr.DataArray, model_id: str) -> xr.DataArray:\n \"\"\"\n Full inference pipeline:\n - Read ONNX model input shape\n - If model expects 15 bands \u2192 drop NDRE/EVI by name\n - Preprocess \u2192 infer \u2192 postprocess\n \"\"\"\n session, metadata = _load_ort_session(model_id)\n\n input_bands = metadata[\"input_bands\"]\n output_classes = metadata[\"output_classes\"]\n\n logger.info(f\"Running inference for model {model_id}\")\n logger.info(f\"Input bands: {input_bands}\")\n logger.info(f\"Output bands: {output_classes}\")\n\n # Validate input bands match expectations\n cube_bands = list(cube.coords[\"bands\"].values)\n if cube_bands != input_bands:\n logger.warning(f\"Band mismatch. Cube: {cube_bands}, Model: {input_bands}\")\n cube = cube.sel(bands=input_bands)\n\n input_tensor, coords, mask_invalid = preprocess_image(cube)\n input_name = session.get_inputs()[0].name\n raw_pred = run_inference(session, input_name, input_tensor)\n\n return postprocess_output(raw_pred, coords, mask_invalid)\n\n\n\ndef apply_datacube(cube: xr.DataArray, context: dict) -> xr.DataArray:\n \"\"\"\n Apply ONNX model per timestep in the datacube.\n \"\"\"\n model_id = context.get(\"model_id\")\n if not model_id:\n raise ValueError(\"model_id must be provided in context\")\n\n logger.info(f\"Applying model from STAC: {model_id}\")\n\n # Ensure correct dimension order\n cube = cube.transpose('y', 'x', 'bands', 't')\n\n if \"t\" in cube.dims:\n logger.info(\"Applying model per timestep via groupby-map.\")\n # Use isel to handle time dimension properly\n def process_timestep(da):\n return apply_model(da, model_id)\n\n return cube.groupby('t').map(process_timestep)\n else:\n logger.info(\"Single timestep: applying model once.\")\n return apply_model(cube, model_id)\n\n\n\n\n\n" }, "result": true } @@ -762,41 +768,71 @@ ] } }, + "loadstac1": { + "process_id": "load_stac", + "arguments": { + "bands": [ + "MAP" + ], + "spatial_extent": { + "from_parameter": "spatial_extent" + }, + "url": "https://www.stac.lcfm.dataspace.copernicus.eu/collections/LCFM_TCD-10_CDSE_v100" + } + }, + "reducedimension4": { + "process_id": "reduce_dimension", + "arguments": { + "data": { + "from_node": "loadstac1" + }, + "dimension": "t", + "reducer": { + "process_graph": { + "min1": { + "process_id": "min", + "arguments": { + "data": { + "from_parameter": "data" + } + }, + "result": true + } + } + } + } + }, "renamelabels2": { "process_id": "rename_labels", "arguments": { "data": { - "from_node": "applyneighborhood1" + "from_node": "reducedimension4" }, "dimension": "bands", "target": [ - "Background", - "Other_Large_Scale_Cropland", - "Pasture", - "Mining", - "Other_Small_Scale_Cropland", - "Roads", - "Forest", - "Plantation_Forest", - "Coffee", - "Built_up", - "Water", - "Oil_Palm", - "Rubber", - "Cacao", - "Avocado", - "Soy", - "Sugar", - "Maize", - "Banana", - "Pineapple", - "Rice", - "Wood_Logging", - "Cashew", - "Tea", - "Other", - "ARGMAX" + "Tree_cover_density_2020" ] + } + }, + "mergecubes7": { + "process_id": "merge_cubes", + "arguments": { + "cube1": { + "from_node": "applyneighborhood1" + }, + "cube2": { + "from_node": "renamelabels2" + } + } + }, + "saveresult1": { + "process_id": "save_result", + "arguments": { + "data": { + "from_node": "mergecubes7" + }, + "format": "netCDF", + "options": {} }, "result": true } @@ -810,9 +846,9 @@ "executor-memoryOverhead": "1000m", "python-memory": "3000m", "max-executors": 10, + "allow_empty_cubes": true, "udf-dependency-archives": [ - "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/onnx_dependencies_1.16.3.zip#onnx_deps", - "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/WorldAgriCommodities/dynamic_models.zip#onnx_models" + "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/onnx_deps_python311.zip#onnx_deps" ] }, "parameters": [ @@ -901,15 +937,7 @@ "description": "Don't filter spatially. All data is included in the data cube.", "type": "null" } - ], - "default": { - "west": 677736, - "south": 624010, - "east": 694576, - "north": 638629, - "crs": "EPSG:32629" - }, - "optional": true + ] }, { "name": "temporal_extent", @@ -937,20 +965,22 @@ } ] } - }, - "default": [ - "2023-01-01", - "2024-01-01" - ], - "optional": true + } }, { "name": "crs", "description": "CRS of the output in ", + "schema": { + "type": "string" + } + }, + { + "name": "model_id", + "description": "Model identifier", "schema": { "type": "string" }, - "default": "EPSG:32629", + "default": "WorldAgriCommodities_Africa_v1", "optional": true } ]