# Using cloud-native tools to map invasive species with supervised machine learning and Sentinel 2

## Overview

### Aims

This notebook demonstrates a complete end-to-end workflow for mapping invasive species across large geographic areas using cloud-native geospatial tools and supervised machine learning. The specific objectives are:

- Learn how to discover and access satellite imagery from cloud-based data repositories using STAC catalogs
- Extract spectral information from satellite data at validated field locations
- Train a machine learning classifier (XGBoost) to distinguish invasive species from other land cover types
- Deploy the trained model efficiently across entire satellite scenes using parallel processing
- Generate spatially-explicit predictions of invasive species distribution for conservation planning and monitoring

### Structure

The notebook is organized into a data pipeline with three main phases:

1. **Data Preparation** (Sections 2-5): Load validation data, search for and retrieve Sentinel-2 imagery from AWS, and extract spectral features at known locations
2. **Model Development** (Section 6): Train and evaluate an XGBoost classifier on held-out test data
3. **Deployment & Export** (Section 7): Apply the trained model across the entire study area using parallel processing and export results as a cloud-optimized GeoTIFF

The example focuses on mapping invasive Pine trees in the Greater Cape Town water fund area, demonstrating how these techniques can support conservation monitoring and resource management decisions.

### 1. Load Python packages

In [None]:
#core python geospatial packages
import xarray as xr
import numpy as np
import rioxarray as riox
import geopandas as gpd
import xvec
from shapely.geometry import box, mapping

#data search
import stackstac
import pystac_client

#plotting
import hvplot.xarray
import holoviews as hv
#interactive plots
hvplot.extension('bokeh')

#static plots
#hvplot.extension('matplotlib')
#import matplotlib.pyplot as plt
#%matplotlib inline

#ml
import xgboost as xgb
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, ConfusionMatrixDisplay

#other
from dask.diagnostics import ProgressBar


### 2. Load invasive plant data

First we load our land cover and invasive plant location data. We create this using field data and/or manually inspected high-resolution imagery. 

We will use the Python package `geopandas` to handle spatial vector data. GeoPandas is a Python library designed to handle and analyze geospatial data, similar to how ArcGIS or the `sf` package in R work. It extends the popular `pandas` library to support spatial data, allowing you to work with vector data formats like shapefiles, GeoJSON, GeoPackage and more. GeoPandas integrates well with other Python libraries and lets you perform spatial operations like overlays, joins, and buffering in a way that's familiar if you're used to `sf` or ArcGIS workflows.

In [None]:
#read data in geopackage with geopandas

gdf = gpd.read_file('gctwf_invasive.gpkg')
gdf = gdf.to_crs("EPSG:32734")
bbox = gdf.total_bounds
gdf

#### Plot the data

In [None]:
#interactive plot
gdf[['name','geometry']].explore('name',tiles='https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}', attr='Google')


In [None]:
#We will need this later to turn codes back to names
# Get unique values
name_table = gdf[['class', 'name']].drop_duplicates().sort_values(by=['class'])

### 3. Search for remote sensing data using STAC

STAC (SpatioTemporal Asset Catalog) is a standardized specification for describing geospatial datasets and their metadata, making it easier to discover, access, and share spatiotemporal data like satellite imagery, aerial photos, and elevation models. STAC catalogs and APIs provide structured, searchable metadata that allow users to query datasets based on criteria like geographic location, time range, resolution, or cloud coverage.

