# Using cloud-naive tools to map invasive species with supervised machine learning and Sentinel 2

## Overview

In this notebook, we will use existing data of verified land cover and invasive species locations to extract remote sensing data from the European Space Agency satellite Sentinel 2. We will then train a machine learning model to predict invasive plant occurrence, and finally, we will apply this model to Sentinel 2 data acquired at multiple dates to monitor the spread and clearing efforts undertaken as part of TNC's Greater Cape Town water fund.

### 1. Load Python packages

In [None]:
#core
import xarray as xr
import numpy as np
import rioxarray as riox
import geopandas as gpd
import xvec
from shapely.geometry import box, mapping
#data search
import stackstac
import pystac_client
#plotting
import hvplot.xarray
import holoviews as hv
hvplot.extension('bokeh')
#ml
import xgboost as xgb
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, ConfusionMatrixDisplay
#other
from dask.diagnostics import ProgressBar

#datashader
#geoviews

### 2. Load invasive plant data

First we load our land cover and invasive plant location data. We create this in our GIS (Arc or QGIS), using field data and manually inspected high-resolution imagery. 

We will use the python package `geopandas` to handle spatial vector data. GeoPandas is a Python library designed to handle and analyze geospatial data, similar to how ArcGIS or the sf package in R work. It extends the popular pandas library to support spatial data, allowing you to work with vector data formats like shapefiles, GeoJSON, geopackage and more. GeoPandas integrates well with other Python libraries and lets you can perform spatial operations like overlays, joins, and buffering in a way that's familiar if you're used to sf or ArcGIS workflows.

In [None]:
#read data in geopackage with geopandas

gdf = gpd.read_file('gctwf_invasive.gpkg')
gdf = gdf.to_crs("EPSG:32734")
bbox = gdf.total_bounds
#interactive plot
gdf[['name','geometry']].explore('name',tiles='https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}', attr='Google')


In [None]:
#We will need this later to turn codes back to names
# Get unique values
name_table = gdf[['class', 'name']].drop_duplicates().sort_values(by=['class'])
name_table

### 3. Search for remote sensing data using STAC

STAC (SpatioTemporal Asset Catalog) is a standardized specification for describing geospatial datasets and their metadata, making it easier to discover, access, and share spatial-temporal data like satellite imagery, aerial photos, and elevation models. STAC catalogs and APIs provide structured, searchable metadata that allow users to query datasets based on criteria like geographic location, time range, resolution, or cloud coverage.

