In [1]:
import abtem
import ase
import dask
from IPython.display import Image, display 

dask.config.set({"array.svg.size": 90});

# Parallelization and Dask

The computational cost of running multislice simulations can grow large depending on the number of probe positions, phonons and many other factors. This cost can be mitigated by using parallelism. Much of the necessary work in abTEM is [embarrasingly parallel](https://en.wikipedia.org/wiki/Embarrassingly_parallel), for example, every probe position is independent, thus each CPU core may calculate a batch of positions independently, only requiring communication after finishing a run of the multislice algorithm. 

abTEM is parallelized using [Dask](https://www.dask.org/){cite}`dask`. Dask allows scaling from a single laptop to hundreds of nodes at high-performance computing (HPC) facilities with minimal changes to the code. 

In this document, we introduce how abTEM uses Dask, this is not required knowledge for running abTEM on a single machine, nonetheless, it may still help you optimize your simulations. If you are already an experienced Dask user, most of what you already know can be applied to using abTEM. If you are new to Dask you may benefit from watching [this introduction](https://www.youtube.com/watch?v=nnndxbr_Xq4) before continuing. We note that Dask is used in several other libraries in electron microscopy, for example, [hyperspy](https://hyperspy.org/), [libertem](https://libertem.github.io/LiberTEM/) and [py4DSTEM](https://py4dstem.readthedocs.io/en/latest/), and we think that you may benefit from knowing this library more generally.

## Task graphs

Simulating TEM experiments requires executing multiple tasks where each task may depend on the output of previous tasks. In Dask this is represented as a [*task graph*](https://docs.dask.org/en/stable/graphs.html), where each task is a node with edges between nodes if the task is dependent on another task. The simulation result is obtained by executing each task (node) in the graph with a Dask scheduler on a single machine or a cluster.

In [2]:
display(
    Image(url="https://docs.dask.org/en/stable/_images/dask-overview.svg", width=600)
)

Below we create the task graph for running a multislice simulation using plane waves with gold in the $\left<100\right>$ zone axis with 4 frozen phonons. 

In [130]:
atoms = ase.build.bulk("Au", cubic=True) * (5, 5, 2)

frozen_phonons = abtem.FrozenPhonons(
    atoms, num_configs=4, sigmas=0.1, ensemble_mean=False
)

potential = abtem.Potential(frozen_phonons, gpts=1024, slice_thickness=2)

probe = abtem.PlaneWave(energy=200e3)

exit_waves = probe.multislice(potential)

The result is an ensemble of $4$ wave functions of shape $512\times512$, which may be represented as a 3D array, where the first dimension represents the phonon ensemble and the last $2$ dimensions represents the 2d wave functions. 

As we have not executed the task graph yet, the wave functions are represented as a [Dask Array](https://docs.dask.org/en/stable/array.html). We can think of the Dask array as being composed of many smaller NumPy arrays, called *chunks*, and operations may be applied to each chunk rather than the full array. This enables 

1. Parallelism over the chunks 
2. Representing a larger-than-memory array as many smaller arrays which each fits in memory

The Dask array `__repr__` shows how the chunks are laid out.

In [131]:
exit_waves.array

Unnamed: 0,Array,Chunk
Bytes,32.00 MiB,8.00 MiB
Shape,"(4, 1024, 1024)","(1, 1024, 1024)"
Count,24 Tasks,4 Chunks
Type,complex64,numpy.ndarray
"Array Chunk Bytes 32.00 MiB 8.00 MiB Shape (4, 1024, 1024) (1, 1024, 1024) Count 24 Tasks 4 Chunks Type complex64 numpy.ndarray",1024  1024  4,

Unnamed: 0,Array,Chunk
Bytes,32.00 MiB,8.00 MiB
Shape,"(4, 1024, 1024)","(1, 1024, 1024)"
Count,24 Tasks,4 Chunks
Type,complex64,numpy.ndarray


We see that the Dask array has the shape `(4, 512, 512)` requiring 8 MB of memory, this is composed of chunks with a shape `(1, 512, 512)` requiring 2 MB each. We stress that the Dask array just represents a task graph, hence, memory is consumed only if it is computed.

Each chunk of the Dask array created above represents a wave function for a frozen phonon configuration. This reflects that, in the multislice algorithm, each frozen phonon configuration is independent and may be calculated in parallel. On the other hand, we should not have chunks across wave functions because each part of the wave function is affected by every other part.

We can visualize the task graph using Dasks [`visualize`](https://docs.dask.org/en/stable/graphviz.html) method. We see that the task graph consists of 4 fully independent branches, one for each frozen phonon.

````{note}
Drawing dask graphs with the cytoscape engine requires the `ipycytoscape` python library. To reproduce the result below you need to:
```
  python -m pip install ipycytoscape
```
and restart jupyter.
````

In [132]:
exit_waves.array.visualize(engine="cytoscape")

CytoscapeWidget(cytoscape_layout={'name': 'dagre', 'rankDir': 'BT', 'nodeSep': 10, 'edgeSep': 10, 'spacingFact…

We usually take the mean across the frozen phonon dimension, thus, we end up with an image represented as a single chunk.

In [133]:
hrtem_image = exit_waves.intensity().mean(0)

hrtem_image.array

Unnamed: 0,Array,Chunk
Bytes,4.00 MiB,4.00 MiB
Shape,"(1024, 1024)","(1024, 1024)"
Count,35 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.00 MiB 4.00 MiB Shape (1024, 1024) (1024, 1024) Count 35 Tasks 1 Chunks Type float32 numpy.ndarray",1024  1024,

Unnamed: 0,Array,Chunk
Bytes,4.00 MiB,4.00 MiB
Shape,"(1024, 1024)","(1024, 1024)"
Count,35 Tasks,1 Chunks
Type,float32,numpy.ndarray


Taking the mean across frozen phonon chunks requires communicating the exit wave function intensity. Showing the task graph we see how the branches are merged where the result have to be communicated between workers.

In [134]:
hrtem_image.array.visualize(engine="cytoscape")

CytoscapeWidget(cytoscape_layout={'name': 'dagre', 'rankDir': 'BT', 'nodeSep': 10, 'edgeSep': 10, 'spacingFact…

## Chunks
To futher explore the role of chunks in abTEM, we create the task graph for running a STEM simulation with gold in the $\left<100\right>$ zone axis with 4 frozen phonons. We do not immediately apply a detector, hence we obtain an ensemble of exit wave functions. 

In [153]:
probe = abtem.Probe(energy=200e3, semiangle_cutoff=20)

scan = abtem.GridScan.from_fractional_coordinates(
    potential,
    start=(0, 0),
    end=(1 / 5, 1 / 5),
)

frozen_phonons = abtem.FrozenPhonons(
    atoms, num_configs=4, sigmas=0.1, ensemble_mean=True
)

potential = abtem.Potential(frozen_phonons, gpts=512, slice_thickness=2)

exit_waves_stem = probe.multislice(potential, scan=scan)

exit_waves_stem.axes_metadata

type               label           coordinates
-----------------  --------------  -------------------
FrozenPhononsAxis  Frozen phonons  -
ScanAxis           x [Å]           0.00 0.27 ... 3.81
ScanAxis           y [Å]           0.00 0.27 ... 3.81
RealSpaceAxis      x [Å]           0.00 0.04 ... 20.36
RealSpaceAxis      y [Å]           0.00 0.04 ... 20.36

The wave functions are represented as a 5D Dask array, a 3D ensemble of 2D wave functions, the ensemble is composed of one phonon dimension and 2 scan dimensions, one for each of the $x$ and $y$ direction. 

The full array is of shape `(4, 15, 15, 512, 512)` requiring 1.76 GB of memory, this is cut into chunks of shape `(1, 8, 7, 512, 512)` of 112 MB. Hence, there is a total of $4 \times 2 \times 3 = 24$ chunks.

In [154]:
exit_waves_stem.array

Unnamed: 0,Array,Chunk
Bytes,1.76 GiB,112.00 MiB
Shape,"(4, 15, 15, 512, 512)","(1, 8, 7, 512, 512)"
Count,69 Tasks,24 Chunks
Type,complex64,numpy.ndarray
"Array Chunk Bytes 1.76 GiB 112.00 MiB Shape (4, 15, 15, 512, 512) (1, 8, 7, 512, 512) Count 69 Tasks 24 Chunks Type complex64 numpy.ndarray",15  4  512  512  15,

Unnamed: 0,Array,Chunk
Bytes,1.76 GiB,112.00 MiB
Shape,"(4, 15, 15, 512, 512)","(1, 8, 7, 512, 512)"
Count,69 Tasks,24 Chunks
Type,complex64,numpy.ndarray


We do not make a chunk for every probe position, instead each chunk of the scan dimension represents a batch of wave functions. This is done partly to limit [the overhead](https://docs.dask.org/en/stable/best-practices.html#avoid-very-large-graphs) that every chunk comes with, more importantly, larger batches enables efficient thread parallelization within each run of the multislice algorithm.

We can change how many wave functions each batch should have using the `max_batch` keyword. Below we set `max_batch=4`, resulting in a total number of $4 \times 8 \times 8 = 256$ chunks.

In [155]:
exit_waves_stem = probe.multislice(potential, scan=scan, max_batch=32)

exit_waves_stem.array

Unnamed: 0,Array,Chunk
Bytes,1.76 GiB,60.00 MiB
Shape,"(4, 15, 15, 512, 512)","(1, 6, 5, 512, 512)"
Count,94 Tasks,36 Chunks
Type,complex64,numpy.ndarray
"Array Chunk Bytes 1.76 GiB 60.00 MiB Shape (4, 15, 15, 512, 512) (1, 6, 5, 512, 512) Count 94 Tasks 36 Chunks Type complex64 numpy.ndarray",15  4  512  512  15,

Unnamed: 0,Array,Chunk
Bytes,1.76 GiB,60.00 MiB
Shape,"(4, 15, 15, 512, 512)","(1, 6, 5, 512, 512)"
Count,94 Tasks,36 Chunks
Type,complex64,numpy.ndarray


The default value of `max_batch` is `"auto"`, with this setting the number of wave functions in each batch is determined such that the batch represents approximately `128 MB` of memory, this number may be changed through the configuration.

Before computing, we apply a HAADF detector and calculate the ensemble mean, this reduces the total size of the output to just 900 B. We note that the 1.76 GB ensemble of wave functions never needs to be in memory simulateneously, each chunk of exit wave functions are reduced immediately after completing the multislice algorithm. 

In [165]:
detector = abtem.AnnularDetector(inner=65, outer=200)

haadf_images = detector.detect(exit_waves_stem).reduce_ensemble()

haadf_images.array

Unnamed: 0,Array,Chunk
Bytes,900 B,120 B
Shape,"(15, 15)","(6, 5)"
Count,229 Tasks,9 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 900 B 120 B Shape (15, 15) (6, 5) Count 229 Tasks 9 Chunks Type float32 numpy.ndarray",15  15,

Unnamed: 0,Array,Chunk
Bytes,900 B,120 B
Shape,"(15, 15)","(6, 5)"
Count,229 Tasks,9 Chunks
Type,float32,numpy.ndarray


## Schedulers

After generating a task graph, it needs to be executed on (parallel) hardware. This is the job of a [task scheduler](https://docs.dask.org/en/stable/scheduler-overview.html). Dask provides several task schedulers, each will compute a task graph and give the same result, but with different performance characteristics.

Every time you call the `compute` method a Dask scheduler is used. abTEM adopts the default Dask scheduler configuration and every keyword argument used with the `compute` method in abTEM is forwarded to the Dask `compute` function. 

### Local scheduler

The default scheduler is the [`ThreadPoolExecutor`](https://docs.dask.org/en/stable/scheduling.html#local-threads), keyword arguments for the scheduler may be passed through the `compute` method. For example, the threaded scheduler take a `num_workers` keyword, which sets the number threads to use (defaults to number of cores).

In [166]:
haadf_images = detector.detect(exit_waves_stem).reduce_ensemble()

haadf_images.compute(
    scheduler="threads", num_workers=16
)

[########################################] | 100% Completed | 19.38 s


<abtem.measurements.Images object at 0x000002DB891834F0>

We can change the scheduler to using the `ProcessPoolExecutor` as below.

```python
haadf_images = detector.detect(exit_waves_stem).compute(scheduler="processes", num_workers=4)
```

Using `abtem.config.set` the scheduler can be set either as a context manager or globally.
```python
# As a context manager
with abtem.config.set(scheduler="processes"):
    haadf_images.compute()

# Set globally
abtem.config.set(scheduler="processes")
haadf_images.compute()
```

To improve performance we have to be able to profile it. Profiling parallel code can be challenging, but Dask provides functionality to aid in profiling and inspecting execution. The diagnostic tools are quite different depending on whether you use a local or distributed scheduler.

Dask allows local diagnostics by adding callback that collects information about your code execution. You can use the profilers as a context manager as described in the Dask documentation. For convenience the abTEM `compute` methods implements keywords for adding profilers.
Below we use the `Profiler` to monitor task execution by setting `profiler=True` and a `ResourceProfiler` to monitor the CPU usage and memory consumption by setting `resource_profiler=True`. We rerun the simulation above with these profilers. 

In [167]:
haadf_images = detector.detect(exit_waves_stem).reduce_ensemble()

haadf_images, profilers = haadf_images.compute(profiler=True, resource_profiler=True)

[########################################] | 100% Completed | 19.31 s


To display the results Dask uses the Bokeh plotting library (this is also installed together with Dask). To display the plots in a Jupyter notebook we need to run the commands.

In [168]:
from bokeh.io import output_notebook
output_notebook()

We first show the result from the `Profiler` object: This shows the execution time for each task as a rectangle, organized along the y-axis by worker (in this case threads), white space represents idle . The task types are grouped by color and, by hovering over each task, one can see the key and task that each block represents. For this calculation there is only one significant task shown in yellow; the task encompasses building the wave function, running the multislice algorithm (calculating the potential on demand) and calculating and integrating the diffraction patterns. 

In [169]:
profilers[0].visualize();

The result from the `ResourceProfiler` is shown below: This shows two lines, one for total CPU percentage used by all the workers, and one for total memory usage. The CPU usage is scaled so each worker contributes up to $100 \ \%$, i.e. two fully utilized workers use $200 \ \%$.

In [170]:
profilers[1].visualize();

We ran the calculation on a single computer with an 8 core CPU with 2 threads per core for 16 threads total. We see that our peak CPU usage was $\sim 1300 \%$, this is a fairly typical usage statistic for a single machine, overhead and background processes limited us from reaching the the $1600 \ \%$ corresponding to using every worker maximally.

We also see that the total memory use is even higher than the size of every wave function combined. This is because every parallel run of the multislice algorithm requires a significant overhead for intermediate results (such as potential slices and fresnel propagators). If your calculation runs out of memory you can lower the number of workers, thus trading away computational speed for lower memory consumption.

We note that the overhead in both CPU usage and memory diminishes for larger simulation with more powerful hardware.

### The distributed scheduler (locally)
The Dask distributed scheduler is necessary for running your simulation a cluster, however, it also runs [locally on a personal machine](https://docs.dask.org/en/stable/scheduling.html#dask-distributed-local). You can find details in the Dask documentation . We demonstrate the basics below.

You can use the Dask distributed scheduler by just initializing a Dask `Client`. The `Client` takes keyword arguments such as `num_workers`.

In [174]:
from dask.distributed import Client

client = Client(n_workers=16)
client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 56043 instead


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:56043/status,

0,1
Dashboard: http://127.0.0.1:56043/status,Workers: 16
Total threads: 16,Total memory: 31.93 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:56046,Workers: 16
Dashboard: http://127.0.0.1:56043/status,Total threads: 16
Started: Just now,Total memory: 31.93 GiB

0,1
Comm: tcp://127.0.0.1:56182,Total threads: 1
Dashboard: http://127.0.0.1:56183/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56061,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-os0hg871,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-os0hg871

0,1
Comm: tcp://127.0.0.1:56206,Total threads: 1
Dashboard: http://127.0.0.1:56207/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56064,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-vsfwdgv0,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-vsfwdgv0

0,1
Comm: tcp://127.0.0.1:56188,Total threads: 1
Dashboard: http://127.0.0.1:56190/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56060,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-68bf8bb1,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-68bf8bb1

0,1
Comm: tcp://127.0.0.1:56189,Total threads: 1
Dashboard: http://127.0.0.1:56191/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56059,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-9tybp6e1,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-9tybp6e1

0,1
Comm: tcp://127.0.0.1:56197,Total threads: 1
Dashboard: http://127.0.0.1:56198/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56056,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-bk4c38fa,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-bk4c38fa

0,1
Comm: tcp://127.0.0.1:56179,Total threads: 1
Dashboard: http://127.0.0.1:56180/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56054,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-1d7_bvs7,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-1d7_bvs7

0,1
Comm: tcp://127.0.0.1:56159,Total threads: 1
Dashboard: http://127.0.0.1:56165/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56053,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-raasg083,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-raasg083

0,1
Comm: tcp://127.0.0.1:56176,Total threads: 1
Dashboard: http://127.0.0.1:56177/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56049,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-it4pyvnx,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-it4pyvnx

0,1
Comm: tcp://127.0.0.1:56158,Total threads: 1
Dashboard: http://127.0.0.1:56164/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56063,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-fyej618e,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-fyej618e

0,1
Comm: tcp://127.0.0.1:56185,Total threads: 1
Dashboard: http://127.0.0.1:56186/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56062,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-dh86zdxl,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-dh86zdxl

0,1
Comm: tcp://127.0.0.1:56209,Total threads: 1
Dashboard: http://127.0.0.1:56210/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56058,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-7euf0ycv,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-7euf0ycv

0,1
Comm: tcp://127.0.0.1:56212,Total threads: 1
Dashboard: http://127.0.0.1:56215/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56057,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-otmirzl7,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-otmirzl7

0,1
Comm: tcp://127.0.0.1:56213,Total threads: 1
Dashboard: http://127.0.0.1:56214/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56055,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-zs37op4v,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-zs37op4v

0,1
Comm: tcp://127.0.0.1:56200,Total threads: 1
Dashboard: http://127.0.0.1:56201/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56052,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-vps0ibni,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-vps0ibni

0,1
Comm: tcp://127.0.0.1:56194,Total threads: 1
Dashboard: http://127.0.0.1:56195/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56051,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-p5jgcq10,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-p5jgcq10

0,1
Comm: tcp://127.0.0.1:56203,Total threads: 1
Dashboard: http://127.0.0.1:56204/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:56050,
Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-6kf9qcif,Local directory: C:\Users\jacob\AppData\Local\Temp\dask-worker-space\worker-6kf9qcif


After intializing the client object any abTEM computation will use the Dask distributed scheduler.

In [101]:
haadf_images = detector.detect(exit_waves_stem).reduce_ensemble()
haadf_images.compute()

<abtem.measurements.Images object at 0x000002DB865FC160>

A benefit of using the distributed scheduler with abTEM on a single machine is the live diagnostic dashboard.

In [118]:
client.close()

### Running abTEM on HPC clusters

Dask (and thus abTEM) has robust tools for deployment on high-performance compute clusters. 

Your HPC provider may be able to advice you on how to best deploy Dask applications on your 

## Using GPUs

Almost every part of abTEM can be accelerated using a GPU through the [CuPy](https://cupy.dev/) library. We have only tested abTEM on CUDA compatiable GPUs, however, any GPU compatible with CuPy should work.

If you have a compatiable GPU and a working installation of CuPy, you can accelerate your image simulations by simply changing the configs at the top of your document as below:

In [152]:
abtem.config.set({"device": "cpu"});

## Performance tips

### Optimize your simulation

### Use PRISM

### Change the FFT backend

### Change the batch size

### Use power of 2 `gpts`

### Running out of memory