In [14]:
from pystac_client import Client
from odc.stac import load, configure_s3_access
from odc.geo import Geometry, BoundingBox
from dask.distributed import Client as DaskClient

import numpy as np
import geopandas as gpd
import folium
import xarray as xr

from utils import WGS84GRID30, USGSCATALOG, USGSLANDSAT, http_to_s3_url

In [2]:
# Configure S3 access, which requires AWS credentials for loading USGS Landsat data
configure_s3_access(cloud_defaults=True, requester_pays=True)

client = Client.open(USGSCATALOG)

In [37]:
# Load our extents
gdf = gpd.read_file("aois.geojson")

# 0 is Fiji, 1 is Caribbean and 2 is Belize
geom = Geometry(gdf.geometry[0], crs="epsg:4326")

# This is all the tiles
tiles = WGS84GRID30.tiles_from_geopolygon(geom)

# This is northern Viti Levu in Fiji
tile = WGS84GRID30.tile_geobox((119, -13))
tile

# Subset the tile, to load a very small area
box = BoundingBox(177.47563, 177.54938, -17.74963,  -17.68436, crs="epsg:4326")
tile = tile.crop([4000, 4000]).extent.explore()
tile

In [None]:
items = client.search(
    collections=[USGSLANDSAT],
    intersects=tile.geographic_extent,
    datetime="2024",
    query={"landsat:collection_category": {"in": ["T1"]}},
).item_collection()

print(f"Found {len(items)} items")


In [None]:
data = load(
    items,
    geobox=tile,
    measurements=["red", "green", "blue", "nir08", "qa_pixel"],
    chunks={"x": 2048, "y": 2048},
    groupby="solar_day",
    dtype="uint16",
    nodata=0,
    resampling={"qa_pixel": "nearest"},
    patch_url=http_to_s3_url,
)

data = data.rename_vars({"nir08": "nir"})

data

In [None]:
# Create cloud mask, scale values to 0-1 and set nodata to NaN

# Bits 3 and 4 are cloud shadow and cloud, respectively. Bit 0 is nodata
bitflags = 0b00011000

# Bitwise AND to select any pixel that is cloud shadow or cloud or nodata
cloud_mask = (data.qa_pixel & bitflags) != 0
# Note that it might be a good idea to dilate the mask here to catch any pixels that are adjacent to clouds

nodata_mask = data.qa_pixel == 0

# Combined the masks
mask = cloud_mask | nodata_mask

# Mask the original data
masked = data.where(~mask, other=np.nan).drop_vars("qa_pixel")

# Scale the data to 0-1
scaled = (masked.where(masked != 0) * 0.0000275 + -0.2).clip(0, 1)

scaled

In [None]:
time = 0
data.isel(time=time).odc.explore(vmin=7000, vmax=8000)

In [None]:
scaled.isel(time=time).odc.explore(vmin=0, vmax=0.2)

In [None]:
# Create the indices

# NDVI
scaled["ndvi"] = (scaled["nir"] - scaled["red"]) / (scaled["nir"] + scaled["red"])

# MSAVI
scaled["msavi"] = 0.5 * (
    (2 * scaled["nir"] + 1) - np.sqrt((2 * scaled["nir"] + 1) ** 2 - 8 * (scaled["nir"] - scaled["red"]))
)

# EVI2
scaled["evi2"] = 2.5 * (scaled["nir"] - scaled["red"]) / (scaled["nir"] + 2.4 * scaled["red"] + 1)

scaled

In [66]:
# Create a temporal maximum, mean and median for the three indices
results = []

for index in ["ndvi", "msavi", "evi2"]:
    maximum = scaled[index].max("time")
    mean = scaled[index].mean("time")
    median = scaled[index].median("time")

    # Rename the variables
    maximum = maximum.rename(f"{index}_max")
    mean = mean.rename(f"{index}_mean")
    median = median.rename(f"{index}_median")

    results.extend([maximum, mean, median])

# Combine the results into a single dataset
indices = xr.merge(results)

In [None]:
# Set up a dask local cluster and compute
with DaskClient(n_workers=1, threads_per_worker=16) as client:
    averages = indices.compute()

averages

In [None]:
maximum.ndvi.plot.imshow(cmap="RdYlGn")

In [None]:
center = tile.geographic_extent.centroid.coords[0][::-1]

m = folium.Map(location=center, zoom_start=8)

for index in ["ndvi", "msavi", "evi2"]:
    opts = {
        "vmin": -1,
        "vmax": 1,
    }
    if index == "ndvi":
        opts["cmap"] = "RdYlGn"
    else:
        opts["cmap"] = "viridis"

    averages[f"{index}_max"].odc.add_to(m, name=f"{index.upper()} Max", **opts)
    averages[f"{index}_mean"].odc.add_to(m, name=f"{index.upper()} Mean", **opts)
    averages[f"{index}_median"].odc.add_to(m, name=f"{index.upper()} Median", **opts)

folium.LayerControl().add_to(m)

m.save("indices.html")

m

In [None]:
tile.geographic_extent.centroid.coords