<img src="../images/dask_horizontal.svg" align="left" width="30%">
<img src="../images/dataset-diagram-logo.png" align="right" width="30%">

# Dask and Xarray


This notebook demonstrates one of xarray's most powerful features: the ability to wrap dask arrays and allow users to seamlessly execute analysis code in parallel.


## Learning Objectives

- Learn that xarray DataArrays and Datasets are "dask collections" i.e. you can execute top-level dask functions such as dask.visualize(xarray_object)
- Learn that all xarray built-in operations can transparently use dask
- Learn that xarray provides tools to easily parallelize custom functions across blocks of dask-backed xarray objects.

## Prerequisites


| Concepts | Importance | Notes |
| --- | --- | --- |
| Familiarity with Dask Array | Necessary | |
| Familiarity with xarray | Necessary | |

- **Time to learn**: *15-20 minutes*


## Setup

First let's set up a `LocalCluster` using `dask.distributed`. 



In [None]:
import dask
import dask.array as da
import xarray as xr
from dask.distributed import Client, LocalCluster

In [None]:
cluster = LocalCluster()
client = Client(cluster)
client

## Reading data with Dask and Xarray

Recall that a dask's array consists of many chunked arrays:

In [None]:
darr = da.ones((2000, 300), chunks=(200, 50))
darr

In [None]:
darr.compute()

To read data as dask arrays with xarray, we need to specify the `chunks` argument to `open_dataset()` function. 

In [None]:
ds = xr.open_dataset(
    "data/tos_Omon_CESM2_historical_r11i1p1f1_gr_200001-201412.nc", engine="netcdf4", chunks={}
)
ds

Passing `chunks={}` to `open_dataset()` works, but since we didn't tell dask how to split up (or chunk) the array, Dask will create a single chunk for our array. 

In [None]:
ds = xr.open_dataset(
    "data/tos_Omon_CESM2_historical_r11i1p1f1_gr_200001-201412.nc",
    engine="netcdf4",
    chunks={"time": 90, "lat": 180, "lon": 360},
)
ds

In [None]:
ds.tos

In [None]:
ds.tos.chunks

## Xarray data structures are first-class dask collections

This means you can call the following functions 

- `dask.visualize(...)`
- `dask.compute(...)`
- `dask.persist(...)`

on both xarray DataArrays and Datasets backed by dask-arrays. 

In [None]:
dask.visualize(ds)

## Parallel and Lazy computation using `dask.array` with xarray


Xarray seamlessly wraps dask so all computation is deferred until explicitly requested. 

In [None]:
z = ds.tos.mean(['lat', 'lon']).dot(ds.tos.T)
z

As you can see, `z` contains a dask array. This is true for all xarray built-in operations including subsetting

In [None]:
z.isel(lat=0)

In [None]:
dask.visualize(z)

In [None]:
%%time
z.compute()

## Reading multiple datasets with `open_mfdataset`

Xarray provides a built-in function `xr.open_mfdataset()` for opening multiple files as a single dataset. This makes it easy to work with data from multiple files as one logical dataset. 

For demonstration purposes, let's revisit our example in [Dask Delayed Notebook](./08-dask-delayed.ipynb). In this example, we loop over a list of files (for four ensemble members), and we compute the anomaly for each ensemble member as follows:

In [None]:
import pathlib

data_dir = pathlib.Path("data/")
files = sorted(data_dir.glob("tos_Omon_CESM2*"))

results = {}
for file in files:

    # Read in file
    ds = dask.delayed(xr.open_dataset)(file, engine='netcdf4')

    # Compute anomaly
    gb = ds.tos.groupby('time.month')
    tos_anom = gb - gb.mean(dim='time')

    # Save the computed anomaly and record the name of the ensemble member
    results[file.stem.split('_')[-3]] = tos_anom


# Compute the results
# dask.compute() returns a tuple here with a single item. So, ensure to grab this one item by using the 0 index
computed_results = dask.compute(results)[0]
# Combine the results in our dataarray by concatenating the results across a new dimension `ensemble_member`
dset_anom = xr.concat(list(computed_results.values()), dim='ensemble_member')
dset_anom['ensemble_member'] = list(computed_results.keys())
dset_anom