On AWS, STAC datasets can be found in the [Registry of Open Data](https://registry.opendata.aws/), which hosts numerous public geospatial datasets in STAC format. Examples include satellite imagery from Landsat, Sentinel-2, and MODIS. You can use tools like the `pystac-client` Python library or STAC browser interfaces to explore and retrieve data directly from AWS S3 buckets.

Once we have used STAC to filter the data we want, we get URLs to the location of that data on AWS S3 object storage. If we have our own data directly stored on S3, we can skip this part and just use that URL directly, or we can create our own STAC catalog if we have a large collection of data.

In [None]:

#the location of the catalog we want to search (find this on AWS Registry of Open Data)
URL = "https://earth-search.aws.element84.com/v1"
catalog = pystac_client.Client.open(URL)

#we want data that intersect with this location
lat, lon = -33.80, 19.20
#in this time
datefilter = "2018-09-01/2018-09-30"

#search!
items = catalog.search(
    intersects=dict(type="Point", coordinates=[lon, lat]),
    collections=["sentinel-2-l2a"],
    datetime=datefilter,
    query={"eo:cloud_cover": {"lt": 10}}  # Filter for cloud cover less than 10%
).item_collection()

#how many matches?
len(items)

Let's print some info about this Sentinel-2 scene.

In [None]:
items[0]

### 4. Load our data

Now that we have the data we want, we can load it into an xarray. We could pass the S3 URL of the data directly to xarray, but the package `stackstac` does a bunch of additional data handling to return a neat result with lots of helpful metadata. We can, for example, give `stackstac` a list of multiple Sentinel-2 scenes and it will align and stack them along a time dimension.

In [None]:
stack = stackstac.stack(items[0],bounds=(bbox[0], bbox[1], bbox[2], bbox[3]),epsg=32734)

#stackstac creates a time dim, but we only have 1 date so we drop this
stack = stack.squeeze()

#select only the bands we need
stack = stack.sel(band=['blue', 'coastal', 'green', 'nir', 'nir08', 'nir09', 'red', 'rededge1', 'rededge2', 'rededge3', 'swir16', 'swir22'])

#combine bands into one chunk
stack = stack.chunk({'band':-1})

#look at the xarray
stack

Let's make a plot to see what it looks like in true color. We will use the package `hvplot`, which makes it very easy to create interactive plots from xarrays.

In [None]:
stack.sel(band=['red','green','blue']).hvplot.rgb(
    x='x', y='y', bands='band',rasterize=True,robust=True,data_aspect=1,title="True colour")

#### Shadow Masking 
In the image above some areas are shadowed by mountains. It is unlikely that we will be able to predict land cover in these areas, so lets mask them out. We will use a simple rule that says if the reflectance in the red and near-infrared is below a threshold, drop those pixels.

In [None]:
#select red and nir bands
red = stack.sel(band='red')
nir = stack.sel(band='nir')

# Set the reflectance threshold
threshold = 0.05

# Create a shadow mask: identify dark pixels across all bands
shadow_condition = (red < threshold) & (nir < threshold)

# Set shadowed pixels to nodata
stack = stack.where(~shadow_condition)

### 5. Extract Sentinel-2 data at point locations

Now we will extract the reflectance data at the locations where we have validated land cover. The package `xvec` makes this easy. It returns the result as an xarray.

In [None]:
# Extract points
point = stack.xvec.extract_points(gdf['geometry'], x_coords="x", y_coords="y",index=True)
point = point.swap_dims({'geometry': 'index'}).to_dataset(name='s2')


In [None]:
#lets actually run this and get the result
with ProgressBar():
    point = point.compute()

In [None]:
#drop points that contain nodata
condition = point.s2.notnull().any(dim='band')

# Apply the mask to keep only the valid indices
point = point.where(condition, drop=True)
point

In [None]:
#add label from geopandas
gxr =gdf[['class','group']].to_xarray()
point = point.merge(gxr.astype(int),join='left')
point

Let's select a single point and visualize the data we will be using to train our model.

In [None]:
pointp = point.isel(index=0)
pointp['center_wavelength'] = pointp['center_wavelength'].astype(float)
pointp['s2'].hvplot.scatter(x='center_wavelength',by='index',
                                         color='green',ylim=(0,0.3),alpha=0.5,legend=False,title = "Single point data")

Finally, our model will be trained on data covering natural vegetation in a specific area. It is important that we only predict in the areas that match our training data. We will therefore mask out non-natural vegetation using a polygon.

In [None]:
geodf = gpd.read_file('aoi.gpkg').to_crs("EPSG:32734")
geoms = geodf.geometry.apply(mapping)

### 6. Train ML model
We will be using a model called xgboost. There are many, many different kinds of ML models. xgboost is a class of models called gradient boosted trees, related to random forests. When used for classification, random forests work by creating multiple decision trees, each trained on a random subset of the data and features, and then averaging their predictions to improve accuracy and reduce overfitting. Gradient boosted trees differ in that they build trees sequentially, with each new tree focusing on correcting the errors of the previous ones. This sequential approach allows xgboost to create highly accurate models by iteratively refining predictions and addressing the weaknesses of earlier trees.

Our dataset has a label indicating which set (training or test) our data belong to. We will use this to split it.

In [None]:
#split into train and test
dtrain = point.where(point['group']==1,drop=True)
dtest = point.where(point['group']==2,drop=True)

#create separte datasets for labels and features
y_train = dtrain['class'].values.astype(int)
y_test = dtest['class'].values.astype(int)
X_train = dtrain['s2'].values.T
X_test = dtest['s2'].values.T

To train the model, we:

1. Create an XGBoost classifier object using the `XGBClassifier` class from the XGBoost library, specifying a set of reasonable hyperparameters.
2. Fit the model to our training data (`X_train` and `y_train`).

The model learns to associate spectral signatures with land cover classes, and will then be evaluated on the held-out test set.

We fit the model to our training data (`X_train` and `y_train`). Once training is complete, the model is ready for making predictions on the test set and for deployment across the full study area.

This will take a few seconds.

In [None]:
# Create and train the XGBoost model
model = xgb.XGBClassifier(
    max_depth=7,
    learning_rate=0.1,
    subsample=0.8,
    n_estimators=100,
    tree_method='hist'
)

# Fit the model to the training data
model.fit(X_train, y_train)

We will use our trained model to predict the classes of the test data and calculate accuracy.

Next, we assess how well the model performs for predicting Pine trees by calculating its precision and recall. Precision measures the accuracy of the positive predictions. It answers the question, "Of all the instances we labeled as Pines, how many were actually Pines?". Recall measures the model's ability to identify all actual positive instances. It answers the question, "Of all the actual Pines, how many did we correctly identify?". You may also be familiar with the terms Users' and Producers' Accuracy. Precision = Users' Accuracy, and Recall = Producers' Accuracy.

Finally, we create and display a confusion matrix to visualize the model's prediction accuracy across all classes.

In [None]:
y_pred = model.predict(X_test)

# Step 2: Calculate acc and F1 score for the entire dataset
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy: {acc}")

# Step 3: Calculate precision and recall for Pine
precision_pine = precision_score(y_test, y_pred, labels=[2], average='macro', zero_division=0)
recall_pine = recall_score(y_test, y_pred, labels=[2], average='macro', zero_division=0)

print(f"Precision for Pines: {precision_pine}")
print(f"Recall for Pines: {recall_pine}")

# Step 4: Plot the confusion matrix
conf_matrix = confusion_matrix(y_test, y_pred,normalize='pred')

ConfusionMatrixDisplay(confusion_matrix=conf_matrix,display_labels=list(name_table['name'])).plot()

### 7. Predict over entire study area

We now have a trained model and are ready to deploy it across an entire Sentinel-2 scene to map the distribution of invasive plants. This involves a large volume of data, so rather than processing it all at once, we will apply the model's `.predict()` method in parallel across manageable chunks of the Sentinel-2 xarray.

#### Parallel vs. sequential processing

In serial processing, tasks execute one after another, so total runtime is the sum of all steps. In parallel processing, the work is split into independent chunks that run simultaneously across multiple workers, with results combined at the end. This can dramatically reduce processing time — a 3-hour serial job might take roughly 1 hour across 3 workers. Geospatial workflows are naturally suited to this approach because operations on individual tiles or regions are typically independent of one another.

![](img/series_parallel.png)

#### How parallel processing works with Dask

When we loaded our Sentinel-2 data earlier, xarray automatically divided it into smaller pieces called **chunks** — think of them as tiles in a mosaic. Dask, the parallel computing library working behind the scenes, uses these chunks to orchestrate efficient computation.

Rather than computing results immediately, Dask builds a **task graph** — a recipe mapping out which operations to perform on which chunks. When we define operations (like applying our model), Dask records the instructions without executing them. Only when we call `.compute()` does Dask execute the graph, processing multiple chunks simultaneously across available CPU cores.

![](img/dummy_graph.png)

This has two key advantages: (1) we can work with datasets larger than memory because only the active chunks need to be loaded at any given time, and (2) operations run faster because chunks are processed concurrently rather than one at a time.

**Monitoring progress with the Dask Dashboard**: You can track the progress of Dask operations in real time using the Dask Dashboard. In Jupyter environments like SageMaker Studio Lab, look for a dashboard link in the output. It displays task execution, memory usage, and worker activity — useful for spotting bottlenecks and confirming your computation is progressing as expected.

**Learning more about Dask**: To go deeper, see the [Dask documentation](https://docs.dask.org/) for tutorials on task graphs, distributed computing, and optimization. The [Dask tutorial](https://tutorial.dask.org/) is also an excellent interactive starting point.

Here is the function that we will actually apply to each chunk. Simple really. The hard work is getting the data into and out of this function.

In [None]:
def predict_on_chunk(chunk, model):
    probabilities = model.predict_proba(chunk)
    return probabilities

Now we define the function that takes as input the Sentinel-2 xarray and passes it to the predict function. This is composed of three parts:

**Part 1:** Applies all the transformations that need to be done before the data goes to the model. It sets a condition to identify valid data points where reflectance values are greater than zero and stacks the spatial dimensions (x and y) into a single dimension.

**Part 2:** Applies the machine learning model to the data in parallel, predicting class probabilities for each data point using xarray's `apply_ufunc` method. Most of the function involves defining what to do with the dimensions of the old dataset and the new output.

**Part 3:** Unstacks the data to restore its original dimensions, sets spatial dimensions and coordinate reference system (CRS), clips the data, and transposes the data to match expected formats before returning the results.

In [None]:
def predict_xr(ds,geometries):

    #part 1 - data prep
    #condition to use for masking no data later
    condition = (ds > 0).any(dim='band')

    #stack the data into a single dimension. This will be important for applying the model later
    ds = ds.stack(sample=('x','y'))


    #part 2 - apply the model over chunks
    result = xr.apply_ufunc(
        predict_on_chunk,
        ds,
        input_core_dims=[['band']],#input dim with features
        output_core_dims=[['class']],  # name for the new output dim
        exclude_dims=set(('band',)),  #dims to drop in result
        output_sizes={'class': 10}, #length of the new dimension
        output_dtypes=[np.float32],
        dask="parallelized",
        kwargs={'model': model}
    )

    #part 3 - post-processing
    result = result.unstack('sample') #remove the stack
    result = result.rio.set_spatial_dims(x_dim='x',y_dim='y') #set the spatial dims
    result = result.rio.write_crs("EPSG:32734") #set the CRS
    result = result.rio.clip(geometries).where(condition) #clip to the protected areas and no data
    result = result.transpose('class', 'y', 'x') #transpose the data - rio expects it this way
    return result.compute()

Now we can actually run this. It should take about 30–60 seconds (to go through a 10 GB Sentinel-2 scene!).

In [None]:
with ProgressBar():
    predicted  = predict_xr(stack,geoms)

In [None]:
predicted

Now we can view our result. We will plot the probability that a pixel is covered in invasive Pine trees.

In [None]:
#reproject
predicted = predicted.rio.reproject("EPSG:4326",nodata=np.nan)
#select only pines
predicted_plot = predicted.isel({'class':2})
#set low probability to NaN
predicted_plot = predicted_plot.where(predicted_plot > 0.5, np.nan)

In [None]:
#plot with a satellite basemap
predicted_plot.hvplot(tiles=hv.element.tiles.EsriImagery(), 
                              project=True,clim=(0,1),
                              cmap='magma',frame_width=800,data_aspect=1,alpha=0.7,title='Pine probability')

Lastly, we export to a geotiff. We can use rioxarray to do this. Now we can explore the map in our desktop GIS if desired.

In [None]:
predicted.rio.to_raster('gctwf_invasive.tiff',driver="COG")

This writes the file to the disk of our machine. If we are doing this on the cloud, it is wiser to write to cloud storage (S3), as this is much cheaper than local disk and almost infinitely scalable. We can also access this data from other machines on AWS, or share it with external collaborators.

In [None]:
import s3fs

fs = s3fs.S3FileSystem(anon=True)

with fs.open("s3://my-bucket/invasive_data/gctwf_invasive.tiff", "wb") as f:
    predicted.rio.to_raster(f, driver="COG")

### Write to Icechunk

We could also use Icechunk to store this data. Choosing between IceChunk/Zarr and GeoTIFF depends on your use case:

GeoTIFF:
- Single-file format, good for simple, static raster data
- Limited to 2D or 3D arrays
- Poor for very large datasets
- No built-in versioning

IceChunk/Zarr:
- Chunked storage, enabling efficient partial data access
- Scales to massive, multi-dimensional datasets
- Cloud-native, parallel I/O for faster processing
- Supports versioning (IceChunk) and collaborative workflows
- Ideal for time-series analysis or complex data cubes

In [None]:
import icechunk
from icechunk.xarray import to_icechunk

# 1. Configure S3 storage with anonymous access
storage = icechunk.s3_storage(
    bucket="my-bucket",
    prefix="invasive_data/gctwf_invasive",
    region="us-west-2",    
    anonymous=True,
)

# 2. Create a new repository
repo = icechunk.Repository.create(storage)

# 3. Open a writable session on the main branch
session = repo.writable_session("main")

# 4. Write your xarray dataset
to_icechunk(ds, session)

# 5. Commit
snapshot_id = session.commit("initial write")

### Credits

This lesson has borrowed heavily from the following resources, which are also a great place to learn more about handling large geospatial data in Python:

- [The Carpentries Geospatial Python lesson by Ryan Avery](https://carpentries-incubator.github.io/geospatial-python/)
- [The xarray user guide](https://docs.xarray.dev/en/stable/user-guide/index.html)
- [An Introduction to Earth and Environmental Data Science](https://earth-env-data-science.github.io/intro.html)

Another good place to start learning more is the [Cloud-Native Geospatial Foundation](https://cloudnativegeo.org/), which curates a community using and developing cloud-native geospatial tools.

A deeper dive into using remote sensing for invasive species mapping can be found on the [NASA Applied Remote Sensing Training Program](https://www.earthdata.nasa.gov/learn/trainings/airborne-data-applications-invasive-species-mapping). This training contains an example based on this notebook that dives deeper into some of the machine learning concepts and data analysis.