# 2 B: Adding more dimensions, without loops

We'll continue on with our heatwave example, this time expanding the analysis to the full 3d dataset

We won't be using any explicit loops, instead we'll rely on Dask to automatically order array operations for us

To start off with we'll load some libraries and the Dask distributed client, so that the analysis is run in parallel.

In [None]:
%matplotlib inline
import numpy
import xarray
import dask.array
from dask.diagnostics import ProgressBar

In [None]:
from dask.distributed import Client, progress
#Client()

## Adding dimensions

For the most part working with multi-dimensional data is exactly the same as 1d data. 'Split' functions take the name of the dimension to work on in either case, so operations like calculating the 15 day rolling average work exactly the same

In [None]:
cmip_tasmax  = '/g/data/rr3/publications/CMIP5/output1/CSIRO-BOM/ACCESS1-3/historical/day/atmos/day/r1i1p1/latest/tasmax/tasmax*.nc'

ds = xarray.open_mfdataset(cmip_tasmax, chunks={'lat': 50, 'lon': 50, 'time': 1000})
tasmax = ds.tasmax

In [None]:
tasmax_15day = tasmax.rolling(time=15, center=True).mean()

## Sorting - avoid it

There are some types of operations that you really need to avoid when working with large datasets.

Calculating things like min/max, mean and standard deviation are fine, because they can be calculated in a 'rolling' manner - there's no need to load the entire dataset to do the calculation.

Calculating the percentile however requires sorting the data, which is one of the worst things you can do with a large dataset. Sorting requires loading the entire dataset into memory.

In this case we could make the assumption that the temperatures are normally distributed, and use that distribution to get the percentiles. This might not always be valid, so you should validate for yourself assumptions like this work

**Make an estimate of the 90th percentile at each day of the year at each grid point, assuming temperatures are normally distributed**

[`scipy.stats.norm.ppf(q, loc=mean, scale=std)`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.norm.html) could be useful

<a href="#ans1" data-toggle="collapse">Answer</a>
<div class="collapse" id="ans1">
<pre><code>
from scipy.stats import norm

# Show a fancy progress bar
with ProgressBar():
    tasmax_mean = tasmax_15day.groupby('time.dayofyear').mean('time')
    tasmax_std = tasmax_15day.groupby('time.dayofyear').std('time')
    
    # The ppf function isn't dask aware, so it will load the mean and stddev data
    threshold = xarray.DataArray(norm.ppf(0.9, loc=tasmax_mean, scale=tasmax_std), coords=tasmax_mean.coords)
</code></pre>
</div>

**How does our estimate of the 90th percentile compare with the actual value (look at a single grid point)**

<a href="#ans2" data-toggle="collapse">Answer</a>
<div class="collapse" id="ans2">
<pre><code>
with ProgressBar():
    threshold.sel(lat=-37.8136, lon=144.9631, method='nearest').plot()
    tasmax_15day.sel(lat=-37.8136, lon=144.9631, method='nearest').load().groupby('time.dayofyear').reduce(numpy.nanpercentile, q=90, dim='time').plot()
</code></pre>
</div>


## Lazy functions

To get our filter function to work nicely with dask we have to make a couple changes - when we directly run numpy functions, like `numpy.logical_and()`, these need to be replaced with the dask version, `dask.array.logical_and()` in this case.

You'll also need to add an `allow_lazy=True` flag when you call `.reduce()`, to let Xarray know that the filter is a dask-aware function.


### Remember - if we've done things right, operations on dask arrays should return immediately!

In [None]:
def heatwave_start_filter(x, axis):
    """
    Returns 1 if a heatwave starts at this time, otherwise nan
    
    Matches the pattern [*, <0, >=0, >=0, >=0] on the rolling dimension
    
    Should be called with x.rolling(time=5, center=True).reduce(heatwave_start_filter)
    """

    assert axis % x.ndim == x.ndim - 1
    assert x.shape[axis] == 5

    left  = dask.array.isnan(x[..., 1]) # Time before this one

    right = dask.array.isfinite(x[..., 2:]).all(axis=axis) # This time and two after
    
    test  = dask.array.logical_and(left, right)
    return dask.array.where(test, 1.0, numpy.nan)

In [None]:
with ProgressBar():
    hw_starts = candidates.rolling(time=5, center=True, min_periods=1).reduce(heatwave_start_filter, allow_lazy=True)

## Optimisation tips

 * Narrow down your region with `.isel()` before you use a split function
 * After setting up climatologies `.load()` or save to file so they are pre-calculated
 * Use smaller chunk sizes to reduce what dask loads
 
http://xarray.pydata.org/en/stable/dask.html#optimization-tips

For instance, put the analysis into a function, then use a selected region as the input

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

with ProgressBar():
    ax = plt.axes(projection=ccrs.Orthographic(central_longitude=140))
    hw_starts.sel(time='1998').count(dim='time').plot(ax=ax, transform=ccrs.PlateCarree())
    ax.coastlines()

In [None]:
def find_hw_starts(t):
    candidates = t.where(t.groupby('time.dayofyear') > threshold)
    return candidates.rolling(time=5, center=True, min_periods=1).reduce(heatwave_start_filter, allow_lazy=True)

In [None]:
with ProgressBar():
    ax = plt.axes(projection=ccrs.Orthographic(central_longitude=140))
    find_hw_starts(tasmax.sel(time=slice('1997','1999'))).count(dim='time').plot(ax=ax, transform=ccrs.PlateCarree())
    ax.coastlines()