In [None]:
import dask.array as da
import fsspec
import numpy as np
import pyproj
import pystac
import rioxarray
import stac2dcache
import xarray as xr

# Spring Index Models from Daymet

## 1. Introduction

### 1.1 Overview

In this notebook we calculate spring onset indicators, namely **the day of first leaf appearance**, as 1-km gridded estimates over the conterminous United States (CONUS). As input data, we use variables from the [Daymet dataset](https://daac.ornl.gov/cgi-bin/dsviewer.pl?ds_id=1840), which we have previously retrieved to the [SURF dCache storage](http://doc.grid.surfsara.nl/en/stable/Pages/Service/system_specifications/dcache_specs.html) in the form of a [SpatioTemporal Asset Catalog](https://stacspec.org/) (see [this notebook](https://github.com/RS-DAT/JupyterDask-Examples/blob/main/03-phenology/notebooks/01-download-Daymet4.ipynb)). The same storage system is used for the output spring index products, which we save in [Zarr](https://zarr.readthedocs.io/en/stable/) format. This work is based on the publication [Izquierdo-Veriguier et al., 2018](https://doi.org/10.1016/j.agrformet.2018.06.028). 

### 1.2 The model

The first-leaf spring indices have been computed following the Extended Spring Index (SI-x) models from [Schwartz et al., 2013](https://doi.org/10.1002/joc.3625). Input data variables, taken from the Daymet dataset, are the daily minimum and maximum temperatures and the daylight duration. 

Using the SI-x models, the first-leaf dates are estimated for the *Lilac* plant species. For more information have a look at the original publication [Izquierdo-Veriguier, 2018](
https://doi.org/10.1016/j.agrformet.2018.06.028).

### 1.3 Before running this notebook

The input and output datasets as well as the corresponding  metadata are stored on the SURF dCache system, which we access via bearer-token authentication with a macaroon. The macaroon, generated using [this script](https://github.com/sara-nl/GridScripts/blob/master/get-macaroon), is stored together with other configuration parameters within a JSON fsspec configuration file (also see the [STAC2dCache tutorial](https://github.com/NLeSC-GO-common-infrastructure/stac2dcache/blob/main/notebooks/tutorial.ipynb) and the [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/features.html#configuration) for more info).

## 2. Calculating the Spring Indices

### 2.1 Overview

The calculation of the spring index events involves the following steps: 
* opening the input variables from the retrieved collection; 
* performing some preprocessing operations (filtering the spatial and temporal extents from the daily records, carrying out few conversions);
* estimating the spring index dates on the 1-km grid on which input variables are provided;
* saving the output.

All the steps are run by looping over years and by using a [Dask](http://dask.org) cluster to parallelize operations over spatial regions and days of the year. For the purpose of this demo we run the spring index calculation for a single year - see [this notebook](https://github.com/RS-DAT/JupyterDask-Examples/blob/main/03-phenology/notebooks/02-compute-spring-index.ipynb) for an example run involving the full time span of the Daymet dataset (42 years).

### 2.2 Input parameters  

The following variables define the parameters for the spring index calculations. These include the range of years, the range of days where to look for the spring onset events, and the boundaries of the area of interest. 

In [None]:
# Select one year to compute the spring index 
year = 1980

# Year day range for calculating growing degree hours
startdate = 1 
enddate = 300

# Bounding box expressed in lat/lon degrees
bbox_latlon = (-124.784, 24.743, -66.951, 49.346)

We also set the dCache path to the STAC catalog where we have archived the Daymet dataset and the path where to store the output spring indices:

In [None]:
# dCache project root path
root_urlpath = "dcache://pnfs/grid.sara.nl/data/remotesensing/disk"

catalog_urlpath = f"{root_urlpath}/daymet-daily-v4/catalog.json"
output_urlpath = f"{root_urlpath}/demo/spring-index-models.zarr"

### 2.3 The model

The SI-x model is encoded in the following few functions, which are used to calculate the first-leaf spring index dates. From the input variables extracted from Daymet, the growing degree hours (GDH) is first computed. A set of predictors is then calculated from the GDH, and these are in turn used to estimate the spring onset dates for our plant species.  

In [None]:
BASE_TEMP_FAHRENHEIT = 31.

HOURS = xr.DataArray(
    data=da.arange(24), 
    dims=("hours",),
)

DAYS = xr.DataArray(
    data=da.arange(startdate, enddate+1),
    dims=("time",),
)

LEAF_INDEX_COEFFS = xr.DataArray(
    data=da.from_array(
        [
            [3.306, 13.878, 0.201, 0.153],
        ],
        chunks=(1,-1)
    ),
    dims=("plant", "variable"),
    coords={"plant": ["lilac"]}
)

LEAF_INDEX_LIMIT = 637

In [None]:
def calculate_gdh(dayl, tmin, tmax):
    """ 
    Calculate growing degree hours (GDH). 
    """
    
    dt = tmax - tmin
    const = np.sin(np.pi/(dayl + 4) * dayl) * dt
    
    eq1 = np.sin(HOURS * np.pi/(dayl + 4)) * dt 
    eq2 = (1 - np.log(HOURS - np.floor(dayl))/np.log(24 - dayl)) * const
    t = xr.where(~np.isfinite(eq2), eq1, eq2) + tmin - BASE_TEMP_FAHRENHEIT
    t = t.clip(min=0)
    return t.sum(dim="hours", skipna=False)


def calculate_leaf_predictors(gdh):
    """
    Calculate predictors for first leaf: DDE2, DD57, MDS0, and SYNOP.
    """
    
    # Pad GDH to solve issues with first days of the year
    gdh_padded = gdh.pad(time=(7,0), mode="edge")
    
    # Calculating dde2 - trailing 3 days GDH sum from day i-2 to i
    dde2 = gdh_padded.rolling(time=3, center=False).sum()
    dde2 = dde2.isel(time=slice(7, None))  # drop padded values 
    
    # Calculating dd57 - trailing 5-7 days GDH sum from day i-7 to i-5
    dd57 = gdh_padded.rolling(time=8, center=False).sum() \
        - gdh_padded.rolling(time=5, center=False).sum()
    dd57 = dd57.isel(time=slice(7, None))  # drop padded values
    
    # Calculating mds0
    mds0 = DAYS - 1
    
    # Calculating synop
    synflag = dde2>=LEAF_INDEX_LIMIT
    synop = synflag.cumsum(dim="time")

    return dde2, dd57, mds0, synop


def calculate_first_leaf(dde2, dd57, mds0, synop):
    """
    Calculate day of first leaf for each plant species from GDH.
    """ 
            
    # Prediction calculation for first leaf
    mdsum = LEAF_INDEX_COEFFS[:,0]*mds0 \
        + LEAF_INDEX_COEFFS[:,1]*synop \
        + LEAF_INDEX_COEFFS[:,2]*dde2 \
        + LEAF_INDEX_COEFFS[:,3]*dd57

    mdbool = mdsum>999.5  # Calculate all occurences of first leaf

    # Vectorized approach to identifying first day of leaf
    outdate = mdbool.argmax(dim="time")
    outdate = outdate.where(mdbool.sum(dim="time")>0)
            
    return outdate

### 2.4 Open the input catalog 

The input variables (minimum temperature, maximum temperature and day length duration) are extracted from the Daymet catalog, which we have dowloaded earlier as a STAC catalog (see [this notebook](./01-download-Daymet4.ipynb)). In order to get access to the data we load the catalog:

In [None]:
catalog = pystac.Catalog.from_file(catalog_urlpath)

In addition to providing links to the data, the catalog provides all the dataset's metadata, which we use e.g. to convert the bounding box from latitude/logitude degrees to the dataset's coordinate reference system (CRS):

In [None]:
# Extract information about input CRS from metadata
_item = next(catalog.get_all_items())
proj_json = _item.properties["proj:projjson"]
crs_lcc = pyproj.CRS.from_json_dict(proj_json)

# Set up CRS converter
transformer = pyproj.Transformer.from_crs(
    crs_from="EPSG:4326", 
    crs_to=crs_lcc,
    always_xy=True,
)

# Calculate bbox in the dataset's CRS
bbox = transformer.transform_bounds(*bbox_latlon)

### 2.5 Connect to the cluster

Once we are ready to run the calculation we setup a Dask cluster and create a client connection. This is most easily achieved via the Dask JupyterLab extension (look for the Dask logo on the left tab of the JupyterLab interface). 

*--DROP DASK `SLURMCluster` HERE--*

Here we create a cluster with 15 nodes. Let's wait for all workers to join the cluster:

In [None]:
client.wait_for_workers(n_workers=15)

### 2.6 Run the model

Once the Dask cluster is reachable, we can start the computation! We define few convenience functions to open the dataset using the Xarray library, preprocess the input variables and save the output products to the storage. Note that by setting the size of the data "chunks" when reading the data, we choose to use Dask arrays as underlying data structure. All calls to Xarray's objects are then lazily executed until data are written to disk, which triggers the calculation of the spring index for a given year.

We now compute the spring indices for the selected year:

In [None]:
def open_dataset(urlpaths, **kwargs):
    """
    Open the remote files as a single dataset. 
    """
    
    ofs = fsspec.open_files(urlpaths, block_size=4*2**20)
    return xr.open_mfdataset(
        [of.open() for of in ofs],
        engine="h5netcdf", 
        decode_coords="all",
        drop_variables=("lat", "lon"),
        **kwargs
    )

In [None]:
# Extract urlpaths to Daymet files from catalog
item = catalog.get_item(f"na-{year}", recursive=True)
hrefs = [
    item.assets[var].get_absolute_href() 
    for var in ("tmin", "tmax", "dayl")
]
    
# Open files as a single dataset, using a chunked Dask array
ds = open_dataset(hrefs, chunks={"time": 5, "x": 1000, "y": 1000})
ds = ds.coarsen({"x": 20, "y":20}, boundary="trim").mean()
ds

In [None]:
# Plot a slice of the dataset
ds["tmax"].isel(time=0).plot.imshow()

In [None]:
def preprocess_dataset(ds, startdate, enddate, bbox):
    """
    Subset the input dataset and make necessary conversions.
    """
    
    # Select time range for GDH calculation
    ds = ds.isel(time=slice(startdate-1, enddate))
    
    # Spatial selection
    ds = ds.rio.clip_box(*bbox)
    
    # Convert temperatures to Fahrenheit
    tmax = ds["tmax"] * 1.8 + 32
    tmin = ds["tmin"] * 1.8 + 32

    # Convert daylength from seconds to hours
    dayl = ds["dayl"] / 3600

    return tmax, tmin, dayl

In [None]:
# Extract temporal/spatial ranges, unit conversion
tmax, tmin, dayl = preprocess_dataset(ds, startdate, enddate, bbox)

In [None]:
# Plot same slice after pre-processing
tmax.isel(time=0).plot.imshow()

In [None]:
# Calculate GDH and rechunk to have single chunk along time axis
gdh = calculate_gdh(dayl, tmin, tmax)
gdh = gdh.chunk({"time": enddate-startdate+1, "x": 500, "y": 500})
gdh

In [None]:
# Plot GDH vs time for one of the pixels
gdh.isel(x=2800, y=1500).plot()

In [None]:
# Fist leaf index
dde2, dd57, mds0, synop = calculate_leaf_predictors(gdh)
first_leaf = calculate_first_leaf(dde2, dd57, mds0, synop)

In [None]:
# Trigger calculation of spring indices
first_leaf = client.persist(first_leaf)

In [None]:
# Plot the first-leaf date
first_leaf.plot.imshow(col="plant")

In [None]:
def save_to_urlpath(first_leaf, urlpath, group):
    """
    Save output to urlpath in Zarr format. 
    """
    
    fs_map = fsspec.get_mapper(urlpath)
    ds = xr.Dataset({
        f"first-leaf": first_leaf, 
    })
    ds.to_zarr(fs_map, mode="w", group=group)

In [None]:
# Rechunk and save to storage
save_to_urlpath(
    first_leaf.chunk({"plant": 1, "x": 1000, "y": 1000}),
    output_urlpath, 
    f"{year}",
)

While this run involved a single year for demo purposes, we have run the spring index calculation for the whole time span of the Daymet dataset for North America (42 years) in [this notebook](https://github.com/RS-DAT/JupyterDask-Examples/blob/main/03-phenology/notebooks/02-compute-spring-index.ipynb). Using 15 nodes (60 cores), the overall wall time was ~5 hours.

When done, we shutdown the cluster to release resources.