# USE CASE 1: Spring Index Calculation for Puerto Rico

## Description:

In this practice use-case, you will use the tools learnt to calculate the spring indices for region of Puerto Rico. As you progress in the course, you will be able to complete different sections of this practice use-case to get you familiarised with the different tools. The calculations are intended to be performed using the Deployable Analysis Environment (DAT) on the infrastructure provided. This notebook will continue to be available to you after the summer school. 

In this practice use-case, you will perform a similar calculation for Puerto Rico as what was shown for North America in the code-along session. The workflow is split in two parts:

1) Part 1 covers the downloading of the required data-set for Puerto Rico and storing it in the STAC catalog already created in the code-along session. 

2) Part 2 covers the calculation of the leaf and bloom spring indices for three different plant species in Puerto Rico.

## Notes:

This notebook will guide you through the process of setting up your data download and corresponding calculations. Some parts of the notebook are prefilled in order to help you will the calculations. Other parts of the notebook, left empty, are to be completed by you.

## Step 1: Downloading the required data-sets
### ***PLEASE NOTE THAT THIS FIRST PART WILL ONLY WORK AFTER YOU HAVE RUN THE 01-data-access-and-retrieval NOTEBOOK. WITHOUT THAT, PART 1 OF THIS NOTEBOOK WILL NOT WORK.

In the code-along session, we showed you how to download the daymet4 data-set for Hawaii and to create a STAC catalog stored on dCache. This enables you to access the data-set efficiently for calculation on SURF infrastructure. 

Here you will practice by downloading another part of the data-set corresponding to Puerto Rico, namely the day length variable, and store it in the local STAC catalog created before. 

In [None]:
# Import copy_asset function from stac2dcache
# ~1 line of code
from stac2dcache.utils import copy_asset

In [None]:
# # Define and load catalog for Puerto Rico
# ~2 lines of code
catalog = pystac.Catalog.from_file('./daymet-daily-v4')
pr = catalog.get_child("pr")

In [None]:
#Download and copy the day length variables for Puerto Rico
# ~1 line of code
copy_asset(
    catalog=pr,
    asset_key="dayl",
    update_catalog=True,
    max_workers=2,
)

In [None]:
#Save the Catalog
# ~1 line of code
catalog.save()

Confirm that the Puerto Rico data-set has been downloaded and saved locally.

## Step 2: Calculate spring indices
### YOU CAN RUN THIS PART INDEPENDANT OF THE PREVIOUS PART-1

In this part, you will calculate the leaf spring indices for Puerto-Rico for three plant species: 'Lilac', 'Arnold Red', 'Zabeli' and averaging the final results. First, the required libraries need to be imported.

In [None]:
#DO NOT MODIFY
import dask.array as da
import fsspec
import numpy as np
import pyproj
import pystac
import rioxarray
import stac2dcache
import xarray as xr

Now, load the (slightly modified) parameters, and functions for calculating GDH, and first leaf spring indices by running the below cell. 

In [None]:
# PREDEFINED FUNCTIONS AND PARAMETERS - DO NOT MODIFY
# Parameters and Coefficients
BASE_TEMP_FAHRENHEIT = 31.

HOURS = xr.DataArray(
    data=da.arange(24), 
    dims=("hours",),
)

DAYS = xr.DataArray(
    data=da.arange(startdate, enddate+1),
    dims=("time",),
)

LEAF_INDEX_COEFFS = xr.DataArray(
    data=da.from_array(
        [
            [3.306, 13.878, 0.201, 0.153], # Coefficients for Lilac
            [4.266, 20.899, 0.000, 0.248], # Coefficients for Arnold Red
            [2.802, 21.433, 0.266, 0.000], # Coefficients for Zabeli
        ],
        chunks=(1,-1)
    ),
    dims=("plant", "variable"),
    coords={"plant": ["lilac", "arnold red", "zabelli"]}
)

