# Thinking like Xarray : High-level computational patterns

**Deepak Cherian, CGD**

March 9, 2022

-----

This version includes additional material inserted during the live lecture.

## Motivation / Learning goals

From https://toolz.readthedocs.io/en/latest/control.html
> The Toolz library contains dozens of patterns like map and groupby. Learning a core set (maybe a dozen) covers the vast majority of common programming tasks often done by hand.
> A rich vocabulary of core control functions conveys the following benefits:
>    - You identify new patterns
>    - You make fewer errors in rote coding
>    - You can depend on well tested and benchmarked implementations


The same is true for xarray

## Xarray's high-level patterns


Xarray allows you to leverage dataset metadata to write more readable analysis code. The metadata is stored with the data; not in your head.
1. Dimension names: `dim="latitude"` instead of `axis=0`
2. Coordinate "labels": or axis tick labels. `data.sel(latitude=45)` instead of `data[10]`


Xarray also provides high-level computational patterns that cover many data analysis tasks.

1. `rolling` :
   [Operate on rolling windows of your data e.g. running mean](https://xarray.pydata.org/en/stable/computation.html#rolling-window-operations)
1. `coarsen` :
   [Downsample your data](https://xarray.pydata.org/en/stable/computation.html#coarsen-large-arrays)
1. `groupby` :
   [Bin data in to groups and reduce](https://xarray.pydata.org/en/stable/groupby.html)
1. `groupby_bins`: GroupBy after discretizing a numeric variable.
1. `resample` :
   [Groupby specialized for time axes. Either downsample or upsample your data.](https://xarray.pydata.org/en/stable/time-series.html#resampling-and-grouped-operations)
1. `weighted` :
   [Weight your data before reducing](https://xarray.pydata.org/en/stable/computation.html#weighted-array-reductions)


## Load example dataset

In [None]:
import numpy as np
import xarray as xr

xr.set_options(keep_attrs=True, display_expand_data=False)

da = xr.tutorial.load_dataset("air_temperature", engine="netcdf4").air
monthly = da.resample(time="M").mean()
data = da.isel(time=0)
data.plot()

-----

## Concept: "index space" vs "label space"


These are windowed operations with a window of a fixed size.

- ``rolling``: sliding window operations e.g. running mean
- ``coarsen``: decimating; reshaping

In [None]:
data

In [None]:
# index space
data[10, :]  # 10th element along the first axis; ¯\_(ツ)_/¯

In [None]:
# slightly better index space
data.isel(lat=10)  # slightly better, 10th element in latitude

In [None]:
# "label" space
data.sel(lat=50)  # much better! lat=50°N

In [None]:
# What I wanted to do
data.sel(lat=50)

# What I had to do (if I wasn't using xarray)
data[10, :]

-----

## Xarray provides high-level patterns in both "index space" and  "label space"

#### Index space
1. `rolling` :
   [Operate on rolling windows of your data e.g. running mean](https://xarray.pydata.org/en/stable/computation.html#rolling-window-operations)
1. `coarsen` :
   [Downsample your data](https://xarray.pydata.org/en/stable/computation.html#coarsen-large-arrays)
   
#### Label space
1. `groupby` :
   [Bin data in to groups and reduce](https://xarray.pydata.org/en/stable/groupby.html)
1. `groupby_bins`: GroupBy after discretizing a numeric variable.
1. `resample` :
   [Groupby specialized for time axes. Either downsample or upsample your data.](https://xarray.pydata.org/en/stable/time-series.html#resampling-and-grouped-operations)

----- 

## Index space: windows of fixed width


### Sliding windows of fixed length: ``rolling``

- returns object of same shape as input
- pads with NaNs to make this happen
- supports multiple dimensions

Here's the dataset

In [None]:
data.plot()

And now smoothed 5 point running mean in lat and lon

In [None]:
data.rolling(lat=5, lon=5, center=True).mean().plot()

#### Apply an existing numpy-only function with ``reduce``

Tip: The `reduce` method expects a function that can receive and return plain arrays (e.g. numpy).  The `map` method expects a function that can receive and return Xarray objects.

Here's an example function: `np.mean`

**Exercise** Calculate the rolling mean in 5 point bins along both latitude and longitude using [`rolling(...).reduce`](https://docs.xarray.dev/en/stable/generated/xarray.core.rolling.DataArrayRolling.reduce.html)

**Answer:**

In [None]:
# exactly equivalent to data.rolling(...).mean()
data.rolling(lat=5, lon=5, center=True).reduce(np.ptp).plot()

#### For more complicated analysis, construct a new array with a new dimension.
Allows things like short-time fourier transform, spectrogram, windowed rolling etc.

In [None]:
simple = xr.DataArray(np.arange(10), dims="time", coords={"time": np.arange(10)})
simple

In [None]:
# adds a new dimension "window"
simple.rolling(time=5, center=True).construct("window")

**Exercise** Calculate the 5 point running mean in time using `rolling.construct`

**Answer**

In [None]:
(
    simple
    .rolling(time=5, center=True)
    .construct("window")
    .mean("window")
)

``construct`` is clever. 
1. It constructs a **view** of the original array, so it is memory-efficient. but you didn't have to know that.
1. It does something sensible for dask arrays (though generally you want big chunksizes for the dimension you're sliding along).
1. It also works with rolling along multiple dimensions!

#### Advanced: Another ``construct`` example

This is a 2D rolling example; we need to provide two new dimension names

In [None]:
(
    data
    .rolling(lat=5, lon=5, center=True)
    .construct(lat="lat_roll", lon="lon_roll")
)

-----

### Block windows of fixed length: ``coarsen``

For non-overlapping windows or "blocks" use ``coarsen``. The syntax is very similar to `rolling`. You will need to specify ``boundary`` if the length of the dimension is not a multiple of the block size

In [None]:
data

In [None]:
data.plot()

In [None]:
data.coarsen(lat=5, lon=5, boundary="trim").std()

In [None]:
(
    data
    .coarsen(lat=5, lon=5, boundary="trim")
    .mean()
    .plot()
)

#### coarsen supports ``reduce`` for custom reductions

**Exercise** Use ``coarsen.reduce`` to apply `np.mean` in 5x5 (latxlon) point blocks of `data`

**Answer**

In [None]:
(
    data.coarsen(lat=5, lon=5, boundary="trim")
    .reduce(np.ptp)
    .plot()
)

#### coarsen supports ``construct`` for block reshaping

This is usually a good alternative to `np.reshape`

A simple example splits a 2-year long monthly 1D time series into a 2D array shaped (year x month)

In [None]:
months = xr.DataArray(np.tile(np.arange(1, 13), reps=2), dims="time", coords={"time": np.arange(1,25)})
months

In [None]:
# break "time" into two new dimensions: "year", "month"
months.coarsen(time=12).construct(time=("year", "month"))

Note two things:
1. The `time` dimension was also reshaped.
1. The new dimensions `year` and `month` don't have any coordinate labels associated with them.


What if the data had say 23 instead of 24 values? In that case we specify a different `boundary` , here we pad to 24 values.

In [None]:
months.isel(time=slice(1, None)).coarsen(time=12, boundary="pad").construct(time=("year", "month"))

This ends up adding values at the end of the array, not so sensible for this problem. To pad at the beginning we use [DataArray.pad](https://xarray.pydata.org/en/stable/generated/xarray.DataArray.pad.html)

In [None]:
(
    months.isel(time=slice(1, None))
    .pad(time=(1, 0), constant_values=-1)
    .coarsen(time=12)
    .construct(time=("year", "month"))
)

**Exercise** Reshape the `time` dimension of the DataArray `monthly` to year x month and visualize the seasonal cycle for two years at 250°E

**Answer**

In [None]:
# splits time dimension into year x month
year_month = monthly.coarsen(time=12).construct(time=("year", "month"))

# assign a nice coordinate value for month
year_month["month"] = ["jan", "feb", "mar", "apr", "may", "jun", "jul", "aug", "sep", "oct", "nov", "dec"]

# assign a nice coordinate value for year
year_month["year"] = [2013, 2014]

# seasonal cycle for two years
year_month.sel(lon=250).plot.contourf(col="year", x="month", y="lat")

This exercise came up during the live lecture. 

**Exercise** Calculate the rolling 4 month average, averaged across years.

**Answer** 
1. We first reshape using `coarsen.construct` to add `year` as a new dimension.
2. Then `rolling` on the month dimension.
3. It turns out that `roll.mean(["year", "month"])` doesn't work. So we use `roll.construct` to get a DataArray with a new dimension `window` and then take the mean over `window` and `year`

In [None]:
reshaped = months.coarsen(time=12).construct(time=("year", "month"))
roll = reshaped.rolling(month=4, center=True)
roll.construct("window").mean(["window", "year"])

### Index space summary

Use `rolling` and `coarsen` for fixed size windowing operations.
1. `rolling` for overlapping windows
1. `coarsen` for non-overlapping windows.

Both provide the usual reductions as methods (`.mean()` and friends), and also `reduce` and `construct` for custom operations.

-----

## Label space "windows" or bins : GroupBy

Generalization of ``coarsen``: sometimes the windows you want are not regular.

- ``groupby``: e.g. climatologies, composites; works when "groups" are exact: e.g. characters or integers; not floats
- ``groupby_bins``: binning operations e.g. histograms
- ``resample``: groupby but specialized for time grouping (so far)

**tip** Both `groupby_bins` and `resample` are implemented as `GroupBy` with a specific way of constructing group labels.


### Deconstructing GroupBy


Commonly called "split-apply-combine". 

1. "split" : break dataset into groups
1. "apply" : apply an operation, usually a reduction like `mean`
1. "combine" : concatenate results from apply step along new "group" dimension

But really there is a first step: "identifying groups" also called "factorization" (or "binning"). Usually this is the hard part.

So "identify groups" → "split into groups" → "apply function" → "combine results".




In [None]:
da.groupby("time.month")

In [None]:
da.groupby("time.month").mean()

This is how xarray identifies "groups" for the monthly climatology computation

In [None]:
da.time.dt.month.plot()

Similarly for binning, 

In [None]:
data.groupby_bins("lat", bins=[20, 35, 40, 45, 50])

and resampling...

In [None]:
da.resample(time="M")

### Constructing group labels

Xarray uses `pandas.factorize` for `groupby` and `pandas.cut` for `groupby_bins`. 

If the automatic group detection doesn't work for your problem then these functions are useful for constructing "group labels" in many cases

1. [numpy.digitize](https://numpy.org/doc/stable/reference/generated/numpy.digitize.html) (binning)
1. [numpy.searchsorted](https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html) supports many other data types
1. [pandas.factorize](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.factorize.html) supports characters, strings etc.
1. [pandas.cut](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.cut.html) for binning
1. [DataArray.isin](https://xarray.pydata.org/en/stable/generated/xarray.DataArray.isin.html)


#### More commonly useful are ["datetime components"](https://xarray.pydata.org/en/stable/user-guide/time-series.html#datetime-components)

See a full list [here](https://xarray.pydata.org/en/stable/generated/xarray.core.accessor_dt.DatetimeAccessor.html?highlight=DatetimeAccessor)

Accessed using ``DataArray.dt.*``

In [None]:
da.time

In [None]:
da.time.dt.day

In [None]:
da["time.day"]

In [None]:
da.time.dt.season

**Demo** Grouping over a custom definition of seasons using numpy.isin.

We want to group over 4 seasons: `DJF`, `MAM`, `JJAS`, `ON` - this makes physical sense in the Indian Ocean basin

Start by extracting months.

In [None]:
month = da.time.dt.month.data
month

Create a new empty array

In [None]:
season = np.full(month.shape, "    ")
season

Use `isin` to assign custom seasons,

In [None]:
season[np.isin(month, [12, 1, 2])] = "DJF"
season[np.isin(month, [3, 4, 5])] = "MAM"
season[np.isin(month, [6, 7, 8, 9])] = "JJAS"
season[np.isin(month, [10, 11])] = "ON"
season = da.time.copy(data=season)
season

In [None]:
(
    # Calculate climatology
    da.groupby(season).mean()
    # reindex to get seasons in logical order (not alphabetical order)
    .reindex(time=["DJF", "MAM", "JJAS", "ON"])
    .plot(col="time")
)

#### `floor`, `ceil` and `round` time

Basically "resampling"

In [None]:
da.time

In [None]:
# remove roundoff error in timestamps
# floor to daily frequency
da.time.dt.floor("D")

#### `strftime` can be extremely useful

So useful and so unintuitive that it has its own website: https://strftime.org/

This example avoids merging "Feb-29" and "Mar-01" for a daily climatology

In [None]:
da.time.dt.strftime("%b-%d")

### groupby supports `reduce` for custom reductions

This applies to `groupby_bins` and `resample`

In [None]:
(
    da.groupby("time.month")
    .reduce(np.mean)
    .plot(col="month", col_wrap=4)
)

**tip** `map` is for functions that expect and return xarray objects (see also ``Dataset.map``). `reduce` is for functions that expect and return plain arrays (like numpy or scipy functions)

### GroupBy does not provide construct

All the groups need not be the same "length" (e.g. months can have 28, 29, 30, or 31 days)

### Instead looping over groupby objects is possible

Maybe you want to plot data in each group separately?

In [None]:
for label, group in da.groupby("time.month"):
    print(label)

This is a DataArray contain data for all December days

In [None]:
group

Maybe you want a histogram of December temperatures?

In [None]:
group.plot.hist()

### In most cases, avoid a for loop using ``map``

Apply functions that expect xarray Datasets or DataArrays.

Avoid having to manually combine results using concat

In [None]:
def iqr(da, dim):
    """ Calculates interquartile range """
    return (da.quantile(q=0.75, dim=dim) - da.quantile(q=0.25, dim=dim)).rename("iqr")


da.groupby("time.month").map(iqr, dim="time")

---

## Summary

Xarray provides methods for high-level analysis patterns:
1. `rolling` :
   [Operate on rolling windows of your data e.g. running mean](https://xarray.pydata.org/en/stable/computation.html#rolling-window-operations)
1. `coarsen` :
   [Downsample your data](https://xarray.pydata.org/en/stable/computation.html#coarsen-large-arrays)
1. `groupby` :
   [Bin data in to groups and reduce](https://xarray.pydata.org/en/stable/groupby.html)
1. `groupby_bins`: GroupBy after discretizing a numeric variable.
1. `resample` :
   [Groupby specialized for time axes. Either downsample or upsample your data.](https://xarray.pydata.org/en/stable/time-series.html#resampling-and-grouped-operations)
1. `weighted` :
   [Weight your data before reducing](https://xarray.pydata.org/en/stable/computation.html#weighted-array-reductions)


## More resources

1. More tutorials here:https://xarray-contrib.github.io/xarray-tutorial/
1. Answers to common questions on "how to do X" are here: https://xarray.pydata.org/en/stable/howdoi.html