In [None]:
from pathlib import Path
import rioxarray
from matplotlib import pyplot as plt
import dask.array as da

import joblib # For save and load of the model

from dask.distributed import Client, LocalCluster, Lock
import xarray as xr

## Overview AoI

In [None]:
path_rgb_full = Path("data/sentinel2_rgb_res_20_size_8000_cog.tif")
rgb_full = rioxarray.open_rasterio(
    path_rgb_full, overview_level=1)

fig, ax = plt.subplots(figsize=(8, 8))
rgb_full.plot.imshow(ax=ax, robust=True)

## Load data and trained model

In [None]:
# Configure chunksize for the later processing
CHUNKSIZE = 1024

In [None]:
# Lazy loading of the RGB data, automatically chunked
rgb = rioxarray.open_rasterio(path_rgb_full, chunks={'band': -1, 'y': CHUNKSIZE, 'x': CHUNKSIZE})
rgb

In [None]:
classifier = joblib.load('binary_classifier_waterbody.pkl')

## Distributed prediction on large data

In [None]:
# IMPORT DASK CLUSTER HERE

# OR USE LOCAL DASK CLUSTER
local_cluster = LocalCluster(n_workers=2)
client = Client(local_cluster)

In [None]:
# Prediction on each band - keep it lazy
def predict_chunk(chunk, classifier):
    """Predict on a chunk of data"""
    # chunk is now an xarray DataArray
    original_shape = chunk.shape
    reshaped = chunk.data.reshape((chunk.shape[0], -1)).T
    
    # Predict probabilities
    probs = classifier.predict_proba(reshaped)

    # Reshape back to spatial dimensions with probability classes
    result = probs.T.reshape((2, original_shape[1], original_shape[2]))
    
    # Return as xarray DataArray with proper coordinates
    return xr.DataArray(
        result,
        dims=['band', 'y', 'x'],
        coords={
            'band': [0, 1],  # non-waterbody, waterbody
            'y': chunk['y'],
            'x': chunk['x']
        }
    )

In [None]:
# Apply prediction function using xarray.map_blocks
predictions = xr.map_blocks(
    predict_chunk,
    rgb,
    args=[classifier],
    template=xr.DataArray(
        da.zeros((2, rgb.sizes['y'], rgb.sizes['x']), chunks=(-1, CHUNKSIZE, CHUNKSIZE)),
        dims=['band', 'y', 'x'],
        coords={
            'band': [0, 1],
            'y': rgb['y'],
            'x': rgb['x']
        }
    )
)

predictions

## Save predictions

In [None]:
predictions.isel(band=0).rio.to_raster("./predictions_non_waterbody.tif", driver="COG")
predictions.isel(band=1).rio.to_raster("./predictions_waterbody.tif", driver="COG")

In [None]:
# # When data is larger than 4GB, zarr is a better option than geo tiff
# # This takes ~2mins
# predictions.to_zarr(
#     "predictions_waterbody_full.zarr",
#     mode="w",
# )

In [None]:
# Parallel save to GeoTIFF
# However this does not support COG

# from dask.distributed import Lock

# tiff_output_non_water = "./predictions_waterbody_full_none_water.tif"
# tiff_output_water = "./predictions_waterbody_full_water.tif"

# predictions.isel(band=0).rio.to_raster(
#     tiff_output_non_water,
#     tiled=True,
#     lock=Lock("rio"),
# )

## Plot predictions

In [None]:
# reduce resolution for visualization
predictions_water = rioxarray.open_rasterio("./predictions_waterbody.tif", overview_level=1)
predictions_non_water = rioxarray.open_rasterio("./predictions_non_waterbody.tif", overview_level=1)

In [None]:
img_extent = (predictions_water.x.min(), predictions_water.x.max(), predictions_water.y.min(), predictions_water.y.max())
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
rgb_full.plot.imshow(ax=axes[0], robust=True)
axes[0].imshow(predictions_non_water.data.squeeze(), cmap='Reds', alpha=0.7, extent=img_extent)
axes[0].set_title('Non-waterbody Probability')
axes[0].axis('off')
plt.colorbar(axes[0].images[1], ax=axes[0], shrink=0.7)
rgb_full.plot.imshow(ax=axes[1], alpha=0.6, robust=True)
axes[1].imshow(predictions_water.data.squeeze(), cmap='Blues', alpha=0.7, extent=img_extent)
axes[1].set_title('Waterbody Probability')
axes[1].axis('off')
plt.colorbar(axes[1].images[1], ax=axes[1], shrink=0.7)
plt.tight_layout()