In [92]:
#%pip install koordinates -q
#%pip install spyndex -q
#%pip install jupyter_bokeh -q
#%pip install stackstac -q
#%pip install dask[dataframe] -q



In [93]:
# general packages:
import os
import pandas as pd
import numpy as np
import xarray as xr

# geospatial:
import pystac_client
import spyndex
import stackstac

import cartopy.crs as ccrs

import rioxarray
import geopandas as gpd

# import rasterio
# from rasterio.windows import Window, from_bounds
# from rasterio import plot
# from rasterio.plot import show
# from rasterio.transform import from_origin
# from rasterio.warp import calculate_default_transform, reproject, Resampling



import owslib
from owslib.wfs import WebFeatureService
from owslib.fes import *

# api specific:
#import requests
#from requests import Request
#import koordinates


# data viz:
#from IPython.display import display, JSON
#import datashader
import holoviews as hv
import geoviews as gv
import datashader as ds
import panel as pn

from holoviews import opts
from holoviews.operation.datashader import rasterize, shade


hv.extension('bokeh')
pn.extension() 

os.environ['AWS_ACCESS_KEY_ID'] = 'your_access_key'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'your_secret_key'

In [94]:
def initialize_stac_client(stac_url):
    """
    Initialize and return a STAC client for a given STAC API URL.
    Parameters:
    - stac_url (str): The URL of the STAC API.
    Returns:
    - A pystac_client.Client object
    """
    client = pystac_client.Client.open(stac_url)
    return client


def query_stac_api(client, collections, bbox, start_date=None, end_date=None, max_items=None):
    """
    Query a STAC API for items within a bounding box and date range for specific collections.
    Parameters:
    - client: The STAC client initialized with `initialize_stac_client`.
    - bbox (list): The bounding box for the query [min_lon, min_lat, max_lon, max_lat].
    - collections (list): A list of collection IDs to include in the query.
    - start_date (str, optional): The start date for the query (YYYY-MM-DD). Defaults to None.
    - end_date (str, optional): The end date for the query (YYYY-MM-DD). Defaults to None.
    - limit (int): Maximum number of items to return.    
    Returns:
    - A list of STAC Items that match the query parameters.
    """

    search_params = {
        "bbox": bbox,
        "collections": [collections],
        "max_items": max_items
    }
    if start_date and end_date:
        search_params["datetime"] = f"{start_date}/{end_date}"

    search = client.search(**search_params).item_collection()
    return search

def stac_to_array(stac_items, bbox, requested_bands):
    """
    Converts STAC item collection returned by query_stac_api() to a data array that includes the requested bands and is clipped to bounding box.

    Parameters:
    - stac_items (list): A list of STAC items to be converted.
    - requested_bands (list): A list of bands to include in the datacube.
    - bbox

    Returns:
    - stack (DataArray): The array containing the requested bands within the specified bounding box.
    """
    stack = stackstac.stack(stac_items,
                        resolution = 10,
                        assets = requested_bands,
                        bounds_latlon=bbox,
                        chunksize=2048,
                        epsg=3857)
    return stack

def save_array_to_geotiff(dataset, array_name, output_path):
    """
    Reprojects an xarray.Dataset to the specified CRS if necessary, then iterates through each 
    time segment of a single data array (e.g., NDVI), and saves each time segment as a separate GeoTIFF.

    Parameters:
    - dataset: xarray.Dataset containing the selected array with spatial metadata.
    - array_name: String name of the array to save (e.g., 'NDVI').
    - output_path: String base path to the output GeoTIFF files. Filenames will include the time slice 
      information in YYYY-MM format.
    - target_crs: String of the target CRS in EPSG code (default 'EPSG:3857').
    """


    for time_idx in range(dataset.sizes['time']):
        # Format the output filename to include the time slice in YYYY-MM format
        time_value = dataset['time'].isel(time=time_idx)
        time_str = pd.to_datetime(str(time_value.values)).strftime('%Y-%m')
        filename = f"{array_name}_{time_str}.tif"
        full_output_path = f"{output_path}/{filename}"
        
        dataset.rio.write_crs("epsg:3857", inplace=True)
        # Save each slice as a GeoTIFF
        dataset.isel(time=time_idx)[array_name].rio.to_raster(full_output_path)
        

def check_and_reproject(df):
    """
    Checks the coordinate reference system (CRS) of the input DataFrame and reprojects it to EPSG:4326 if necessary.

    Parameters:
    df (GeoDataFrame): The input DataFrame containing spatial data.

    Returns:
    GeoDataFrame: The reprojected DataFrame.
    """
    if df.crs != 'EPSG:4326':
        df = df.to_crs('EPSG:4326')
    else:
        df=df
    return df


# Function to list all .tif files in the specified directory
def list_tif_files(path):
    return [f for f in os.listdir(path) if f.endswith('.tif')]

