In [None]:
from pathlib import Path
import rioxarray
import geopandas as gpd
import xarray as xr
import numpy as np
from matplotlib import pyplot as plt

from sklearn.ensemble import RandomForestClassifier
import joblib # For save and load of the model

## Load data

In [None]:
path_rgb_cutout = Path("data/sentinel2_rgb_res_20_cutout.tif")
path_positive_label = Path("data/waterbody_labels.gpkg")
path_negative_label = Path("data/none_waterbody_labels.gpkg")

In [None]:
rgb_cutout = rioxarray.open_rasterio(path_rgb_cutout)
gpd_positive = gpd.read_file(path_positive_label)
gpd_negative = gpd.read_file(path_negative_label)

## Investigate the data

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
rgb_cutout.plot.imshow(ax=ax, robust=True)
gpd_positive.plot(ax=ax, color='cyan', alpha=0.6, edgecolor='k')
gpd_negative.plot(ax=ax, color='red', alpha=0.6, edgecolor='k')

## Covert labels from vector to raster

In [None]:
def generate_label_array(raster, gdf_positive, gdf_negative):
    """
    Generate a label array from the raster and positive/negative GeoDataFrames.

    Parameters:
    raster (xarray.DataArray): The input raster data.
    gdf_positive (geopandas.GeoDataFrame): GeoDataFrame containing positive labels.
    gdf_negative (geopandas.GeoDataFrame): GeoDataFrame containing negative labels.

    Returns:
    xarray.DataArray: A label array where positive labels are 1, negative labels are 0, 
    and areas without labels are -1.
    """
    # Make positive labels
    pos_mask = xr.full_like(
        raster.isel(band=0).drop_vars("band"), fill_value=1, dtype=np.int32
    )
    pos_mask = pos_mask.rio.write_nodata(-1)
    pos_label_array = pos_mask.rio.clip(gdf_positive["geometry"], drop=False)

    # Make negative labels
    neg_mask = xr.full_like(
        raster.isel(band=0).drop_vars("band"), fill_value=0, dtype=np.int32
    )
    neg_mask = neg_mask.rio.write_nodata(-1)
    neg_label_array = neg_mask.rio.clip(gdf_negative["geometry"], drop=False)

    # Combine positive and negative labels
    label_array = -(pos_label_array * neg_label_array)

    return label_array

In [None]:
# Covert labels from vector to raster
label_array = generate_label_array(rgb_cutout, gpd_positive, gpd_negative)

# Plot labels over the RGB image
fig, ax = plt.subplots(figsize=(8, 8))
rgb_cutout.plot.imshow(ax=ax, robust=True)
label_array.plot.imshow(ax=ax, vmin=-1, vmax=1, alpha=0.5)

## Prepare training data

In [None]:
def prepare_training_data(image, labels):
    # Reshape input data to [n_instances, n_features]
    flattened = labels.flatten()
    positive_data = image.reshape((image.shape[0], -1))[:, flattened == 1].transpose()
    negative_data = image.reshape((image.shape[0], -1))[:, flattened == 0].transpose()
    positive_labels = np.full_like(positive_data[:,0], 1)
    negative_labels = np.full_like(negative_data[:,0], 0)
    train_data = np.concatenate((positive_data, negative_data))
    train_labels = np.concatenate((positive_labels, negative_labels))

    # Shuffle the training data
    indices = np.arange(train_data.shape[0])
    indices_shuffled = np.random.permutation(indices)
    train_data = train_data[indices_shuffled]
    train_labels = train_labels[indices_shuffled]

    return train_data, train_labels

In [None]:
train_data, train_labels = prepare_training_data(rgb_cutout.data, label_array.data)

print(f"dimensions of training data: {train_data.shape}")
print(f"dimensions of training labels: {train_labels.shape}")

## Train the RandomForestClassifier

In [None]:
# This automatically computes the dask arrays to convert them to numpy arrays for training
classifier = RandomForestClassifier(n_estimators=100, random_state=42)
classifier.fit(train_data, train_labels)

In [None]:
# Export the trained classifier
joblib.dump(classifier, 'binary_classifier_waterbody.pkl')

## Prediction on the cutout

In [None]:
input = rgb_cutout.data.reshape((rgb_cutout.data.shape[0], -1)).transpose()
input.shape

In [None]:
# Load the trained classifier
classifier = joblib.load('binary_classifier_waterbody.pkl')

predictions = classifier.predict_proba(input)
predictions

In [None]:
# Map of possibilities, take one band for now
prediction_map = predictions.transpose().reshape((2, rgb_cutout.data.shape[1], rgb_cutout.data.shape[2]))

In [None]:
img_extent = (rgb_cutout.x.min(), rgb_cutout.x.max(), rgb_cutout.y.min(), rgb_cutout.y.max())
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
rgb_cutout.plot.imshow(ax=axes[0], alpha=0.6, robust=True)
axes[0].imshow(prediction_map[0], cmap='Reds', alpha=0.7, extent=img_extent)
axes[0].set_title('Non-waterbody Probability')
axes[0].axis('off')
plt.colorbar(axes[0].images[1], ax=axes[0], shrink=0.7)
rgb_cutout.plot.imshow(ax=axes[1], alpha=0.6, robust=True)
axes[1].imshow(prediction_map[1], cmap='Blues', alpha=0.7, extent=img_extent)
axes[1].set_title('Waterbody Probability')
axes[1].axis('off')
plt.colorbar(axes[1].images[1], ax=axes[1], shrink=0.7)
plt.tight_layout()