LEAF_INDEX_LIMIT = 637

# Required Functions for calculations

def open_dataset(urlpaths, **kwargs):
    """
    Open the remote files as a single dataset. 
    """
    
    ofs = fsspec.open_files(urlpaths, block_size=4*2**20)
    return xr.open_mfdataset(
        [of.open() for of in ofs],
        engine="h5netcdf", 
        decode_coords="all",
        drop_variables=("lat", "lon"),
        **kwargs
    )


def calculate_gdh(dayl, tmin, tmax):
    """ 
    Calculate growing degree hours (GDH). 
    """
    
    dt = tmax - tmin
    const = np.sin(np.pi/(dayl + 4) * dayl) * dt
    
    eq1 = np.sin(HOURS * np.pi/(dayl + 4)) * dt 
    eq2 = (1 - np.log(HOURS - np.floor(dayl))/np.log(24 - dayl)) * const
    t = xr.where(~np.isfinite(eq2), eq1, eq2) + tmin - BASE_TEMP_FAHRENHEIT
    t = t.clip(min=0)
    return t.sum(dim="hours", skipna=False)


def calculate_leaf_predictors(gdh):
    """
    Calculate predictors for first leaf: DDE2, DD57, MDS0, and SYNOP.
    """
    
    # Pad GDH to solve issues with first days of the year
    gdh_padded = gdh.pad(time=(7,0), mode="edge")
    
    # Calculating dde2 - trailing 3 days GDH sum from day i-2 to i
    dde2 = gdh_padded.rolling(time=3, center=False).sum()
    dde2 = dde2.isel(time=slice(7, None))  # drop padded values 
    
    # Calculating dd57 - trailing 5-7 days GDH sum from day i-7 to i-5
    dd57 = gdh_padded.rolling(time=8, center=False).sum() \
        - gdh_padded.rolling(time=5, center=False).sum()
    dd57 = dd57.isel(time=slice(7, None))  # drop padded values
    
    # Calculating mds0
    mds0 = DAYS - 1
    
    # Calculating synop
    synflag = dde2>=LEAF_INDEX_LIMIT
    synop = synflag.cumsum(dim="time")

    return dde2, dd57, mds0, synop


def calculate_first_leaf(dde2, dd57, mds0, synop):
    """
    Calculate day of first leaf for each plant species from GDH.
    """ 
            
    # Prediction calculation for first leaf
    mdsum = LEAF_INDEX_COEFFS[:,0]*mds0 \
        + LEAF_INDEX_COEFFS[:,1]*synop \
        + LEAF_INDEX_COEFFS[:,2]*dde2 \
        + LEAF_INDEX_COEFFS[:,3]*dd57

    mdbool = mdsum>999.5  # Calculate all occurences of first leaf

    # Vectorized approach to identifying first day of leaf
    outdate = mdbool.argmax(dim="time")
    outdate = outdate.where(mdbool.sum(dim="time")>0)
            
    # Arnold red's first leaf is one day after reaching mdsum limit
    day_shift = xr.DataArray(
        da.array([0, 1, 0]),
        dims=("plant",),
        coords={"plant": ["lilac", "arnold red", "zabelli"]}
    )
    outdate = outdate + day_shift
    return outdate


def add_mean_plant_layer(outdate):
    """
    Average the spring index date over plant species and add the mean
    as a new layer.
    """
    
    mean = outdate.mean(dim="plant", skipna=False).round()
    mean = mean.expand_dims(plant=["mean"])
    return xr.concat([outdate, mean], dim="plant")


def save_to_urlpath(first_leaf, urlpath, group):
    """
    Save output to urlpath in Zarr format. 
    """
    
    fs_map = fsspec.get_mapper(urlpath)
    ds = xr.Dataset({
        f"first-leaf": first_leaf, 
    })
    ds.to_zarr(fs_map, mode="w", group=group)

#### Create and connect dask cluster