# Function to load and display the selected .tif file
def load_and_display_tif(filename):
    map_tiles = hv.element.tiles.OSM().opts(width=1000, height=800)
    
    filepath = os.path.join(outpath, filename)
    #img = gv.util.from_xarray(rioxarray.open_rasterio(filepath).rio.reproject('EPSG:3857'))
    img = gv.util.from_xarray(rioxarray.open_rasterio(filepath).rio.reproject('EPSG:3857'))
    
    #map_img = gv.Image(img, kdims=['x', 'y']).opts(cmap='viridis', title=filename)
    
    map_img = hv.HoloMap(gv.Image(img, kdims=['x', 'y'], rtol=10).opts(cmap='viridis', title=filename))
    map_combo = map_tiles * map_img
    
    return map_combo

In [95]:
outpath = '/workspace/notebooks/sandbox/data/output-data/spectral-indices'
target_crs = 'EPSG:2193'  # this is for vector, use another variable for stac


# import and set up geometry
geom_path = '/workspace/notebooks/sandbox/data/input-data/nz_merino/nz_merino_test_random_area.geojson'
geom = gpd.read_file(geom_path)

geom_stac = geom #temporary until I figure out if I can keep both in same crs. STAC can't use 2193


if geom.crs != target_crs:
    geom = geom.to_crs(target_crs)
    
    
bbox = list(geom.total_bounds)
bbox_stac = list(geom_stac.total_bounds) #one bbox for STAC, the other for NZ to account for different projections

lat = geom.centroid.x[0]
long = geom.centroid.y[0]
    

In [96]:

# set up STAC
stac_url_sentinel = "https://earth-search.aws.element84.com/v1"
collection_sentinel = "sentinel-2-l2a"

start_date = "2023-01-01"
end_date = "2024-01-01"
sentinel_bands = ['blue','green','red', 'rededge1', 'nir', 'swir16', 'scl'] 

### Get soil vector data from LRIS portal

In [97]:
wfs_url = "https://lris.scinfo.org.nz/services;key=b16e80c40f4040b5a03ddb8f66b403d8/wfs/layer-48066/?service=WFS&request=GetCapabilities"
wfs = WebFeatureService(url=wfs_url)

In [98]:
feature_type = list(wfs.contents.keys())

for feature_type in wfs.contents:
    print(feature_type)
    
wfs.get_schema(feature_type)

layer = wfs.contents

layer

lris.scinfo.org.nz:layer-48066


{'lris.scinfo.org.nz:layer-48066': <owslib.feature.wfs100.ContentMetadata at 0x7f443e91f790>}

In [99]:
response = wfs.getfeature(typename=layer, bbox=bbox, outputFormat='json')

In [100]:
out = open('/workspace/notebooks/sandbox/adhoc_client_work/nz_merino/test-storedquery.geojson', 'wb')
out.write(response.read())
out.close()

In [101]:
data = gpd.read_file('/workspace/notebooks/sandbox/adhoc_client_work/nz_merino/test-storedquery.geojson')

data_clip = data.clip(bbox)
data_clip.to_file('/workspace/notebooks/sandbox/adhoc_client_work/nz_merino/test-storedquery_clip.geojson', driver='GeoJSON')

### Get vegetation and soil indices using sentinel-2 + STAC + Spyndex

In [102]:
client_sentinel = initialize_stac_client(stac_url_sentinel)
items_sentinel = query_stac_api(client_sentinel, collection_sentinel, bbox_stac, start_date, end_date)

stack = stac_to_array(items_sentinel, bbox_stac, sentinel_bands)

  times = pd.to_datetime(


In [103]:
lowcloud = stack[stack["eo:cloud_cover"] < 20] #removed cloud filtering for now as it appears to be broken

# Create monthly composite excluding NaN values
monthly = lowcloud.resample(time="4Q").median("time", keep_attrs=True, skipna=True)
rgb = monthly.sel(band=["red", "green", "blue"])


  index_grouper = pd.Grouper(


In [104]:
"""
TODO: convert into a function where user can list the indices they want to compute
"""
# using the data array set up using stackstac, we can now comput a bunch of indices:
monthly['NDVI'] = spyndex.indices.NDVI.compute(
        N = monthly.sel(band="nir"),
        R = monthly.sel(band="red"),
).clip(-1,1)

monthly['NDMI'] = spyndex.indices.NDMI.compute(
        N = monthly.sel(band="nir"),
        S1 = monthly.sel(band="swir16")
).clip(-1,1)


monthly['NDREI'] = spyndex.indices.NDREI.compute(
        N = monthly.sel(band="nir"),
        RE1 = monthly.sel(band="rededge1")
).clip(-1,1)


ndvi = monthly['NDVI'].compute()
ndmi = monthly['NDMI'].compute()
ndrei = monthly['NDREI'].compute()

In [105]:
save_array_to_geotiff(ndvi, 'NDVI', outpath)
save_array_to_geotiff(ndmi, 'NDMI', outpath)
save_array_to_geotiff(ndrei, 'NDREI', outpath)

### Map visualisation to check data

In [106]:
# set up map tiles for basemap
map_tiles = hv.element.tiles.EsriImagery().opts(width=600, height=400)

data_map = check_and_reproject(data_clip)
soil = data_map.hvplot(geo=True)

map_tiles * soil