## Satellite image processing on Sagemaker

This notebook explores the use of SageMaker for basic geospatial analysis, including using a STAC browser to query and process cloud-optimized geotiffs

## Set-up
### Installs and imports

In [None]:
%%capture
%pip install geopandas
%pip install shapely
%pip install --find-links=https://girder.github.io/large_image_wheels --no-cache GDAL
%pip install rasterio
%pip install Werkzeug==2.3.7
%pip install leafmap localtileserver matplotlib==3.6.3 folium==0.13.0
# %pip install leafmap localtileserver matplotlib folium
%pip install jupyter-server-proxy
%pip install sat-search
# %pip install psutil
# %pip install boto3

Restart kernel

In [None]:
import os
import warnings
warnings.filterwarnings('ignore')
import re
from subprocess import run

import pandas as pd
import numpy as np
import geopandas as gpd
from shapely.geometry import Point, Polygon, box
import matplotlib
import matplotlib.pyplot as plt 

import boto3
import rasterio as rio

from rasterio.features import bounds
from rasterio.plot import show
from pyproj import Transformer
from rasterio.transform import Affine

import satsearch

proxy_path = f"studiolab/default/jupyter/proxy/{{port}}"
os.environ['LOCALTILESERVER_CLIENT_PREFIX'] = proxy_path
    
import localtileserver
from localtileserver import get_folium_tile_layer, TileClient, examples
import leafmap.foliumap as leafmap

### Helper functions

In [None]:
def base_map():
    m = leafmap.Map()
    m.add_basemap("SATELLITE")
    m.add_tile_layer(
        url='https://server.arcgisonline.com/ArcGIS/rest/services/' +\
                'World_Imagery/MapServer/tile/{z}/{y}/{x}',
        name="ESRI",
        attribution="ESRI"
    )
    return m

def poly_box(lon, lat, delta):
    c1 = [lon + delta, lat + delta]
    c2 = [lon + delta, lat - delta]
    c3 = [lon - delta, lat - delta]
    c4 = [lon - delta, lat + delta]
    geometry = {"type": "Polygon", "coordinates": [[ c1, c2, c3, c4, c1 ]]}
    return geometry

def get_subset(geotiff_file, geometry):
    with rio.Env(aws_session):
        with rio.open(geotiff_file) as src:#geo_fp:
            
            # get bbox from bounds of GeoSeries
            poly = gpd.GeoSeries([Polygon(geometry["coordinates"][0])])\
                .set_crs(4326)\
                .to_crs(src.crs)
            bbox = bounds(poly)
            
            window = rio.windows.from_bounds(
                bbox[0], bbox[1], bbox[2], bbox[3], transform=src.transform
            )
            # Actual HTTP range request
            subset = src.read(1, window=window, boundless=True)
    return subset

def plotNDVI(nir, red, filename):
    ndvi = (nir-red) / (nir+red)
    ndvi[ndvi>1] = 1
    plt.imshow(ndvi)
    plt.savefig(filename)
    plt.close()

## Get administrative areas

### Download and extract dataset

In [None]:
za_base_url = "https://biogeo.ucdavis.edu/data/diva/adm/"
za_bounds_file = 'ZMB_adm.zip'

data_dir = f"{os.environ['HOME']}/data"
if not os.path.isdir(data_dir):
    os.makedirs(data_dir, exist_ok=True)

if not os.path.isfile(f"{data_dir}/{za_bounds_file}"):
    !wget {za_base_url}{za_bounds_file} -P {data_dir}
    !unzip -o {data_dir}/{za_bounds_file} -d {data_dir}

### Read in shape

In [None]:
file = [f"{data_dir}/{f}" for f in os.listdir(data_dir) if "adm2.shp" in f]
zambia = gpd.read_file(file[0])
zambia.head()

In [None]:
m = base_map()
m.add_gdf(zambia, style={"color": "white"}, layer_name="Zambia", 
          zoom_to_layer=True)
m.add_gdf(zambia[zambia.NAME_2=="Kabwe"], style={"color": "red"}, 
          layer_name="Kabwe", zoom_to_layer=True)
m

## Get satellite data