Scale your system to 4 workers* and connect the dask cluster to this notebook.

*PLEASE DO NOT SCALE BEYOND 4 WORKERS TO ENSURE FAIR DISTRIBUTION OF WORKLOAD

In [None]:
from dask.distributed import Client

client = Client("tcp://10.0.0.28:46245")
client

*--DROP DASK `SLURMCluster` HERE--*

#### Load the Dataset

In [None]:
# DO NOT MODIFY
# This defines the urlpath to the dataset stored on dCache

# dCache project root path
root_urlpath = "dcache://pnfs/grid.sara.nl/data/remotesensing/disk/hdcrs"

# catalog path under root directory
catalog_urlpath = f"{root_urlpath}/daymet-daily-v4/catalog.json"

#Output path
output_urlpath = '~'

In [None]:
# Define input parameters
# You can opt to perform the calculation for:
#   1) One year only as in the example notebook
#   2) For all years in the dataset from 1980 - 2022

# Select year(s) for spring index calculation - use range() function for years
# ~1 line of code
years = range(1980,1981)

# Define the day range (upto 300 days) for calculating growing degree hours
# ~2 lines of code1
startdate = 1
enddate = 300

# Load the catalog
# ~1 line of code
catalog = pystac.Catalog.from_file(catalog_urlpath)

In [None]:
# Preprocess your data
# Create here the function to preprocess the dataset which has to:
#   1) select the predefined time range
#   2) convert temperature to Fahrenheit
#   3) convert daylength to hours
# ~4-5 lines of code
def preprocess_dataset(ds, startdate, enddate):
    """
    Subset the input dataset and make necessary conversions.
    """
    
    # Select time range for GDH calculation
    ds = ds.isel(time=slice(startdate-1, enddate))
    
    # Convert temperatures to Fahrenheit
    tmax = ds["tmax"] * 1.8 + 32
    tmin = ds["tmin"] * 1.8 + 32

    # Convert daylength from seconds to hours
    dayl = ds["dayl"] / 3600

    return tmax, tmin, dayl

#### Calculate leaf spring index

In [None]:
# Loop through years
for year in years:
    # Extract urlpaths to Daymet files
    # ~2 lines of code
    item = catalog.get_item(f"pr-{year}", recursive=True)
    hrefs = [
        item.assets[var].get_absolute_href() 
        for var in ("tmin", "tmax", "dayl")
    ]
    
    # Open dataset using open_dataset() - think about the right chunk size
    # ~1 line of code
    ds = open_dataset(hrefs, chunks={"time": 50, "x": 100, "y": 100})
    
    # Preprocess the data using your preprocess function
    # ~1 line of code
    tmax, tmin, dayl = preprocess_dataset(ds, startdate, enddate)

    # Calculate gdh using the calculate_gdh()
    # ~1 line of code
    gdh = calculate_gdh(dayl, tmin, tmax)
    
    # Rechunk gdh to an appropriate size - what is a good size for chunking?
    # ~1 line of code
    gdh = gdh.chunk({"time": enddate-startdate+1, "x": 100, "y": 100})
    
    # Calculate leaf spring index predictors using calculate_leaf_predictors()
    # ~1 line of code
    dde2, dd57, mds0, synop = calculate_leaf_predictors(gdh)
    
    # Calculate leaf spring index using calculate_first_leaf()
    # ~1 line of code
    first_leaf = calculate_first_leaf(dde2, dd57, mds0, synop)
    
    # Calculate the average over plants using add_mean_plant_layer() - check required inputs
    # ~1 line of code
    first_leaf = add_mean_plant_layer(first_leaf)

In [None]:
# Plot the day of first-leaf across Puerto Rico for one year 
first_leaf.plot.imshow(col="plant")

In [None]:
# Save the output to file using save_to_urlpath()
save_to_urlpath(
    first_leaf,
    output_urlpath, 
    f"PuertoRico",
)