# Mapping vegetation in California with SageMaker Geospatial Processing jobs

In this notebook, we will explore the process of mapping vegetation in California using [SageMaker Geospatial Processing jobs](https://docs.aws.amazon.com/sagemaker/latest/dg/geospatial-custom-operations.html). Our primary data source will be Sentinel-2 satellite imagery, which allows us to compute the normalized difference vegetation index (NDVI). By analyzing the NDVI values, we can identify areas with significant vegetation and compare them to historical data to detect outliers and track trends.

**Contents**
* [Setup SageMaker geospatial capabilities](#1)
* [Query and access Data](#2)
* [SageMaker geospatial processing jobs](#3)
* [Analyze the NDVI over the year](#4)
* [Monitor the vegetation health by comparing to historical data](#5)

<a id='1'></a>
## Setup SageMaker geospatial capabilities

In [None]:
# # install the necessary packages
# !pip install leafmap

In [None]:
import os
import time
import json
import random
import math
import cv2
import numpy as np
from glob import glob
import tifffile
from urllib.parse import urlparse
from IPython.display import JSON
import matplotlib.pyplot as plt

import geopandas
import leafmap.foliumap as leafmap
from shapely.geometry import Polygon
import rioxarray

import boto3
import sagemaker
import sagemaker_geospatial_map
from botocore import UNSIGNED
from botocore.config import Config

session = boto3.Session()
sg_client = session.client(service_name="sagemaker-geospatial")

### Inspect the area of interest

Before we proceed with the data analysis, it's crucial to confirm that we are focusing on the correct location. We will depict the area of interest using a polygon that represent the boundaries of California.

In [None]:
m = leafmap.Map(center=[37, -119], zoom=4)
m.add_basemap("OpenStreetMap")
ca_gdf = geopandas.read_file("./data/ca_polygon.geojson")
m.add_gdf(ca_gdf, layer_name="AOI", style={"color": "red"})
m

<a id='2'></a>

## Query the Sentinel-2 data

Using the defined polygon, we will query the Sentinel-2 Level-2A data for the study. Specifically, we will focus on data from 2022, filtering out images with cloud coverage exceeding 10%. This approach allows us to analyze the ground conditions and obtain accurate insights into the vegetation distribution within the selected area of interest in California.

In [None]:
start_time = time.time()
search_rdc_args = {
    "Arn": "arn:aws:sagemaker-geospatial:us-west-2:378778860802:raster-data-collection/public/nmqj48dcu3g7ayw8",  # sentinel-2 L2A
    "RasterDataCollectionQuery": {
        "AreaOfInterest": {
            "AreaOfInterestGeometry": {
                "PolygonGeometry": {
                    "Coordinates": [
                        [
                            [-124.3499619, 41.961687],
                            [-124.3719345, 41.9780236],
                            [-124.4158798, 40.2733979],
                            [-122.9766709, 37.9638526],
                            [-120.6585556, 34.5270956],
                            [-117.2527939, 32.5122532],
                            [-114.6929795, 32.7065972],
                            [-114.6490341, 34.96942],
                            [-119.9993759, 38.9787893],
                            [-120.0103623, 41.9861904],
                            [-124.3499619, 41.961687],
                        ]
                    ]
                }
            }
        },
        "TimeRangeFilter": {
            "StartTime": "2022-01-01T00:00:00Z",
            "EndTime": "2022-12-30T23:59:59Z",
        },
        "PropertyFilters": {
            "Properties": [{"Property": {"EoCloudCover": {"LowerBound": 0, "UpperBound": 10}}}],
            "LogicalOperator": "AND",
        },
    },
}

s2_items = []
s2_tile_ids = []
s2_geometries = {
    "id": [],
    "geometry": [],
}
while search_rdc_args.get("NextToken", True):
    search_result = sg_client.search_raster_data_collection(**search_rdc_args)
    for item in search_result["Items"]:
        s2_id = item["Id"]
        s2_tile_id = s2_id.split("_")[1]
        # filtering out tiles cover the same area
        if s2_tile_id not in s2_tile_ids:
            s2_tile_ids.append(s2_tile_id)
            s2_geometries["id"].append(s2_id)
            s2_geometries["geometry"].append(Polygon(item["Geometry"]["Coordinates"][0]))
        del item["DateTime"]
        s2_items.append(item)

    search_rdc_args["NextToken"] = search_result.get("NextToken")

print(f"{len(s2_items)} unique Sentinel-2 images found.")
print(f"{time.time() - start_time} seconds")

In [None]:
# plot out the sentinel-2 image footprints as a sanity check
s2_gdf = geopandas.GeoDataFrame(s2_geometries)

m = leafmap.Map(center=[37, -119])
m.add_basemap("OpenStreetMap")
m.add_gdf(s2_gdf, layer_name="Sentinel-2 Tiles", style={"color": "blue"})
m

In [None]:
# check one example of the image
item = s2_items[-1]
JSON(item)

<a id='3'></a>

## SageMaker geospatial processing jobs

- Manifest of data
- Custom script for processing
- Launch processing jobs



### Manifest of data
To process the Sentinel-2 data efficiently, we use an input manifest file where each line is a reference to the Sentinel-2 image. We divided the data into 10 jobs, with each job being processed by 20 instances. You can adjust the number of jobs and instances based on your workload to optimize the processing workflow.

Note: you might need to request a [quota increase](https://docs.aws.amazon.com/servicequotas/latest/userguide/request-quota-increase.html) to use 200 instances for processing jobs.

In [None]:
# generate data manifest for geospatial processing jobs
def s2_item_to_relative_metadata_url(item):
    parts = item["Assets"]["visual"]["Href"].split("/")
    tile_prefix = parts[4:-1]
    return "{}/{}.json".format("/".join(tile_prefix), item["Id"])


num_jobs = 10
num_instances_per_job = 20  # maximum 20

manifest_list = {}
for idx in range(num_jobs):
    manifest = [{"prefix": "s3://sentinel-cogs/sentinel-s2-l2a-cogs/"}]
    manifest_list[idx] = manifest

## split the manifest for N processing jobs
for idx, item in enumerate(s2_items):
    job_idx = idx % num_jobs
    manifest_list[job_idx].append(s2_item_to_relative_metadata_url(item))

In [None]:
# upload the manifest to S3
sagemaker_session = sagemaker.Session()
s3_bucket_name = sagemaker_session.default_bucket()  # Replace with your own bucket if needed
s3_bucket = session.resource("s3").Bucket(s3_bucket_name)

s3_prefix = "ca-ndvi"

s3_client = boto3.client("s3")
s3 = boto3.resource("s3")

manifest_dir = "manifests"
os.makedirs(manifest_dir, exist_ok=True)

for job_idx, manifest in manifest_list.items():
    manifest_file = f"{manifest_dir}/manifest{job_idx}.json"
    s3_manifest_key = s3_prefix + "/" + manifest_file
    with open(manifest_file, "w") as f:
        json.dump(manifest, f)

    s3_client.upload_file(manifest_file, s3_bucket_name, s3_manifest_key)
    print("Uploaded {} to {}".format(manifest_file, s3_manifest_key))

### Custom script for processing

In the next step, we will write custom code to compute the NDVI from the Sentinel-2 data. This script will serve as the core logic executed within the processing job. By leveraging the spectral bands (`red` and `near-infrared`) provided by the Sentinel-2 imagery, we will calculate the NDVI using a specific formula $(nir - red) / (nir + red)$. The NDVI values obtained will enable us to quantitatively assess the vegetation health and density in the selected area. The NDVI value ranges from -1 to 1. Higher NDVI values indicate dense and healthy vegetation, while a value of zero suggests areas with no vegetation. Negative NDVI values are typically associated with water bodies.

In [None]:
# one example of NDVI
s2_item = s2_items[2023]
red_band_url = s2_item["Assets"]["red"]["Href"]
nir_band_url = s2_item["Assets"]["nir"]["Href"]
scl_mask_url = s2_item["Assets"]["scl"]["Href"]
red = rioxarray.open_rasterio(red_band_url, masked=True)
nir = rioxarray.open_rasterio(nir_band_url, masked=True)
scl = rioxarray.open_rasterio(scl_mask_url, masked=True)
scl_interp = scl.interp(x=red["x"], y=red["y"])
red_cloud_masked = red.where((scl_interp != 8) & (scl_interp != 9) & (scl_interp != 10))
nir_cloud_masked = nir.where((scl_interp != 8) & (scl_interp != 9) & (scl_interp != 10))

ndvi = (nir_cloud_masked - red_cloud_masked) / (nir_cloud_masked + red_cloud_masked)
tci_url = s2_item["Assets"]["visual"]["Href"]
tci = rioxarray.open_rasterio(tci_url)
tci_vis = np.einsum("ijk -> jki", tci)
ndvi_vis = np.einsum("ijk -> jki", ndvi)

fig, axs = plt.subplots(1, 2, figsize=(10, 20))
im0 = axs[0].imshow(tci_vis)
im1 = axs[1].imshow(ndvi_vis, cmap="RdYlGn")
cax = fig.add_axes([0.93, 0.4, 0.01, 0.2])
cbar = fig.colorbar(im1, cax=cax)
cbar.set_label("Colorbar")
plt.show()

In [None]:
script_dir = "scripts"
os.makedirs(script_dir, exist_ok=True)

In [None]:
%%writefile scripts/compute_vi.py

import os
import rioxarray
import json
import gc
import warnings

warnings.filterwarnings("ignore")

if __name__ == "__main__":
    print("Starting processing")

    input_path = "/opt/ml/processing/input"
    output_path = "/opt/ml/processing/output"
    input_files = []
    items = []
    for current_path, sub_dirs, files in os.walk(input_path):
        for file in files:
            if file.endswith(".json"):
                full_file_path = os.path.join(input_path, current_path, file)
                input_files.append(full_file_path)
                with open(full_file_path, "r") as f:
                    items.append(json.load(f))

    print("Received {} input files".format(len(input_files)))

    for item in items:
        print("Computing NDVI for {}".format(item["id"]))
        red_band_url = item["assets"]["red"]["href"]
        nir_band_url = item["assets"]["nir"]["href"]
        scl_mask_url = item["assets"]["scl"]["href"]
        red = rioxarray.open_rasterio(red_band_url, masked=True)
        nir = rioxarray.open_rasterio(nir_band_url, masked=True)
        scl = rioxarray.open_rasterio(scl_mask_url, masked=True)
        scl_interp = scl.interp(
            x=red["x"], y=red["y"]
        )  # interpolate SCL to the same resolution as Red and NIR bands

        # mask out cloudy pixels using SCL (https://sentinels.copernicus.eu/web/sentinel/technical-guides/sentinel-2-msi/level-2a/algorithm-overview)
        # class 8: cloud medium probability
        # class 9: cloud high probability
        # class 10: thin cirrus
        red_cloud_masked = red.where((scl_interp != 8) & (scl_interp != 9) & (scl_interp != 10))
        nir_cloud_masked = nir.where((scl_interp != 8) & (scl_interp != 9) & (scl_interp != 10))

        ndvi = (nir_cloud_masked - red_cloud_masked) / (nir_cloud_masked + red_cloud_masked)
        # save the ndvi as geotiff
        s2_tile_id = red_band_url.split("/")[-2]
        file_name = f"{s2_tile_id}_ndvi.tif"
        output_file_path = f"{output_path}/{file_name}"
        ndvi.rio.to_raster(output_file_path)
        print("Written output: {}".format(output_file_path))

        # keep memory usage low
        del red
        del nir
        del scl
        del scl_interp
        del red_cloud_masked
        del nir_cloud_masked
        del ndvi

        gc.collect()

### Launch the processing job

To initiate the SageMaker geospatial processing job, we follow a straightforward process. First, we specify the geospatial image to be utilized for the job. Then, we define the input by referencing the manifest file we generated earlier, which contains the relevant image pointers. Next, we specify the custom script we wrote as the code to be executed within the processing job. This script encapsulates the logic for computing the NDVI from the Sentinel-2 data, enabling us to extract valuable insights about vegetation. Lastly, we define the desired output location where we want to store the results of the processing job.

In [None]:
import argparse
import sagemaker
from sagemaker import get_execution_role
from sagemaker.sklearn.processing import ScriptProcessor
from sagemaker.processing import ProcessingInput, ProcessingOutput

role = get_execution_role()
geospatial_image_uri = (
    "081189585635.dkr.ecr.us-west-2.amazonaws.com/sagemaker-geospatial-v1-0:latest"
)


def run_job(job_idx):
    s3_manifest = f"s3://{s3_bucket_name}/{s3_prefix}/{manifest_dir}/manifest{job_idx}.json"
    s3_output = f"s3://{s3_bucket_name}/{s3_prefix}/output2022"
    script_processor = ScriptProcessor(
        command=["python3"],
        image_uri=geospatial_image_uri,
        role=role,
        instance_count=num_instances_per_job,
        instance_type="ml.m5.xlarge",
        base_job_name=f"ca-s2-nvdi-{job_idx}",
    )

    script_processor.run(
        code="scripts/compute_vi.py",
        inputs=[
            ProcessingInput(
                source=s3_manifest,
                destination="/opt/ml/processing/input/",
                s3_data_type="ManifestFile",
                s3_data_distribution_type="ShardedByS3Key",
            ),
        ],
        outputs=[
            ProcessingOutput(
                source="/opt/ml/processing/output/",
                destination=s3_output,
                s3_upload_mode="Continuous",
            )
        ],
    )

#### Launch multiple processing jobs in parallel

In [None]:
from multiprocessing import Pool

with Pool(num_jobs) as pool:
    pool.map(run_job, range(num_jobs))

<a id='4'></a>
## Analyze the NDVI over the year

After the completion of the processing jobs, we proceed to download the computed NDVI data to our local instance. This step enables us to access the results for in-depth analysis and exploration. This essential post-processing step empowers us to derive meaningful insights and make informed decisions based on the computed NDVI values.

In [None]:
def date_to_day_of_year(date):
    month = int(date[:2])
    day = int(date[2:])
    days_in_month = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
    day_of_year = sum(days_in_month[: month - 1]) + day
    return day_of_year


# download NDVI to local instance
sagemaker_session = sagemaker.Session()
s3_bucket = session.resource("s3").Bucket(s3_bucket_name)
ndvi_dir = "data/ndvi2022"
os.makedirs(ndvi_dir, exist_ok=True)
s2_tile_id = "10SGJ"
ndvi_prefix = f"{s3_prefix}/output2022"
ndvi_files = []
for s3_object in s3_bucket.objects.filter(Prefix=ndvi_prefix).all():
    path, filename = os.path.split(s3_object.key)
    if s2_tile_id in filename:
        ndvi_file = ndvi_dir + "/" + filename
        ndvi_files.append(ndvi_file)
        s3_bucket.download_file(s3_object.key, ndvi_file)
        print("Downloaded NDVI: " + ndvi_file)

To gain specific insights into vegetation trends, we select a particular location as an example using image coordinates for analysis. By utilizing the computed NDVI values, we extract the NDVI value for that specific point across the entire year of 2022.

In [None]:
# load the NDVI
poi = [200, 1850]
ndvi_dict2022 = {}
ndvi_poi2022 = {}
ndvi_files = glob("./data/ndvi2022/*.tif")
for ndvi_file in ndvi_files:
    ndvi = tifffile.imread(ndvi_file)
    ndvi_date = ndvi_file.split("_")[2][4:]
    ndvi_day = date_to_day_of_year(ndvi_date)
    ndvi_dict2022[ndvi_day] = ndvi
    ndvi_poi2022[ndvi_day] = ndvi[poi[0], poi[1]]

sorted_ndvi2022 = dict(sorted(ndvi_poi2022.items()))
ndvi_days = list(sorted_ndvi2022.keys())
ndvi_values = list(sorted_ndvi2022.values())

plt.plot(ndvi_days, ndvi_values, marker="o")
plt.xlabel("Day of the Year")
plt.ylabel("NDVI Value")
plt.show()

The NDVI values in the area exhibit clear seasonal variations, reflecting changes in vegetation over time. However, certain dates have missing NDVI values due to cloud cover obstructing the satellite view. To address this challenge, we employ a simple interpolation method. By leveraging the NDVI values from the two closest neighboring dates, we estimate the missing values through averaging.

In [None]:
## interpolating the missing values
def interpolate_vegetation_index(vegetation_index):
    days_of_year = list(vegetation_index.keys())
    values = list(vegetation_index.values())

    valid_indices = np.where(~np.isnan(values))[0]
    valid_days = np.array(days_of_year)[valid_indices]
    valid_values = np.array(values)[valid_indices]

    interpolated_values = np.interp(days_of_year, valid_days, valid_values)

    interpolated_vegetation_index = dict(zip(days_of_year, interpolated_values))

    return interpolated_vegetation_index


ndvi_interpolated = interpolate_vegetation_index(sorted_ndvi2022)
sorted_ndvi_interpolated = dict(sorted(ndvi_interpolated.items()))

plt.plot(
    list(sorted_ndvi_interpolated.keys()), list(sorted_ndvi_interpolated.values()), "b", marker="o"
)
plt.xlabel("Day of the Year")
plt.ylabel("NDVI Value")
plt.show()

<a id='5'></a>
### Monitor the veg health by comparing to historical data

After obtaining the NDVI values for the entire state in 2022, we aim to replicate the process for 2023 to establish a comprehensive profile of the area. By computing the NDVI values for 2023 and comparing them with the values from 2022, we can identify any significant deviations. This analysis enables us to assess vegetation patterns and draw informed conclusions about the current state of the area.

In [None]:
## download NDVI to local instance
sagemaker_session = sagemaker.Session()
s3_bucket = session.resource("s3").Bucket(s3_bucket_name)
ndvi_dir = "data/ndvi2023"
os.makedirs(ndvi_dir, exist_ok=True)
s2_tile_id = "10SGJ"
ndvi_prefix = f"{s3_prefix}/output2023"
ndvi_files = []
for s3_object in s3_bucket.objects.filter(Prefix=ndvi_prefix).all():
    path, filename = os.path.split(s3_object.key)
    if s2_tile_id in filename:
        ndvi_file = ndvi_dir + "/" + filename
        ndvi_files.append(ndvi_file)
        # s3_bucket.download_file(s3_object.key, ndvi_file)
        print("Downloaded NDVI: " + ndvi_file)

## load the NDVI
poi = [200, 1850]
ndvi_dict2023 = {}
ndvi_poi2023 = {}
ndvi_files = glob("./data/ndvi2023/*.tif")
for ndvi_file in ndvi_files:
    ndvi = tifffile.imread(ndvi_file)
    ndvi_date = ndvi_file.split("_")[2][4:]
    ndvi_day = date_to_day_of_year(ndvi_date)
    ndvi_dict2023[ndvi_day] = ndvi
    ndvi_poi2023[ndvi_day] = ndvi[poi[0], poi[1]]


sorted_ndvi2023 = dict(sorted(ndvi_poi2023.items()))
ndvi_days2023 = list(sorted_ndvi2023.keys())
ndvi_values2023 = list(sorted_ndvi2023.values())
ndvi_interpolated2023 = interpolate_vegetation_index(sorted_ndvi2023)
sorted_ndvi_interpolated2023 = dict(sorted(ndvi_interpolated2023.items()))

In [None]:
plt.plot(
    list(sorted_ndvi_interpolated.keys()),
    list(sorted_ndvi_interpolated.values()),
    "b",
    marker="o",
    label="NDVI 2022",
)
plt.plot(
    list(sorted_ndvi_interpolated2023.keys()),
    list(sorted_ndvi_interpolated2023.values()),
    "r",
    marker="x",
    label="NDVI 2023",
)
plt.legend(loc="upper right")
plt.xlabel("Day of the Year")
plt.ylabel("NDVI Value")
plt.show()