# Land Accounts inference on Sentinel-2 GeoMAD with Random Forest

This workflow demonstrates how to use a
[Sentinel-2](https://www.esa.int/Applications/Observing_the_Earth/Copernicus/Sentinel-2)
[GeoMedian annual satellite imagery composite](https://github.com/digitalearthpacific/dep-geomad)
for segmenting land use / land cover (LULC) using a
[GPU-accelerated Random Forest classifier](https://developer.nvidia.com/blog/accelerating-random-forests-up-to-45x-using-cuml/).
We will pursue this objective by integrating ground truth land use land cover data
from the VBoS from 2020. To make this scalable to all of Vanuatu, we use an
[administrative boundaries dataset from Pacific data hub](https://pacificdata.org/data/dataset/2016_vut_phc_admin_boundaries/resource/66ae054b-9b67-4876-b59c-0b078c31e800).

In this notebook, we will demonstrate the following:

1. **Data Acquisition**:
   - We use **Sentinel-2 L2A** data accessed via the [Digital Earth Pacific STAC catalog](http://stac.digitalearthpacific.org/). The search is filtered by parameters like a region of interest (AOI) and time range to obtain suitable imagery.

2. **Preprocessing**:
   - The Sentinel-2 imagery contains several spectral bands (e.g., Red, Green, Blue, Near-Infrared, Short-wave Infrared). These are extracted and combined into a single dataset for analysis. Remote sensing indices useful for land use / land cover mapping are calculated from these bands. Additionally, the imagery is masked to remove areas outside the regions of interest so as to focus on the relevant pixels. We use 5 out of 6 provinces making up the nation of Vanuatu for training, and one for testing.

3. **Feature Extraction**:
   - Features for the classifier are extracted from the Sentinel-2 spectral bands. Here, we will use the reflectance values from the Red, Green, Blue, Near-Infrared (NIR), and Short-wave Infrared (SWIR) bands. We will compute remote sensing indices (NDVI, MNDWI, SAVI, BSI) from these bands as the final feature set.

4. **Ground Truth Data Integration**:
   - A shapefile containing polygons attributed by land cover/land use is loaded into a [GeoDataFrame](https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.html). This allows us to create multi-class labels for the pixels in the Sentinel-2 imagery.

5. **Data Splitting**:
   - To ensure correct model training, we split the features and labels into training (80%) and testing (20%) sets. A 'seed' value is used for the random number generator to ensure this random split is reproducible.

6. **Random Forest Classification**:
   - We train a **Random Forest** classifier to predict land use/land cover on a pixel-wise basis. The `n_estimators` parameter is a key hyperparameter, determining the number of decision trees in the forest. Random Forest leverages the collective wisdom of multiple decision trees to make accurate predictions.

7. **Prediction**:
   - We will use the trained classifier to predict the likelihood of lulc types for each pixel in the test image/province.

8. **Evaluation**:
   - After making predictions on the test partition, we evaluate the model's performance using metrics such as accuracy and F1-score. This allows us to assess the performance of the Random Forest model and the effectiveness of the selected features.

9. **Visualization**:
   - We visualize the predictions by plotting the classified map, where lulc types are indicated by specific color codes.

At the end, you will have trained a model to predict land use + land cover in Vanuatu.

In [None]:
!mamba install --channel rapidsai --quiet --yes cuml

In [None]:
import geopandas as gpd

# import hvplot.xarray
import matplotlib.pyplot as plt
import numpy as np
import odc.stac
import rasterio.features
import rioxarray
import xarray as xr

# from cuml import RandomForestClassifier
# from dask_ml.model_selection import train_test_split
# from geocube.api.core import make_geocube
from pystac_client import Client
from shapely.geometry import box, mapping, shape

## Data Acquisition

Let's read the LULC data into a GeoDataFrame.

A [GeoDataFrame](https://geopandas.org/en/stable/docs/reference/geodataframe.html) is a type of data structure used to store geographic data in Python, provided by the [GeoPandas](https://geopandas.org/en/stable/) library. It extends the functionality of a pandas DataFrame to handle spatial data, enabling geospatial analysis and visualization. Like a pandas DataFrame, a GeoDataFrame is a tabular data structure with labeled axes (rows and columns), but it adds special features to work with geometric objects, such as:
- a geometry column
- a CRS
- accessibility to spatial operations (e.g.  intersection, union, buffering, and spatial joins)

In [None]:
# Version of the LULC model (based on ROIs.zip)
VERSION = "v9"
YEAR = 2020  # year to run inference on
PROVINCE_INFERENCE = "SHEFA"  # Vanuatu province, choose from ["TORBA", "SANMA", "PENAMA", "MALAMPA", "TAFEA"]

In [None]:
# Download the administrative boundaries (2016_phc_vut_pid_4326.geojson)
!wget https://pacificdata.org/data/dataset/9dba1377-740c-429e-92ce-6a484657b4d9/resource/3d490d87-99c0-47fd-98bd-211adaf44f71/download/2016_phc_vut_pid_4326.geojson

Read and inspect the datasets.

In [None]:
lulc_gdf = gpd.read_file("./ROIs_v9.zip")  # "./ROIs_v5.shp")

In [None]:
admin_boundaries_gdf = gpd.read_file("./2016_phc_vut_pid_4326.geojson")

In [None]:
admin_boundaries_gdf

Create raster image and label xarray datarrays for each province.

In [None]:
if not admin_boundaries_gdf.index.name == "pname":
    admin_boundaries_gdf = admin_boundaries_gdf.set_index(
        keys="pname"  # set province name as the index
    )

In [None]:
admin_boundaries_gdf

Get geometries of one province.

In [None]:
GEOM_INFERENCE = admin_boundaries_gdf.loc[PROVINCE_INFERENCE].geometry
GEOM_INFERENCE

Get Sentinel-2 GeoMedian composite data for year 2020 for one province

In [None]:
STAC_URL = "http://stac.digitalearthpacific.org/"
stac_client = Client.open(STAC_URL)

In [None]:
gdf_test = lulc_gdf.query(expr=f"Pname == '{PROVINCE_INFERENCE}'")

s2_search = stac_client.search(
    collections=["dep_s2_geomad"],
    intersects=GEOM_INFERENCE,
    datetime=str(YEAR),
)
# Retrieve all items from search results
s2_items = s2_search.item_collection()
print("len(s2_items): ", len(s2_items))

s2_data_inference = odc.stac.load(
    items=s2_items,
    bands=["blue", "green", "red", "nir08", "swir16"],
    chunks={"x": 1024, "y": 1024, "bands": -1, "time": -1},
    resolution=20,
)
s2_data_inference

Buffer the geometries to include some coastal offshore areas to account for any
classes/ROIs that might be relevant and overlapping.

In [None]:
# Keep projection aligned with raster
raster_crs = s2_data_inference.rio.crs
print(raster_crs)

# Get only the select province and reproject
gdf_reprojected_test = admin_boundaries_gdf.loc[[PROVINCE_INFERENCE]].to_crs(
    crs=raster_crs
)

# Buffer in raster units (meters if UTM)
geom_buffered_test = gdf_reprojected_test.buffer(distance=5)[PROVINCE_INFERENCE]
geom_buffered_test

Clip the Sentinel-2 data to be within the buffered geometries only.

In [None]:
# Clip inference province
s2_clipped_inference = s2_data_inference.rio.clip(geometries=[geom_buffered_test])

In [None]:
# Plot inference province
s2_rgb = s2_clipped_inference[["red", "green", "blue"]]
s2_rgb_array = s2_rgb.to_array("band")  # now dims: band, y, x
s2_rgb_array_squeezed = s2_rgb_array.squeeze(dim="time", drop=True)

In [None]:
s2_rgb_array_squeezed.plot.imshow(size=4, vmin=0, vmax=3000)

Calculate remote sensing indices.

In [None]:
# Calculate remote sensing indices useful for mapping LULC
def compute_indices(ds):
    red = ds["red"]
    green = ds["green"]
    blue = ds["blue"]
    nir = ds["nir08"]
    swir = ds["swir16"]
    eps = 1e-6
    return xr.Dataset(
        {
            "NDVI": (nir - red) / (nir + red + eps),
            "MNDWI": (green - swir) / (green + swir + eps),
            "SAVI": ((nir - red) / (nir + red + 0.5 + eps)) * 1.5,
            "BSI": ((swir + red) - (nir + blue)) / ((swir + red) + (nir + blue) + eps),
        }
    )


index_data_test = compute_indices(s2_clipped_inference).squeeze("time", drop=True)
print(index_data_test)

Rasterize labels from the ROIs for inference.

In [None]:
# Rasterize labels
width_test, height_test = s2_clipped_inference.x.size, s2_clipped_inference.y.size
bands = ["red", "green", "blue", "nir08"]

gdf_test = gdf_test.to_crs(epsg=s2_clipped_inference.rio.crs.to_epsg())

gdf_rpg = lulc_gdf.to_crs(s2_clipped_inference.rio.crs)

unique_classes = gdf_rpg["ROI"].unique()
# class_mapping = {cls: i+1 for i, cls in enumerate(unique_classes)}
class_mapping = {
    cls: i for i, cls in enumerate(unique_classes)
}  # zero-based, assumes existence of no data

Flatten pixels and only retain the those that overlap with an ROI.
The labels (ROIs) are sparse, so we will throw out pixels in regions between ROIs (unlabeled).

In [None]:
features_test = index_data_test.to_array().stack(flattened_pixel=("y", "x"))
# labels_test = rasterized_labels_test.to_array().stack(flattened_pixel=("y", "x"))

features_test = features_test[:].transpose("flattened_pixel", "variable").compute()

print("features_test shape:", features_test.shape)

In [None]:
len(features_test)

In [None]:
features_test

## Random Forest Classification

Now we will set up a small [random forest classifider](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html) with 10 trees. We use a [seed](https://towardsdatascience.com/why-do-we-set-a-random-state-in-machine-learning-models-bb2dc68d8431) (`random_state`) to ensure reproducibility. Calling the `.fit()` method on the classifier will initiate training.

In [None]:
%%time
# Train a Random Forest classifier
# clf = RandomForestClassifier(n_estimators=100, random_state=42)  # n_estimators=10
# clf.fit(X_train.data, y_train.data)

# Load trained Random Forest classifier
clf = TODO

## Inference

Once the trained classifier has been loaded, we can use it to make predictions on one province.

In [None]:
y_pred = clf.predict(features_test)

As a reminder, these are what each class number represents.

In [None]:
print("Class mapping:")
for key, val in class_mapping.items():
    print(val, key)

## Visualization

In [None]:
predicted_map = y_pred.reshape((height_test, width_test))
predicted_map_xr = xr.DataArray(
    data=predicted_map,
    coords=s2_clipped_inference.coords,  # coords=rasterized_labels_test.coords
)
print(np.unique(y_pred))

In [None]:
predicted_map_xr.hvplot.image(height=600, rasterize=True, cmap="Set1")

In [None]:
# rasterized_labels_test.ROI_numeric.hvplot.image(rasterize=True, cmap="Set1")

In [None]:
compatible_array = predicted_map_xr.astype("int32")

# Rasterize to polygons
polygons = list(
    rasterio.features.shapes(
        compatible_array.values, transform=compatible_array.rio.transform()
    )
)

# Convert polygons to GeoDataFrame
prediction_gdf = gpd.GeoDataFrame(
    [{"geometry": shape(geom), "value": value} for geom, value in polygons],
    crs="EPSG:3832",
)
# print(prediction_gdf)
print(prediction_gdf.value.unique())

prediction_gdf.to_file(
    f"./predicted_lulc_utm_{PROVINCE_INFERENCE}_{YEAR}.geojson", driver="GeoJSON"
)

In [None]:
prediction_gdf.head(10)

You can run these predictions on every province, collect the geodataframes in a list, and combine them into a final, unified nationwide LULC vector dataset like so (placeholder code, you need to generate the predictions first):

In [None]:
# prediction_gdf_merged_nationwide = pd.concat(
#     [
#         prediction_gdf_TORBA,
#         prediction_gdf_SANMA,
#         prediction_gdf_PENAMA,
#         prediction_gdf_MALAMPA,
#         prediction_gdf_SHEFA,
#         prediction_gdf_TAFEA,
#     ],
#     ignore_index=True,
# )
#
# prediction_gdf_merged_nationwide.to_file(
#     f"./predicted_lulc_utm_nationwide_{YEAR}.geojson", driver="GeoJSON"
# )