We are going to use the Spatio-temporal asset catalog (STAC) and specifically designed STAC browsers to query Sentinel-2 imagery available on AWS. Here we are following examples provided [here](https://www.matecdev.com/posts/landsat-sentinel-aws-s3-python.html). 

First do a simple query of the whole catalog, to retrieve the number of records in there. 

In [None]:
sentinel_stac = satsearch.Search.search(
    url = "https://earth-search.aws.element84.com/v0"
)
print("Found " + str(sentinel_stac.found()) + " items")

### Filter for a specific area and time

#### Create ROI

In [None]:
xy = zambia[zambia.NAME_2=="Kabwe"].geometry.centroid
geometry = poly_box(xy.x.iloc[0], xy.y.iloc[0], 0.01)
time_range = '2023-04-01/2023-05-01'

#### Query STAC catalog

In [None]:
s2_search = satsearch.Search.search( 
    url = "https://earth-search.aws.element84.com/v0",
    intersects = geometry,
    datetime = time_range,
    collections = ['sentinel-s2-l2a-cogs']
)
s2_search

In [None]:
sentinel_items = s2_search.items()
print(sentinel_items.summary())

for item in sentinel_items:
    red_s3 = item.assets['B04']['href']
    print(red_s3)

In [None]:
item = sentinel_items[0]
print(item.assets.keys())

### Load Bands 4 and 8 for a spatial subset

#### Set-up AWS session

In [None]:
os.environ['CURL_CA_BUNDLE'] = '/etc/ssl/certs/ca-certificates.crt'
print("Creating AWS Session")
aws_session = rio.session.AWSSession(boto3.Session(), requester_pays=True)

#### Convert geometry to GeoSeries

In [None]:
with rio.open(red_s3) as src:
    src_meta = src.meta

geom_poly = gpd.GeoSeries(
    [Polygon(geometry["coordinates"][0])]
).set_crs(4326).to_crs(src_meta["crs"])
bbox = bounds(geom_poly)

Where is the ROI in relation to an S2 tile? 

In [None]:
src_extent = src.bounds
img_poly = gpd.GeoSeries(
    [box(src_extent.left, src_extent.bottom,
         src_extent.right, src_extent.top)]
)
roi_poly = gpd.GeoSeries(
    [Polygon(geometry["coordinates"][0])]
).set_crs(4326).to_crs(src.crs)

fig, ax = plt.subplots(figsize=(10, 10))
# show(subset, ax=ax)
img_poly.plot(color="green", ax=ax)
roi_poly.plot(color="red", ax=ax)
None

#### Read in subset from each band and date

In [None]:
redl = []
nirl = []
ndvil = []
for i, item in enumerate(sentinel_items):
    red_s3 = item.assets['B04']['href']
    nir_s3 = item.assets['B08']['href']
    date = item.properties['datetime'][0:10]

    print("Sentinel item number " + str(i) + "/" + \
          str(len(sentinel_items)) + " " + date)
    red = get_subset(red_s3, geometry)
    nir = get_subset(nir_s3, geometry)
    ndvi = (nir - red) / (nir + red + 0.00001)
    
    redl.append(red)
    nirl.append(nir)
    ndvil.append(ndvi)
    
    plotNDVI(
        nir, red, 
        f"{os.environ['HOME']}/sagemaker-studiolab-notebooks/"\
        f"images/{date}_{i}_ndvi.png"
    )

## Homework

1. Calculate median NIR image
2. Calculate median red image
3. Calculate NDVI from median red and NIR
4. Write to geotiff on disk (under data folder)
5. Display resulting NDVI image in leafmap using the add_raster function

### Median images

In [None]:
redmed = np.nanmedian(
    np.array([np.where(red==0, np.nan, red) for red in redl]), 
    axis=0
)
nirmed = np.nanmedian(
    np.array([np.where(nir==0, np.nan, nir) for nir in nirl]), 
    axis=0
)

### NDVI

In [None]:
ndvi = (nirmed-redmed) / (nirmed+redmed+0.0001)

### Write to geotiff

#### Get and adjust metadata

In [None]:
with rio.open(red_s3) as src:
    src_meta = src.meta

# src_meta    
dst_meta = src_meta.copy()
dst_transform = list(src_meta["transform"])[0:6]
bbox = bounds(geom_poly)
dst_transform[2] = bbox[0]
dst_transform[5] = bbox[3]
dst_meta["dtype"] = np.float32
dst_meta["height"] = ndvi.shape[0]
dst_meta["width"] = ndvi.shape[1]
dst_meta["transform"] = Affine(*dst_transform)

#### To disk

In [None]:
out_file = f"{data_dir}/ndvi_median.tif"
# with rio.open(out_file, 'w+', **dst_meta) as dst:
#     dst.write(ndvi, 1)

show(rio.open(out_file))
None

### Map it

Note: `add_raster` works hen 

In [None]:
m = base_map()
m.add_gdf(gpd.GeoDataFrame({"id": 1, "geometry": geom_poly}), 
          style={"color": "blue"}, layer_name="ROI", zoom_to_layer=True)
m.add_raster(out_file, cmap="PRGn", layer_name="S2 NDVI")
m

## Create cloud-native assets

### Cloud-Optimized Geotiffs

In [None]:
%%capture
%pip install rio-cogeo

Run rio-cogeo to create and validate tifs

In [None]:
cog_file = re.sub(".tif", "_cog.tif", out_file)

cmd = ['rio', 'cogeo', 'create', '-b', '1', out_file, cog_file]
p = run(cmd, capture_output=True)
msg = p.stderr.decode().split('\n')
print(f'...{msg[-2]}')

cmd = ['rio', 'cogeo', 'validate', cog_file]
p = run(cmd, capture_output = True)
msg = p.stdout.decode().split('\n')
print(f'...{msg[0]}')

Check specs of imagery

In [None]:
# !gdalinfo {out_file}
!gdalinfo {cog_file}

### Make an NDVI geotiff for each date

In [None]:
for i in range(len(ndvil)):
    ndvi_file = re.sub("median.tif", f"{i}.tif", out_file)
    # ndvi_cog_file = re.sub(f"{i}.tif", f"cog_{i}.tif", ndvi_file)
    print(f"Making {ndvi_file}")
    
    with rio.open(ndvi_file, 'w+', **dst_meta) as dst:
        dst.write(ndvil[i], 1)
    
    cmd = ['rio', 'cogeo', 'create', '-b', '1', ndvi_file, 
           ndvi_file]
    p = run(cmd, capture_output=True)
    msg = p.stderr.decode().split('\n')
    print(f'...{msg[-2]}')

    cmd = ['rio', 'cogeo', 'validate', ndvi_file]
    p = run(cmd, capture_output = True)
    msg = p.stdout.decode().split('\n')
    print(f'...{msg[0]}')