On AWS, STAC datasets can be found in the [Registry of Open Data](https://registry.opendata.aws/), which hosts numerous public geospatial datasets in STAC format. Examples include satellite imagery from Landsat, Sentinel-2, and MODIS. You can use tools like the pystac-client Python library or STAC browser interfaces to explore and retrieve data directly from AWS S3 buckets.

Once we have used stac to filter the data we want, we get urls to the location of that data on AWS s3 object storage. If we have our own data direclty stored on s3, we can skip this part and just use that url directly, or we can create our own stac catalog if we have a large collection of data.

In [None]:

#the location of the catalog we want to search (find this on AWS Registry of Open Data)
URL = "https://earth-search.aws.element84.com/v1"
catalog = pystac_client.Client.open(URL)

#we want data that intersect with this location
lat, lon = -33.80, 19.20
#in this time
datefilter = "2018-09-01/2018-09-30"

#search!
items = catalog.search(
    intersects=dict(type="Point", coordinates=[lon, lat]),
    collections=["sentinel-2-l2a"],
    datetime=datefilter,
    query={"eo:cloud_cover": {"lt": 10}}  # Filter for cloud cover less than 10%
).item_collection()

#how many matches?
len(items)

Lets print some info about this Sentinel scene

In [None]:
items[0]

### 4. Load our data
Now that we have the data we want we can load it into an xarray. We could pass the s3 url of the data direclty to xarray, but the package `stackstac` does a bunch of additional data handling to return a neat result with lots of helpful metadata. We can, for example, give `stackstac` a list of multiple Sentinel 2 scenes and it will align and stack them along a time dimension.

In [None]:
stack = stackstac.stack(items[0],bounds=(bbox[0], bbox[1], bbox[2], bbox[3]))

#stackstac create a time dim, but we only have 1 date so we drop this
stack = stack.squeeze()

#select only the bands we need
stack = stack.sel(band=['blue', 'coastal', 'green', 'nir', 'nir08', 'nir09', 'red', 'rededge1', 'rededge2', 'rededge3', 'swir16', 'swir22'])

#combine bands into one chunk
stack = stack.chunk({'band':-1})

Lets make a plot to see what it looks like in true color. We will use the package `hvplot` which makes it very easy to create interactive plots from xarrays

In [None]:
stack.sel(band=['red','green','blue']).hvplot.rgb(
    x='x', y='y', bands='band',rasterize=True,robust=True,data_aspect=1,title="True colour")

What about a single band?

In [None]:
stack.sel(band='red').hvplot(
    x='x', y='y',rasterize=True,robust=True,data_aspect=1,cmap='magma',clim=(0,0.2),title='Red reflectance')

#### Shadow Masking 
In the image above some areas are shadowed by mountains. It is unlikely that we will be able to predicct land cover in these areas, so lets mask them out. We will use a simple rule that says if the reflectance in the red and near-infrared is below a threshold, drop those pixels.

In [12]:
#select red and nir bands
red = stack.sel(band='red')
nir = stack.sel(band='nir')

# Set the reflectance threshold
threshold = 0.05

# Create a shadow mask: identify dark pixels across all bands
shadow_condition = (red < threshold) & (nir < threshold)

# Set shadowed pixels to nodata
stack = stack.where(~shadow_condition)

### 5. Extract Sentinel data at point locations
Now we will extract the reflectance data at the locations where we have validated land cover. The package `xvec` makes this easy. It returns the result as a xarray

In [13]:
# Extract points
point = stack.xvec.extract_points(gdf['geometry'], x_coords="x", y_coords="y",index=True)
point = point.swap_dims({'geometry': 'index'}).to_dataset(name='s2')


In [14]:
#lets actually run this and get the result
with ProgressBar():
    point = point.compute()

In [None]:
point

In [None]:
#drop points that contain nodata
condition = point.s2.notnull().any(dim='band')

# Apply the mask to keep only the valid indices
point = point.where(condition, drop=True)
point

In [None]:
#add label from geopandas
gxr =gdf[['class','group']].to_xarray()
point = point.merge(gxr.astype(int),join='left')
point

Lets select a single point and visualize the data we will be using to train our model

In [None]:
pointp = point.isel(index=0)
pointp['center_wavelength'] = pointp['center_wavelength'].astype(float)
pointp['s2'].hvplot.scatter(x='center_wavelength',by='index',
                                         color='green',ylim=(0,0.3),alpha=0.5,legend=False,title = "Single point data")

### 6. Train ML model
We will be using a model called xgboost. There are many, many different kinds of ML models. xgboost is a class of models called gradient boosted trees, related to random forests. When used for classification, random forests work by creating multiple decision trees, each trained on a random subset of the data and features, and then averaging their predictions to improve accuracy and reduce overfitting. Gradient boosted trees differ in that they build trees sequentially, with each new tree focusing on correcting the errors of the previous ones. This sequential approach allows xgboost to create highly accurate models by iteratively refining predictions and addressing the weaknesses of earlier trees.

Our dataset has a label indicating which set (training or test), our data belong to. We wil use this to split it

In [19]:
#split into train and test
dtrain = point.where(point['group']==1,drop=True)
dtest = point.where(point['group']==2,drop=True)

#create separte datasets for labels and features
y_train = dtrain['class'].values.astype(int)
y_test = dtest['class'].values.astype(int)
X_train = dtrain['s2'].values.T
X_test = dtest['s2'].values.T

The steps we will go through to train the model are:

First, we define the hyperparameter grid. Initially, we set up a comprehensive grid (param_grid) with multiple values for several hyperparameters of the XGBoost model.

Next, we create an XGBoost classifier object using the XGBClassifier class from the XGBoost library.

We then set up the RandomizedSearchCV object using our defined XGBoost model and the hyperparameter grid. RandomizedSearchCV allows us to perform a search over the specified hyperparameter values to find the optimal combination that results in the best model performance. We choose a 5-fold cross-validation strategy (cv=5), meaning we split our training data into five subsets to validate the model's performance across different data splits. We use accuracy as our scoring metric to evaluate the models.

WE fit the RandomizedSearchCV object to our training data (X_train and y_train). This process involves training multiple models with different hyperparameter combinations and evaluating their performance using cross-validation. Our goal is to identify the set of hyperparameters that yields the highest accuracy.

Once the search completes, we print out the best set of hyperparameters and the corresponding best score. The `random_search.best_params_` attribute provides the combination of hyperparameters that achieved the highest cross-validation accuracy, while the `random_search.best_score_ attribute` shows the corresponding accuracy score. Finally, we extract the best model (best_model) from the search results. This model is trained with the optimal hyperparameters and is ready for making predictions or further analysis in our classification task.

This will take approx 5-20 seconds

In [None]:
param_grid = {
    'max_depth': [5, 9],         
    'learning_rate': [0.1, 0.2], 
    'subsample': [0.6, 0.9],     
    'n_estimators': [50, 200]    
}

# Create the XGBoost model object
xgb_model = xgb.XGBClassifier(tree_method='hist')

# Create the search object
random_search = RandomizedSearchCV(
    xgb_model,
    param_grid,     
    n_iter=10, 
    cv=5,
    scoring='accuracy', 
    random_state=42, 
    n_jobs=-1)

# Fit the search object to the training data
random_search.fit(X_train, y_train)

# Print the best set of hyperparameters and the corresponding score
print("Best set of hyperparameters: ", random_search.best_params_)
print("Best score: ", random_search.best_score_)
best_model = random_search.best_estimator_

We will use our best model to predict the classes of the test data and calculate accuracy.

Next, we assess how well the model performs for predicting Pine trees by calculating its precision and recall. Precision measures the accuracy of the positive predictions. It answers the question, "Of all the instances we labeled as Pines, how many were actually Pines?". Recall measures the model's ability to identify all actual positive instances. It answers the question, "Of all the actual Pines, how many did we correctly identify?". You may also be familiar with the terms Users' and Producers' Accuracy. Precision = User' Accuracy, and Recall = Producers' Accuracy.

Finally, we create and display a confusion matrix to visualize the model's prediction accuracy across all classes

In [None]:
y_pred = best_model.predict(X_test)

# Step 2: Calculate acc and F1 score for the entire dataset
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy: {acc}")

# Step 3: Calculate precision and recall for Pine
precision_pine = precision_score(y_test, y_pred, labels=[2], average='macro', zero_division=0)
recall_pine = recall_score(y_test, y_pred, labels=[2], average='macro', zero_division=0)

print(f"Precision for Pines: {precision_pine}")
print(f"Recall for Pines: {recall_pine}")

# Step 4: Plot the confusion matrix
conf_matrix = confusion_matrix(y_test, y_pred,normalize='pred')

ConfusionMatrixDisplay(confusion_matrix=conf_matrix,display_labels=list(name_table['name'])).plot()


### 7. Predict over entire study area
We now have a trained model and are ready to deploy it to generate predictions across an entire Sentinel 2 scene and map the distribution of invasive plants. This involves handling a large volume of data, so we need to write the code to do this intelligently. We will accomplish this by applying the .predict() method of our trained model in parallel across the chunks of the sentinel 2 xarray. The model will receive one chunk at a time so that the data is not too large, but it will be able to perform this operation in parallel across multiple chunks, and therefore will not take too long.

Our model was only trained on data covering natural vegetaiton in a specific area, It is important that we only predict in the areas that match our training data. We will therefore mask out non-natural vegetation using a polygon.

In [22]:
geodf = gpd.read_file('aoi.gpkg').to_crs("EPSG:32734")
geoms = geodf.geometry.apply(mapping)

Here is the function that we will actually apply to each chunk. Simple really. The hard work is getting the data into and out of this function

In [23]:
def predict_on_chunk(chunk, model):
    probabilities = model.predict_proba(chunk)
    return probabilities

Now we define the funciton that takes as input the Sentinel 2 xarray and passes it to the predict function. This is composed of 3 parts:

Part 1: Applies all the transformations that need to be done before the data goes to the model. It sets a condition to identify valid data points where reflectance values are greater than zero and the stacks the spatial dimensions (x and y) into a single dimension.

Part 2: Applies the machine learning model to the normalized data in parallel, predicting class probabilities for each data point using xarray's apply_ufunc method. Most of the function invloves defining what to do with the dimensions of the old dataset and the new output

Part 3: Unstacks the data to restore its original dimensions, sets spatial dimensions and coordinate reference system (CRS), clips the data, and transposes the data to match expected formats before returning the results.

In [24]:
def predict_xr(ds,geometries):

    #part 1 - data prep
    #condition to use for masking no data later
    condition = (ds > 0).any(dim='band')

    #stack the data into a single dimension. This will be important for applying the model later
    ds = ds.stack(sample=('x','y'))


    #part 2 - apply the model over chunks
    result = xr.apply_ufunc(
        predict_on_chunk,
        ds,
        input_core_dims=[['band']],#input dim with features
        output_core_dims=[['class']],  # name for the new output dim
        exclude_dims=set(('band',)),  #dims to drop in result
        output_sizes={'class': 10}, #length of the new dimension
        output_dtypes=[np.float32],
        dask="parallelized",
        kwargs={'model': best_model}
    )

    #part 3 - post-processing
    result = result.unstack('sample') #remove the stack
    result = result.rio.set_spatial_dims(x_dim='x',y_dim='y') #set the spatial dims
    result = result.rio.write_crs("EPSG:32734") #set the CRS
    result = result.rio.clip(geometries).where(condition) #clip to the protected areas and no data
    result = result.transpose('class', 'y', 'x') #transpose the data - rio expects it this way
    return result.compute()

now we can actually run this. It should take about 30-60s (to go through an 10GB sentinel scene!)

In [None]:
with ProgressBar():
    predicted  = predict_xr(stack,geoms)

In [None]:
predicted

Now we can view our result. We will plot the probablity that a pixel is covered in invasive Pine trees

In [29]:
#reproject
predicted = predicted.rio.reproject("EPSG:4326",nodata=np.nan)
#select only pines
predicted_plot = predicted.isel({'class':2})
#set low propbability to na
predicted_plot = predicted_plot.where(predicted_plot > 0.5, np.nan)

In [None]:
#plot with a satellite basemap
predicted_plot.hvplot(tiles=hv.element.tiles.EsriImagery(), 
                              project=True,clim=(0,1),
                              cmap='magma',frame_width=800,data_aspect=1,alpha=0.7,title='Pine probablity')

Lastly, we export to a geotiff. We can use rioxarray to do this. Now we can explore the map in our desktop GIS if desired.

In [31]:
predicted.rio.to_raster('gctwf_invasive.tiff',driver="COG")

### credits:
This lesson has borrowed heavily from the follwoing resrouces, which are also a great place to learn more about handling large geospatial data in python:

[The Carpentries Geospatial Python lesson by Ryan Avery](https://carpentries-incubator.github.io/geospatial-python/)  

[The xarray user guide](https://docs.xarray.dev/en/stable/user-guide/index.html) 

[An Introduction to Earth and Environmental Data Science](https://earth-env-data-science.github.io/intro.html)

Another good place to start learning more is the [Cloud-Native Geospatial Foundation](https://cloudnativegeo.org/), which curates a community using and developing cloud-native geospatial tools