In [1]:
%pip install spyndex -q
%pip install jupyter_bokeh -q
#%pip install odc-stac -q
#%pip install stackstac -q

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import pandas as pd
import numpy as np
import xarray as xr
import rioxarray as rxr
import geopandas as gpd

import rasterio
from rasterio.windows import Window, from_bounds
from rasterio import plot
from rasterio.plot import show
from rasterio.transform import from_origin
from rasterio.warp import calculate_default_transform, reproject, Resampling

import rioxarray

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import pystac_client

import spyndex
import stackstac

import holoviews as hv
import geoviews as gv
import rioxarray
import panel as pn

hv.extension('bokeh')
gv.extension('bokeh')

os.environ['AWS_ACCESS_KEY_ID'] = 'your_access_key'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'your_secret_key'


In [3]:
def initialize_stac_client(stac_url):
    """
    Initialize and return a STAC client for a given STAC API URL.
    Parameters:
    - stac_url (str): The URL of the STAC API.
    Returns:
    - A pystac_client.Client object
    """
    client = pystac_client.Client.open(stac_url)
    return client


def query_stac_api(client, collections, bbox, start_date=None, end_date=None, max_items=10):
    """
    Query a STAC API for items within a bounding box and date range for specific collections.
    Parameters:
    - client: The STAC client initialized with `initialize_stac_client`.
    - bbox (list): The bounding box for the query [min_lon, min_lat, max_lon, max_lat].
    - collections (list): A list of collection IDs to include in the query.
    - start_date (str, optional): The start date for the query (YYYY-MM-DD). Defaults to None.
    - end_date (str, optional): The end date for the query (YYYY-MM-DD). Defaults to None.
    - limit (int): Maximum number of items to return.    
    Returns:
    - A list of STAC Items that match the query parameters.
    """

    search_params = {
        "bbox": bbox,
        "collections": [collections],
        "max_items": max_items
    }
    if start_date and end_date:
        search_params["datetime"] = f"{start_date}/{end_date}"

    search = client.search(**search_params).item_collection()
    return search

def stac_to_array(stac_items, bbox, requested_bands):
    """
    Converts STAC item collection returned by query_stac_api() to a data array that includes the requested bands and is clipped to bounding box.

    Parameters:
    - stac_items (list): A list of STAC items to be converted.
    - requested_bands (list): A list of bands to include in the datacube.
    - bbox

    Returns:
    - stack (DataArray): The array containing the requested bands within the specified bounding box.
    """
    stack = stackstac.stack(stac_items,
                        resolution = 50,
                        assets = requested_bands,
                        bounds_latlon=bbox,
                        chunksize=2048,
                        epsg=3857)
    return stack
    

# def apply_scl_mask(data_array, variable_name='scl'):
#     """
#     Apply the SCL cloud mask to each time slice of the data array, preserving metadata.
    
#     Parameters:
#     - data_array: xarray DataArray including spectral bands and an SCL layer over time.
    
#     Returns:
#     - The data array with cloud pixels masked, with preserved metadata.
#     """
#     # Create a dictionary to hold the attributes of the variables
#     variable_attrs = {var: stack[var].attrs for var in stack.data_vars}
    
    
#     scl = data_array.sel(band='scl')
    
#     def mask_time_slice(slice, attrs):
#         scl_cloud_values = [3, 8, 9, 10, 11]  # Cloud and other features to mask
#         scl_mask = slice.sel(band='scl').isin(scl_cloud_values)
#         masked_slice = slice.where(~scl_mask, other=np.nan)
#         # Drop 'scl' but preserve other variable attributes
#         for band in masked_slice.band.values:
#             if band != 'scl':
#                 masked_slice.sel(band=band).attrs = attrs[band]
#         return masked_slice.drop_vars('scl')
    
    
    
#     masked_data_array = data_array.groupby('solar_day').apply(mask_time_slice, attrs=variable_attrs)
#     masked_data_array.attrs = global_attrs  # Re-assign global attributes
    
#     return masked_data_array


In [4]:
def apply_scl_mask(data_array):
    scl_cloud_values = [3, 8, 9, 10, 11]  # Cloud and other features to mask
    bitmask = 0
    
    for field in scl_cloud_values:
        bitmask |= 1 << field
    
    grouped_stack = data_array.groupby('time')
    
    qa = grouped_stack['scl']
    bad = qa & bitmask
    
    grouped_stack.where(bad == 0)
    
    return grouped_stack
    

In [5]:
def save_array_to_geotiff(dataset, array_name, output_path):
    """
    Reprojects an xarray.Dataset to the specified CRS if necessary, then iterates through each 
    time segment of a single data array (e.g., NDVI), and saves each time segment as a separate GeoTIFF.

    Parameters:
    - dataset: xarray.Dataset containing the selected array with spatial metadata.
    - array_name: String name of the array to save (e.g., 'NDVI').
    - output_path: String base path to the output GeoTIFF files. Filenames will include the time slice 
      information in YYYY-MM format.
    - target_crs: String of the target CRS in EPSG code (default 'EPSG:3857').
    """


    for time_idx in range(dataset.sizes['time']):
        # Format the output filename to include the time slice in YYYY-MM format
        time_value = dataset['time'].isel(time=time_idx)
        time_str = pd.to_datetime(str(time_value.values)).strftime('%Y-%m')
        filename = f"{array_name}_{time_str}.tif"
        full_output_path = f"{output_path}/{filename}"
        
        dataset.rio.write_crs("epsg:3857", inplace=True)
        # Save each slice as a GeoTIFF
        dataset.isel(time=time_idx)[array_name].rio.to_raster(full_output_path)


