# General Guide for Parallelizing xCDAT Operations with Dask

Author: [Tom Vo](https://github.com/tomvothecoder)


## Overview

This notebook serves as a general guide for parallel computing with xCDAT. It covers the
following topics:

- Basics of Dask Arrays
- General Dask Best Practice
- How to use Dask with Xarray
- How to use Dask with xCDAT, including real-world examples and performance metrics
- Dask Schedulers and using a local distributed scheduler for more resource-intensive needs

_The data used in the code examples can be found through the [Earth System Grid Federation (ESGF) search portal](https://aims2.llnl.gov/search)._

### More Resources

To learn more in-depth about Dask and Xarray, please check these resources out:

- [Official Xarray Parallel Computing with Dask Guide](https://docs.xarray.dev/en/stable/user-guide/dask.html)
- [Official Xarray Parallel Computing with Dask Jupyter Notebook Tutorial](https://tutorial.xarray.dev/intermediate/xarray_and_dask.html)
- [Official Dask guide for Xarray with Dask Arrays](https://examples.dask.org/xarray.html)
- [Project Pythia: Dask Arrays with Xarray](https://foundations.projectpythia.org/core/xarray/dask-arrays-xarray.html)


## Notebook Setup

Create an Anaconda environment for this notebook using the command below. You can
substitute `conda` with `mamba` if you are using Mamba instead.

```bash
conda create -n xcdat_dask_guide -c conda-forge python xarray netcdf4 dask xcdat flox matplotlib nc-time-axis jupyter jupyter-server-proxy
```

- [flox](https://flox.readthedocs.io/en/latest/) is a package that is used to improve the xarray `groupby()` performance by making it parallelizable.
- [matplotlib](https://matplotlib.org/) is an optional dependency required for plotting with xarray.
- [nc-time-axis](https://nc-time-axis.readthedocs.io/en/latest/) is an optional dependency required for `matplotlib` to plot `cftime` coordinates.
- [jupyter-server-proxy](https://github.com/jupyterhub/jupyter-server-proxy) is a package used
  for co-locating the Jupyter server with the Dask Scheduler so that they share the same port,
  allowing for the Dask dashboard to route the connection to Jupyter (and vice versa).


## Dask Best Practices

- **Use NumPy**
  - If your data fits comfortably in RAM and you are not performance bound, then using NumPy might be the right choice.
  - Dask adds another layer of complexity which may get in the way.
  - If you are just looking for speedups rather than scalability then you may want to consider a project like [Numba](https://numba.pydata.org/)
- **Select a good chunk size**
  - A common performance problem among Dask Array users is that they have chosen a chunk size that is either too small (leading to lots of overhead) or poorly aligned with their data (leading to inefficient reading).
- Orient your chunks
  - When reading data you should align your chunks with your storage format.
- **Avoid Oversubscribing Threads**
  - By default Dask will run as many concurrent tasks as you have logical cores. It assumes that each task will consume about one core. However, many array-computing libraries are themselves multi-threaded, which can cause contention and low performance.
- **Consider Xarray**
  - The Xarray package wraps around Dask Array, and so offers the same scalability, but also adds convenience when dealing with complex datasets
- **Build your own Operations**
  - Often we want to perform computations for which there is no exact function in Dask Array. In these cases we may be able to use some of the more generic functions to build our own.

&mdash; <cite>https://docs.dask.org/en/stable/array-best-practices.html#best-practices</cite>


## The Basics of Dask Arrays

- **Dask divides arrays** into many small pieces, called **"chunks"** (each presumed to be small enough to fit into memory)
- Dask Array **operations are lazy**
  - Operations **queue** up a series of tasks mapped over blocks
  - No computation is performed until values need to be computed (hence "lazy")
  - Data is loaded into memory and **computation** is performed in **streaming fashion**, **block-by-block**
- Computation is controlled by multi-processing or thread pool

<div style="text-align:center">
  <img src="../_static/dask-array.png" alt="Dask Array" style="display: inline-block; width:300px;">
</div>

&mdash; <cite>https://docs.xarray.dev/en/stable/user-guide/dask.html</cite>


## Xarray and Dask

<div style="text-align: center">
    <img src="../_static/xarray-logo.png" alt="xarray logo" style="display: inline-block; margin-right: 50px; width:400px;">
</div>


**Why does Xarray integrate with Dask?**

> Xarray integrates with Dask to support parallel computations and streaming computation
> on datasets that don’t fit into memory. Currently, Dask is an entirely optional feature
> for xarray. However, the benefits of using Dask are sufficiently strong that Dask may
> become a required dependency in a future version of xarray.
>
> &mdash; <cite>https://docs.xarray.dev/en/stable/use

**Which Xarray features support Dask?**

> Nearly all existing xarray methods (including those for indexing, computation,
> concatenating and grouped operations) have been extended to work automatically with
> Dask arrays. When you load data as a Dask array in an xarray data structure, almost
> all xarray operations will keep it as a Dask array; when this is not possible, they
> will raise an exception rather than unexpectedly loading data into memory.
>
> &mdash; <cite>https://docs.xarray.dev/en/stable/user-guide/dask.html#using-dask-with-xarray</cite>

**What is the default Dask behavior for distributing work on compute hardware**

> By default, dask uses its multi-threaded scheduler, which distributes work across
> multiple cores and allows for processing some datasets that do not fit into memory.
> For running across a cluster, [setup the distributed scheduler](https://docs.dask.org/en/latest/setup.html).
>
> &mdash; <cite>https://docs.xarray.dev/en/stable/user-guide/dask.html#using-dask-with-xarray</cite>

**How do I use Dask arrays in an `xarray.Dataset`**

> The usual way to create a Dataset filled with Dask arrays is to load the data from a
> netCDF file or files. You can do this by supplying a `chunks` argument to [open_dataset()](https://docs.xarray.dev/en/stable/generated/xarray.open_dataset.html#xarray.open_dataset)
> or using the [open_mfdataset()](https://docs.xarray.dev/en/stable/generated/xarray.open_mfdataset.html#xarray.open_mfdataset) function.

**What happens if I don't specify `chunks` with `open_mfdataset()`**

> `open_mfdataset()` called without `chunks` argument will return dask arrays with
> chunk sizes equal to the individual files. Re-chunking the dataset after creation
> with `ds.chunk()` will lead to an ineffective use of memory and is not recommended.
>
> &mdash; <cite>https://docs.xarray.dev/en/stable/user-guide/dask.html#reading-and-writing-data</cite>


## First, let's learn about chunking arrays


> For performance, a good choice of `chunks` follows the following rules:
>
> 1. A chunk should be small enough to fit comfortably in memory. We'll
>    have many chunks in memory at once
> 2. A chunk must be large enough so that computations on that chunk take
>    significantly longer than the 1ms overhead per task that Dask scheduling
>    incurs. A task should take longer than 100ms
> 3. Chunk sizes between 10MB-1GB are common, depending on the availability of
>    RAM and the duration of computations
> 4. Chunks should align with the computation that you want to do.
>    - For example, if you plan to frequently slice along a particular dimension,
>      then it's more efficient if your chunks are aligned so that you have to
>      touch fewer chunks. If you want to add two arrays, then its convenient if
>      those arrays have matching chunks patterns
> 5. Chunks should align with your storage, if applicable.
>    - Array data formats are often chunked as well. When loading or saving data,
>      if is useful to have Dask array chunks that are aligned with the chunking
>      of your storage, often an even multiple times larger in each direction
>
> &mdash; <cite>https://docs.dask.org/en/latest/array-chunks.html</cite>


### Chunking with Xarray


The `chunks` parameter has critical performance implications when using Dask arrays.

- **If your chunks are too small**, queueing up operations will be extremely slow.

  - Dask will translate each operation into a huge number of operations mapped across chunks.
  - Computation on Dask arrays with small chunks can also be slow, because each operation on a chunk has some fixed overhead from the Python interpreter and the Dask task executor.

- **If your chunks are too big**, some of your computation may be wasted. Dask only computes results one chunk at a time.

&mdash; <cite>https://docs.xarray.dev/en/stable/user-guide/dask.html#chunking-and-performance</cite>


### Good rule of thumb

**Create arrays with a minimum chunksize of at least one million elements (e.g., a 1000x1000 > matrix).**

**With large arrays (10+ GB)**, the cost of queueing up Dask operations can be noticeable and **you may need even > larger chunksizes**.


### Or let Dask try to figure out chunking for you

Dask Arrays can look for a `.chunks` attribute and use that to provide a good chunking.
This can help prevent users from specifying "too many chunks" and "too few chunks" which
can lead to performance issues.

To do this in `open_dataset()`/`open_mfdataset()`, specify `chunks` on a specific dimension(s) or all dimensions, as shown below:

1. `chunks={"time": "auto"}` - auto-scale the specified dimension(s) to get to accommodate ideal chunk sizes. In this example, replace `"time"` and/or add additional dims to the dictionary for additional auto-scaling.
2. `chunks="auto"` - allow chunking _all_ dimensions to accommodate ideal chunk sizes

&mdash; <cite>https://docs.dask.org/en/latest/array-chunks.html#automatic-chunking</cite>


> DISCLAIMER: Although Dask's chunk auto-scaling tries its best to optimally align chunks to the ideal sizes, the auto-scaling might not always be optimal. For these cases, it is recommended
> that you manually chunk for ideal sizes.


## Code Example - Setup


In [None]:
import xarray as xr
import logging

# Silence flox logger info messages.
logger = logging.getLogger("flox")
logger.setLevel(logging.WARNING)

# Disclaimer: The dataset used in the example is only a few hundred MBs to make
# downloading the file quick. A file this small should normally **NOT** be
# chunked since computational performance will most likely suffer.
filepath = "http://esgf.nci.org.au/thredds/dodsC/master/CMIP6/CMIP/CSIRO/ACCESS-ESM1-5/historical/r10i1p1f1/Amon/tas/gn/v20200605/tas_Amon_ACCESS-ESM1-5_historical_r10i1p1f1_gn_185001-201412.nc"

## Code Example - Parallelism with xCDAT + Dask

The code example below demonstrates chunking a dataset in Xarray and grouping the data
in parallel across chunks.

**By default, dask uses its multi-threaded scheduler**, which distributes work across multiple cores and allows for processing some datasets that do not fit into memory.

If you are interested in using a distributed scheduler (local or cluster) for more resource-intensive computational operations, there is more information below in this notebook.


We're letting Dask auto-scale all dimensions to get a good chunk size using `chunks="auto"`, which references the `.chunks` attribute.


In [None]:
ds = xr.open_dataset(filepath, chunks="auto")

In [None]:
ds

Now we perform a daily average using the `groupby` API.

`flox` must be installed to make this API parallelizable and Xarray will use `flox` by
default if it is installed.


In [None]:
tas_daily_avg_xr = (
    ds["tas"].groupby(ds.time.dt.year).mean(method="cohorts", engine="flox")
)

tas_daily_avg_xr

`.load()` or `.compute()` will trigger the computation, which loads the data into
memory. This also automatically happens when writing out the data to a file.


In [None]:
tas_daily_avg_xr.compute()

## Code Example - Parallelism with xCDAT + Dask

Many core [xCDAT computational APIs](https://xcdat.readthedocs.io/en/latest/api.html#methods),
including spatial averaging and temporal averaging, inherit Xarray's Dask support by operating on `xarray.Dataset` objects and making calls to parallelized Xarray APIs.

**Just chunk the xarray.Dataset object as you normally would before calling any of the parallelizable xCDAT APIs**.

Here's an example with xCDAT's `temporal.group_average()` API.


In [None]:
tas_daily_avg_xc = ds.temporal.group_average("tas", freq="month")

tas_daily_avg_xc

`.load()` or `.compute()` will trigger the computation, which loads the data into
memory. This also automatically happens when writing out the data to a file.


In [None]:
tas_daily_xc.compute()

## Dask Task Scheduling

> All of the large-scale Dask collections like Dask Array, Dask DataFrame, and Dask Bag and the fine-grained APIs like delayed and futures generate task graphs where each node in the graph is a normal Python function and edges between nodes are normal Python objects that are created by one task as outputs and used as inputs in another task. After Dask generates these task graphs, it needs to execute them on parallel hardware. This is the job of a task scheduler. Different task schedulers exist, and each will consume a task graph and compute the same result, but with different performance characteristics.
> Dask has two families of task schedulers:
>
> 1.  **Single-machine scheduler**: This scheduler provides basic features on a local process or thread pool. This scheduler was made first and is the default. It is simple and cheap to use, although it can only be used on a single machine and does not scale
> 2.  **Distributed scheduler**: This scheduler is more sophisticated, offers more features, but also requires a bit more effort to set up. It can run locally or distributed across a cluster
>
> &mdash; <cite>https://docs.dask.org/en/stable/scheduling.html</cite>

<div style="text-align:center">
  <img src="../_static/dask-overview-schedulers.svg" alt="Dask Schedulers" style="display: inline-block;">
</div>


### Setup a local Dask Distributed Scheduler


Xarray is setup to use Dask's default single-machine, multi-threaded scheduler. However, Dask advises users to use the Dask distributed scheduler for more advanced functionality and more resource-intensive needs.

&mdash; <cite>https://docs.dask.org/en/stable/scheduling.html#dask-distributed-local</cite>


#### 1. Setup the Dask Client for the local distributed scheduler

Xarray will automatically use the Dask Client when calling `.compute()` or `.load()`
to trigger queued up operations in the Dask task graph.

You can configure the Dask Client (e.g., memory limit) to your needs. Check these
API docs out:

- https://distributed.dask.org/en/latest/api.html#client
- https://distributed.dask.org/en/latest/api.html#distributed.LocalCluster


In [None]:
from dask.distributed import Client

client = Client()

#### 2. Open the Dask Dashboard UI

The Dask distributed scheduler provides an interactive dashboard containing many plots
and tables with live information.

Check this [Dask documentation page](https://docs.dask.org/en/stable/dashboard.html) to learn how to interpret the information. Here's an example:

<div style="text-align:center">
  <img src="../_static/dask-dashboard-example.png" alt="Dask Dashboard UI Example" style="display: inline-block; width:800px;">
</div>


In [None]:
# Get the link to the dashboard
client.dashboard_link

#### 3. Run an Xarray/xCDAT computation while viewing the dashboards


In [None]:
tas_daily_avg3 = ds.temporal.average("tas")

tas_daily_avg3

In [None]:
tas_daily_avg3.compute()

## FAQs


### Are there any other optimizations tips for working with Dask and Xarray?

We HIGHLY recommend checking out the [Optimization Tips](https://docs.xarray.dev/en/stable/user-guide/dask.html#optimization-tips) section if you are using Dask with Xarray

### Are there cases where xCDAT loads Dask arrays into memory?

As of `xarray=2023.5.0`, Xarray does not support updating/setting multi-dimensional dask
arrays. The followin error is raised if this is attempted: `xarray can't set arrays with multiple array indices to dask yet`.

As a workaround, xCDAT loads coordinate bounds into memory if they are multi-dimensional
Dask arrays before performing operations or computations. This loading occurs in the
following APIs:

- `xcdat.axis.swap_lon_axis`
  - swapping longitude axis orientation
  - aligning longitude bounds to (0, 360) axis
- `xarray.Dataset.spatial.average`
  - generating weights using lat/lon coordinate bounds
  - swapping longitude axis orientation
  - scaling domain bounds to a specified region
- `xcdat.Dataset.temporal.<average|group_average|climatology|departures>`
  - generating weights using time coordinate bounds
