# Apply the trained classifier to the full image

In this notebook, we will apply the trained `RandomForestClassifier` to the full image, to predict the waterbodies and non-waterbodies across the entire AOI. We will use Dask to distribute the prediction process across multiple chunks of the image, which allows us to handle large images efficiently.

In [None]:
from pathlib import Path
import rioxarray
from matplotlib import pyplot as plt
import dask.array as da

import joblib  # For save and load of the model

from dask.distributed import Client, LocalCluster, Lock
import xarray as xr

## Overview AoI

First, let's take another look at the area of interest (AoI) by visualizing the COG.

In [None]:
# Load and visualize the full RGB with overviews
path_rgb_full = Path("data/sentinel2_rgb_res_20_size_8000_cog.tif")
rgb_full = rioxarray.open_rasterio(
    path_rgb_full, overview_level=1)

fig, ax = plt.subplots(figsize=(8, 8))
rgb_full.plot.imshow(ax=ax, robust=True)

## Load data and trained model

In the following steps, we will load the full RGB image chunk-wise. Depends on the computation infrastructure, we can adjust the chunk size accordingly. With the `chunks` input argument in `rioxarray.open_rasterio`, we can "lazily" load the large image, i.e. create a task graph for loading the data in chunks, which will be executed later when we actually need the data.

In [None]:
# Configure chunksize for the later processing
CHUNKSIZE = 2000

# Lazy loading of the RGB data, automatically chunked
rgb = rioxarray.open_rasterio(path_rgb_full, chunks={'band': -1, 'y': CHUNKSIZE, 'x': CHUNKSIZE})
rgb

We load load the trained classifier from the previous step.

In [None]:
classifier = joblib.load('binary_classifier_waterbody.pkl')

## Distributed prediction on large data

We will initialize a Dask client, which connects to a Dask cluster for distributed computing. On an HPC infrastructure with SLURM, one can use the `dask-jobqueue` to start a Dask SLURM cluster, and connect to it using the cluster. When running locally, we can use the `LocalCluster` to start a local Dask cluster.

One can inspect the process through the dashboard link.

In [None]:
# IMPORT DASK CLUSTER HERE

# OR USE LOCAL DASK CLUSTER
local_cluster = LocalCluster(n_workers=3)
client = Client(local_cluster)

client

In [None]:
# Prediction on each band - keep it lazy
def predict_chunk(chunk, classifier):
    """Predict on a chunk of data"""
    # chunk is now an xarray DataArray
    original_shape = chunk.shape
    reshaped = chunk.data.reshape((chunk.shape[0], -1)).T
    
    # Predict probabilities
    probs = classifier.predict_proba(reshaped)

    # Reshape back to spatial dimensions with probability classes
    result = probs.T.reshape((2, original_shape[1], original_shape[2]))
    
    # Return as xarray DataArray with proper coordinates
    return xr.DataArray(
        result,
        dims=['band', 'y', 'x'],
        coords={
            'band': [0, 1],  # non-waterbody, waterbody
            'y': chunk['y'],
            'x': chunk['x']
        }
    )

Now we will apply the trained classifier to the full image in chunks. Many Python libraries, such as `numpy` functions, can work with Dask arrays directly, in those case we can directly use the Dask array as input. However since the `predict` function of `sklearn` classifiers does not support Dask arrays, we will use xarray's `map_blocks` function to apply the classifier to each chunk of the image. This function is suitable in this case, since the prediction can be applied independently to each chunk of the image.

In [None]:
# Apply prediction function using xarray.map_blocks
predictions = xr.map_blocks(
    predict_chunk,
    rgb,
    args=[classifier],
    template=xr.DataArray(
        da.zeros((2, rgb.sizes['y'], rgb.sizes['x']), chunks=(-1, CHUNKSIZE, CHUNKSIZE)),
        dims=['band', 'y', 'x'],
        coords={
            'band': [0, 1],
            'y': rgb['y'],
            'x': rgb['x']
        }
    )
)

predictions

One can also visualize the task graph of the prediction process using Dask's `visualize` function. Note that this requires the `graphviz` package to be installed in the environment, which has been added to the `environment.yml` file.

In [None]:
import dask
dask.visualize(predictions)

## Save predictions

Now the `predictions` variable has not been computed yet, it is still a Dask array with a task graph. There are multiple ways to compute the predictions. For example, one can use the `compute` method to evaluate the results, and collect the results in memory, if the scheduler memory capacity allows it. 

One can also save the predictions directly to a file, with the possibility to save each chunk directly without collecting the results in the memory of the scheduler.