Instead of explicitly looping over the list of files to construct xarray datasets, we can pass the list of files to [`xr.open_mfdataset()`](https://xarray.pydata.org/en/stable/generated/xarray.open_mfdataset.html#xarray.open_mfdataset) and xarray will construct a single dataset for us:

In [None]:
dset = xr.open_mfdataset(
    sorted(files),
    concat_dim='ensemble_member',
    combine="nested",
    parallel=True,
    data_vars=['tos'],
    engine="netcdf4",
    chunks={'time': 90},
)
# Add coordinate labels for the newly created `ensemble_member` dimension
dset["ensemble_member"] = ['r11i1p1f1', 'r7i1p1f1', 'r8i1p1f1', 'r9i1p1f1']
dset

<div class="admonition alert alert-info">
    <p class="admonition-title" style="font-weight:bold"></p>
    <ul>
    <li>By default, open_mfdataset() will chunk each netCDF file into a single Dask array; supply the chunks argument to control the size of the resulting Dask arrays.</li>
    <li>In more complex cases, you can open each file individually using open_dataset(..., chunks={...}) and merge the results into a single dataset.</li>
        <li>Passing the keyword argument parallel=True to open_mfdataset() will speed up the reading of large multi-file datasets by executing those read tasks in parallel using dask.delayed.</li>
    </ul>
</div>

In [None]:
# Compute anomaly
gb = dset.tos.groupby('time.month')
tos_anom = gb - gb.mean(dim='time')
tos_anom

In [None]:
tos_anom.sel(lon=310, lat=50, method='nearest').plot(col='ensemble_member', col_wrap=2, size=4);

<div class="admonition alert alert-warning">
    <p class="admonition-title" style="font-weight:bold"></p>
    Note that using plotting functionality will automatically trigger computation of required results.
</div>

So, with xarray's `open_mfdataset()`, the following code

```python
results = {}
for file in files:

    # Read in file
    ds = dask.delayed(xr.open_dataset)(file, engine='netcdf4')

    # Compute anomaly
    gb = ds.tos.groupby('time.month')
    tos_anom = gb - gb.mean(dim='time')

    # Save the computed anomaly and record the name of the ensemble member
    results[file.stem.split('_')[-3]] = tos_anom


# Compute the results
# dask.compute() returns a tuple here with a single item. So, ensure to grab this one item by using the 0 index
computed_results = dask.compute(results)[0]
# Combine the results in our dataarray by concatenating the results across a new dimension `ensemble_member`
dset_anom = xr.concat(list(computed_results.values()), dim='ensemble_member')
dset_anom['ensemble_member'] = list(computed_results.keys())
```

becomes 


```python
dset = xr.open_mfdataset(sorted(files), concat_dim='ensemble_member', 
                         combine="nested", parallel=True, data_vars=['tos'],
                         engine="netcdf4", chunks={'time': 90})
# Add coordinate labels for the newly created `ensemble_member` dimension
dset["ensemble_member"] = ['r11i1p1f1', 'r7i1p1f1', 'r8i1p1f1', 'r9i1p1f1'] 
# Compute anomaly
gb = dset.tos.groupby('time.month')
tos_anom = gb - gb.mean(dim='time')
```

This latter version is cleaner and easier to maintain than the version with loops. 



In [None]:
cluster.close()
client.close()

In [None]:
%load_ext watermark
%watermark --time --python --updated --iversion

---

## Learn More

Visit the [Parallel computing with Dask documentation](https://xarray.pydata.org/en/stable/user-guide/dask.html), and the [dask array best practices](https://docs.dask.org/en/latest/array-best-practices.html) which provides advice on using `dask.array` well.

## Resources and references

* Reference
    *  [Dask Docs](https://dask.org/)
    *  [Dask Blog](https://blog.dask.org/)
    *  [Xarray Docs](https://xarray.pydata.org/)
  
*  Ask for help
    *   [`dask`](http://stackoverflow.com/questions/tagged/dask) tag on Stack Overflow, for usage questions
    *   [github discussions (dask):](https://github.com/dask/dask/discussions) for general, non-bug, discussion, and usage questions
    *   [github issues (dask): ](https://github.com/dask/dask/issues/new) for bug reports and feature requests
     *   [github discussions (xarray): ](https://github.com/pydata/xarray/discussions) for general, non-bug, discussion, and usage questions
    *   [github issues (xarray): ](https://github.com/pydata/xarray/issues/new) for bug reports and feature requests
    
* Pieces of this notebook are adapted from the following sources
  * https://github.com/xarray-contrib/xarray-tutorial/blob/master/scipy-tutorial/06_xarray_and_dask.ipynb
  
  
  
 <div class="admonition alert alert-success">
    <p class="title" style="font-weight:bold">Previous: <a href="./08-dask-delayed.ipynb">Dask Delayed</a></p>
     <p class="title" style="font-weight:bold">Next: <a href="./10-dask-and-xarray.ipynb">Dask and Xarray</a></p>
    
</div>