In [6]:
outpath = '/workspace/notebooks/sandbox/data/output-data/spectral-indices'

geom_fpath = '/workspace/notebooks/sandbox/data/input-data/nca/dissolved-orana-boundaries.geojson'
geom = gpd.read_file(geom_fpath)

bbox = list(geom.total_bounds)
lat = geom.centroid.x[0]
long = geom.centroid.y[0]

stac_url_sentinel = "https://earth-search.aws.element84.com/v1"
collection_sentinel = "sentinel-2-l2a"

start_date = "2023-01-01"
end_date = "2024-01-01"
sentinel_bands = ['blue','green','red','nir', 'swir16', 'scl'] #scl MUST be included


  lat = geom.centroid.x[0]

  long = geom.centroid.y[0]


In [7]:
client_sentinel = initialize_stac_client(stac_url_sentinel)
items_sentinel = query_stac_api(client_sentinel, collection_sentinel, bbox, start_date, end_date)

items_sentinel

In [8]:

stack = stac_to_array(items_sentinel, bbox, sentinel_bands) #sentinel_bands,

  times = pd.to_datetime(


Use the Spyndex package to map satellite bands to parameters, and then calculate whatever indices we want. 

In [9]:

# Filter by low cloud cover if needed
#lowcloud = cloud_masked_stack[cloud_masked_stack["eo:cloud_cover"] < 10]

lowcloud = stack[stack["eo:cloud_cover"] < 20]

# Create monthly composite excluding NaN values
monthly = lowcloud.resample(time="4Q").median("time", keep_attrs=True, skipna=True)
rgb = monthly.sel(band=["red", "green", "blue"])


  index_grouper = pd.Grouper(


In [10]:
"""
TODO: convert into a function where user can list the indices they want to compute
"""
# using the data array set up using stackstac, we can now comput a bunch of indices:
monthly['NDVI'] = spyndex.indices.NDVI.compute(
        N = monthly.sel(band="nir"),
        R = monthly.sel(band="red"),
).clip(-1,1)

monthly['NDMI'] = spyndex.indices.NDMI.compute(
        N = monthly.sel(band="nir"),
        S1 = monthly.sel(band="swir16")
).clip(-1,1)

monthly['NDWI'] = spyndex.indices.NDWI.compute(
        G = monthly.sel(band="green"),
        N = monthly.sel(band="nir")
).clip(-1,1)

monthly['NDREI'] = spyndex.indices.NDREI.compute(
        N = monthly.sel(band="nir"),
        RE1 = monthly.sel(band="rededge1")
).clip(-1,1)

# monthly['CIRE'] = spyndex.indices.CIRE.compute(
#         N = monthly.sel(band="nir"),
#         RE1 = monthly.sel(band="rededge1")
# )
# monthly

ndvi = monthly['NDVI'].compute()
ndmi = monthly['NDMI'].compute()
ndwi = monthly['NDWI'].compute()
ndrei = monthly['NDREI'].compute()
#cire = monthly['CIRE'].compute()

KeyError: "not all values found in index 'band'. Try setting the `method` keyword argument (example: method='nearest')."

In [None]:
save_array_to_geotiff(ndvi, 'NDVI', outpath)
save_array_to_geotiff(ndmi, 'NDMI', outpath)
save_array_to_geotiff(ndwi, 'NDWI', outpath)
save_array_to_geotiff(ndrei, 'NDREI', outpath)
#save_array_to_geotiff(monthly, 'CIRE', outpath)

RasterioIOError: Attempt to create new tiff file '/workspaces/data-notebooks/notebooks/sandbox/data/output-data/NDVI_2023-12.tif' failed: No such file or directory

In [None]:
# Function to list all .tif files in the specified directory
def list_tif_files(path):
    return [f for f in os.listdir(path) if f.endswith('.tif')]

# Function to load and display the selected .tif file
def load_and_display_tif(filename):
    filepath = os.path.join(outpath, filename)
    img = gv.util.from_xarray(rioxarray.open_rasterio(filepath).rio.reproject('EPSG:3857'))
    
    # Define map tiles and create the map image
    map_tiles = gv.tile_sources.EsriImagery().opts(width=1000, height=600)
    map_img = gv.Image(img, kdims=['x', 'y']).opts(cmap='viridis', title=filename)
    map_combo = map_tiles * map_img
    
    return map_combo

In [None]:
# Create the dropdown menu
tif_files = list_tif_files(outpath)

dropdown = pn.widgets.Select(name='Select a .tif file', options=tif_files)

# Panel DynamicMap to update the map based on the dropdown selection
@pn.depends(dropdown.param.value)
def update_map(selected_file):
    return load_and_display_tif(selected_file)

# Layout using Panel
layout = pn.Column(
    pn.Row(pn.Column('## Geodata file viewer', dropdown)),
    update_map
)

# Display the Panel dashboard
layout.servable()