In this example, we will save the predictions to a COG file, which is a nice option for visualizing large raster data. Unfortunately, `rioxarray` does not support parallel saving of COG files, so here we will collect the results per band, and save each band as a separate COG file.

In the later cells (commented out by default), we also provided examples of:
- parallel saving to a normal GeoTIFF file
- parallel saving to Zarr


In [None]:
# Save each band of the predictions to separate COG
# takes ~7 mins
predictions.isel(band=0).rio.to_raster("./predictions_non_waterbody.tif", driver="COG")
predictions.isel(band=1).rio.to_raster("./predictions_waterbody.tif", driver="COG")

In [None]:
# Parallel save to a normal GeoTIFF, by specifying lock
# However this does not support COG

# tiff_output_non_water = "./predictions_waterbody_full_none_water.tif"
# tiff_output_water = "./predictions_waterbody_full_water.tif"

# predictions.isel(band=0).rio.to_raster(
#     tiff_output_non_water,
#     tiled=True,
#     lock=Lock("rio"),
# )

In [None]:
# When data is larger than 4GB, zarr is a better option than geo tiff

# predictions.to_zarr(
#     "predictions_waterbody_full.zarr",
#     mode="w",
# )

## Plot predictions

Now we can visualize the predictions. First let's load the entire saved predictions with overview level, and inspect the results.

In [None]:
# Load the saved predictions with overview level
predictions_water = rioxarray.open_rasterio("./predictions_waterbody.tif", overview_level=1)
predictions_non_water = rioxarray.open_rasterio("./predictions_non_waterbody.tif", overview_level=1)

In [None]:
# Visualize the full predictions
img_extent = (predictions_water.x.min(), predictions_water.x.max(), predictions_water.y.min(), predictions_water.y.max())
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
rgb_full.plot.imshow(ax=axes[0], robust=True)
axes[0].imshow(predictions_non_water.data.squeeze(), cmap='Reds', alpha=0.8, extent=img_extent)
axes[0].set_title('Non-waterbody Probability')
axes[0].axis('off')
plt.colorbar(axes[0].images[1], ax=axes[0], shrink=0.7)
rgb_full.plot.imshow(ax=axes[1], alpha=0.6, robust=True)
axes[1].imshow(predictions_water.data.squeeze(), cmap='Blues', alpha=0.8, extent=img_extent)
axes[1].set_title('Waterbody Probability')
axes[1].axis('off')
plt.colorbar(axes[1].images[1], ax=axes[1], shrink=0.7)
plt.tight_layout()

We can also load the results lazily and zoom into specific areas.

In [None]:
# Load the saved predictions with overview level
predictions_water = rioxarray.open_rasterio("./predictions_waterbody.tif", chunks={'band': -1, 'y': CHUNKSIZE, 'x': CHUNKSIZE})
predictions_non_water = rioxarray.open_rasterio("./predictions_non_waterbody.tif", chunks={'band': -1, 'y': CHUNKSIZE, 'x': CHUNKSIZE})

In [None]:
# Select a cutout for visualization
# South west Friesland
y_idx_range = slice(100, 2100)
x_idx_range = slice(200, 2200)
predictions_water_cutout = predictions_water.isel(
    y=y_idx_range,
    x=x_idx_range
)
predictions_non_water_cutout = predictions_non_water.isel(
    y=y_idx_range,
    x=x_idx_range
)
rgb_cutout = rgb.isel(
    y=y_idx_range,
    x=x_idx_range
)

In [None]:
# Visualize the cutout predictions
img_extent = (predictions_water_cutout.x.min(), predictions_water_cutout.x.max(), predictions_water_cutout.y.min(), predictions_water_cutout.y.max())
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
rgb_cutout.plot.imshow(ax=axes[0], robust=True)
axes[0].imshow(predictions_non_water_cutout.data.squeeze(), cmap='Reds', alpha=0.8, extent=img_extent)
axes[0].set_title('Non-waterbody Probability')
axes[0].axis('off')
plt.colorbar(axes[0].images[1], ax=axes[0], shrink=0.7)
rgb_cutout.plot.imshow(ax=axes[1], alpha=0.6, robust=True)
axes[1].imshow(predictions_water_cutout.data.squeeze(), cmap='Blues', alpha=0.8, extent=img_extent)
axes[1].set_title('Waterbody Probability')
axes[1].axis('off')
plt.colorbar(axes[1].images[1], ax=axes[1], shrink=0.7)
plt.tight_layout()