In [1]:
import abtem
import ase
import dask
from IPython.display import Image, display

dask.config.set({"array.svg.size": 90});

# Parallelization and Dask

The computational cost of running multislice simulations can grow large depending on the number of probe positions, phonons and many other factors. This cost can be mitigated by using parallelism. Much of the necessary work in abTEM is [embarrasingly parallel](https://en.wikipedia.org/wiki/Embarrassingly_parallel), for example, every probe position is independent, thus each CPU core may calculate a batch of positions independently, only requiring communication after finishing a run of the multislice algorithm. 

abTEM is parallelized using [Dask](https://www.dask.org/){cite}`dask`. Dask allows scaling from a single laptop to hundreds of nodes at high-performance computing (HPC) facilities with minimal changes to the code. 

In this document, we introduce how abTEM uses Dask, this is not required knowledge for running abTEM on a single machine, nonetheless, it may still help you optimize your simulations. If you are already an experienced Dask user, most of what you already know can be applied to using abTEM. If you are new to Dask you may benefit from watching [this introduction](https://www.youtube.com/watch?v=nnndxbr_Xq4) before continuing. We note that Dask is used in several other libraries in electron microscopy, for example, [hyperspy](https://hyperspy.org/), [libertem](https://libertem.github.io/LiberTEM/) and [py4DSTEM](https://py4dstem.readthedocs.io/en/latest/), and we think that you may benefit from knowing this library more generally.

## Task graphs

Simulating TEM experiments requires executing multiple tasks where each task may depend on the output of previous tasks. In Dask this is represented as a [*task graph*](https://docs.dask.org/en/stable/graphs.html), where each task is a node with edges between nodes if the task is dependent on another task. The simulation result is obtained by executing each task (node) in the graph with a Dask scheduler on a single machine or a cluster.

In [2]:
display(
    Image(url="https://docs.dask.org/en/stable/_images/dask-overview.svg", width=600)
)

Below we create the task graph for running a multislice simulation using plane waves with gold in the $\left<100\right>$ zone axis with 4 frozen phonons. 

In [3]:
atoms = ase.build.bulk("Au", cubic=True) * (5, 5, 2)

frozen_phonons = abtem.FrozenPhonons(
    atoms, num_configs=4, sigmas=0.1, ensemble_mean=False
)

potential = abtem.Potential(frozen_phonons, gpts=512, slice_thickness=2)

probe = abtem.PlaneWave(energy=200e3)

exit_waves = probe.multislice(potential)

The result is an ensemble of $4$ wave functions of shape $512\times512$, which may be represented as a 3D array, where the first dimension represents the phonon ensemble and the last $2$ dimensions represents the 2d wave functions. 

As we have not executed the task graph yet, the wave functions are represented as a [Dask Array](https://docs.dask.org/en/stable/array.html). We can think of the Dask array as being composed of many smaller NumPy arrays, called *chunks*, and operations may be applied to each chunk rather than the full array. This enables 

1. Parallelism over the chunks 
2. Representing a larger-than-memory array as many smaller arrays which each fits in memory

The Dask array `__repr__` shows how the chunks are laid out.

In [4]:
exit_waves.array

Unnamed: 0,Array,Chunk
Bytes,8.00 MiB,2.00 MiB
Shape,"(4, 512, 512)","(1, 512, 512)"
Count,15 Graph Layers,4 Chunks
Type,complex64,numpy.ndarray
"Array Chunk Bytes 8.00 MiB 2.00 MiB Shape (4, 512, 512) (1, 512, 512) Count 15 Graph Layers 4 Chunks Type complex64 numpy.ndarray",512  512  4,

Unnamed: 0,Array,Chunk
Bytes,8.00 MiB,2.00 MiB
Shape,"(4, 512, 512)","(1, 512, 512)"
Count,15 Graph Layers,4 Chunks
Type,complex64,numpy.ndarray


We see that the Dask array has the shape `(4, 512, 512)` requiring 8 MB of memory, this is composed of chunks with a shape `(1, 512, 512)` requiring 2 MB each. We stress that the Dask array just represents a task graph, hence, memory is consumed only if it is computed.

Each chunk of the Dask array created above represents a wave function for a frozen phonon configuration. This reflects that, in the multislice algorithm, each frozen phonon configuration is independent and may be calculated in parallel. On the other hand, we should not have chunks across wave functions because each part of the wave function is affected by every other part.

We can visualize the task graph using Dasks [`visualize`](https://docs.dask.org/en/stable/graphviz.html) method. We see that the task graph consists of 4 fully independent branches, one for each frozen phonon.

````{note}
Drawing dask graphs with the cytoscape engine requires the `ipycytoscape` python library. To reproduce the result below you need to:
```
  python -m pip install ipycytoscape
```
and restart jupyter.
````

In [5]:
exit_waves.array.visualize(engine="cytoscape")

CytoscapeWidget(cytoscape_layout={'name': 'dagre', 'rankDir': 'BT', 'nodeSep': 10, 'edgeSep': 10, 'spacingFact…

We usually take the mean across the frozen phonon dimension, thus, we end up with an image represented as a single chunk.

In [6]:
hrtem_image = exit_waves.intensity().mean(0)

hrtem_image.array

Unnamed: 0,Array,Chunk
Bytes,1.00 MiB,1.00 MiB
Shape,"(512, 512)","(512, 512)"
Count,19 Graph Layers,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 1.00 MiB 1.00 MiB Shape (512, 512) (512, 512) Count 19 Graph Layers 1 Chunks Type float32 numpy.ndarray",512  512,

Unnamed: 0,Array,Chunk
Bytes,1.00 MiB,1.00 MiB
Shape,"(512, 512)","(512, 512)"
Count,19 Graph Layers,1 Chunks
Type,float32,numpy.ndarray


Taking the mean across frozen phonon chunks requires communicating the exit wave function intensity. Showing the task graph we see how the branches are merged where the result have to be communicated between workers.

In [7]:
hrtem_image.array.visualize(engine="cytoscape")

CytoscapeWidget(cytoscape_layout={'name': 'dagre', 'rankDir': 'BT', 'nodeSep': 10, 'edgeSep': 10, 'spacingFact…

## Chunks
To futher explore the role of chunks in abTEM, we create the task graph for running a STEM simulation with gold in the $\left<100\right>$ zone axis with 4 frozen phonons. We do not immediately apply a detector, hence we get an ensemble of exit wave functions. 

In [225]:
probe = abtem.Probe(energy=200e3, semiangle_cutoff=20)

scan = abtem.GridScan.from_fractional_coordinates(
    potential,
    start=(0, 0),
    end=(1 / 5, 1 / 5),
    sampling=1,
)

frozen_phonons = abtem.FrozenPhonons(
    atoms, num_configs=4, sigmas=0.1, ensemble_mean=False
)

potential = abtem.Potential(frozen_phonons, gpts=512, slice_thickness=2)

exit_waves_stem = probe.multislice(potential, scan=scan)

exit_waves_stem.axes_metadata

type               label           coordinates
-----------------  --------------  -------------------
FrozenPhononsAxis  Frozen phonons  -
ScanAxis           x [Å]           0.00 0.82 ... 3.26
ScanAxis           y [Å]           0.00 0.82 ... 3.26
RealSpaceAxis      x [Å]           0.00 0.04 ... 20.36
RealSpaceAxis      y [Å]           0.00 0.04 ... 20.36

The wave functions are represented as a 5d Dask array, a 3d ensemble of 2d wave functions, the ensemble is composed of one phonon dimension and 2 scan dimensions, one for each of the $x$ and $y$ direction. 

The full array is of shape `(4, 15, 15, 512, 512)` requiring 1.76 GB of memory, this is cut into chunks of shape `(1, 8, 7, 512, 512)` of 112 MB. Hence, there is a total of $4 \times 2 \times 3 = 24$ chunks.

In [226]:
exit_waves_stem.array

Unnamed: 0,Array,Chunk
Bytes,200.00 MiB,50.00 MiB
Shape,"(4, 5, 5, 512, 512)","(1, 5, 5, 512, 512)"
Count,16 Graph Layers,4 Chunks
Type,complex64,numpy.ndarray
"Array Chunk Bytes 200.00 MiB 50.00 MiB Shape (4, 5, 5, 512, 512) (1, 5, 5, 512, 512) Count 16 Graph Layers 4 Chunks Type complex64 numpy.ndarray",5  4  512  512  5,

Unnamed: 0,Array,Chunk
Bytes,200.00 MiB,50.00 MiB
Shape,"(4, 5, 5, 512, 512)","(1, 5, 5, 512, 512)"
Count,16 Graph Layers,4 Chunks
Type,complex64,numpy.ndarray


We do not make a chunk for every probe position, instead each chunk of the scan dimension represents a batch of wave functions. This is done partly to limit [the overhead](https://docs.dask.org/en/stable/best-practices.html#avoid-very-large-graphs) that every chunk comes with, more importantly, larger batches enables efficient thread parallelization within each run of the multislice algorithm.

We can change how many wave functions each batch should have using the `max_batch` keyword. Below we set `max_batch=4`, resulting in a total number of $4 \times 8 \times 8 = 256$ chunks.

In [202]:
exit_waves_stem = probe.multislice(potential, scan=scan, max_batch=32)

exit_waves_stem.array

Unnamed: 0,Array,Chunk
Bytes,400.00 MiB,50.00 MiB
Shape,"(8, 5, 5, 512, 512)","(1, 5, 5, 512, 512)"
Count,28 Graph Layers,8 Chunks
Type,complex64,numpy.ndarray
"Array Chunk Bytes 400.00 MiB 50.00 MiB Shape (8, 5, 5, 512, 512) (1, 5, 5, 512, 512) Count 28 Graph Layers 8 Chunks Type complex64 numpy.ndarray",5  8  512  512  5,

Unnamed: 0,Array,Chunk
Bytes,400.00 MiB,50.00 MiB
Shape,"(8, 5, 5, 512, 512)","(1, 5, 5, 512, 512)"
Count,28 Graph Layers,8 Chunks
Type,complex64,numpy.ndarray


The default value of `max_batch` is `"auto"`, with this setting the number of wave functions in each batch is determined such that the batch represents approximately `128 MB` of memory, this number may be changed through the configuration.

Before running `compute`, we apply a HAADF detector and calculate the ensemble mean, this reduces the total size of the output to just 400 B. It is important to stress that the entire 1.76 GB ensemble of wave functions never had to be in memory simulateneously. Each chunk of exit wave functions will be reduced immediately after completing the multislice algorithm. 

In [222]:
detector = abtem.AnnularDetector(inner=65, outer=200)

haadf_images = detector.detect(exit_waves_stem).compute()

[########################################] | 100% Completed | 3.00 sms


<abtem.measurements.Images object at 0x7f9894f16a60>

## Schedulers

After generating a task graph, it needs to be executed on (parallel) hardware. This is the job of a [task scheduler](https://docs.dask.org/en/stable/scheduler-overview.html). Dask provides several task schedulers, each will compute a task graph and give the same result, but with different performance characteristics.

Every time you call the `compute` method a Dask scheduler is used. abTEM adopts the default Dask scheduler configuration and every keyword argument used with the `compute` method in abTEM is forwarded to the Dask `compute` function. 

The default scheduler is the [`ThreadPoolExecutor`](https://docs.dask.org/en/stable/scheduling.html#local-threads), keyword arguments for the scheduler may be passed through the `compute` method. For example, the threaded scheduler take a `num_workers` keyword, which sets the number threads to use (defaults to number of cores).

In [224]:
haadf_images = detector.detect(exit_waves_stem).compute(
    scheduler="threads", num_workers=4
)

[########################################] | 100% Completed | 3.06 sms


We can change the scheduler to using the `ProcessPoolExecutor` as below.

```python
haadf_images = detector.detect(exit_waves_stem).compute(scheduler="processes", num_workers=4)
```

Using `abtem.config.set` the scheduler can be set either as a context manager or globally.
```python
# As a context manager
with abtem.config.set(scheduler="processes"):
    haadf_images = detector.detect(exit_waves_stem).compute()

# Set globally
abtem.config.set(scheduler="processes")
haadf_images = detector.detect(exit_waves_stem).compute()
```

### The distributed scheduler
You can use the Dask distributed scheduler by just initializing a `Client`. 
The Dask distributed scheduler is necessary for  a cluster, however, it may also  locally on a personal machine. The main benefit of using the distributed scheduler locally


https://docs.dask.org/en/stable/scheduling.html#dask-distributed-local
The Dask distributed scheduler 

In [209]:
from dask.distributed import Client

client = Client()
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads: 8,Total memory: 16.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:52971,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 16.00 GiB

0,1
Comm: tcp://127.0.0.1:52986,Total threads: 2
Dashboard: http://127.0.0.1:52991/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:52974,
Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-mly95cqx,Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-mly95cqx

0,1
Comm: tcp://127.0.0.1:52988,Total threads: 2
Dashboard: http://127.0.0.1:52990/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:52977,
Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-1g90e8js,Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-1g90e8js

0,1
Comm: tcp://127.0.0.1:52989,Total threads: 2
Dashboard: http://127.0.0.1:52992/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:52976,
Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-6pz6vp7i,Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-6pz6vp7i

0,1
Comm: tcp://127.0.0.1:52987,Total threads: 2
Dashboard: http://127.0.0.1:52993/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:52975,
Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-8pgf7eqx,Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-8pgf7eqx


In [210]:
client.close()

### Running abTEM on HPC clusters

## Diagnostics

To improve performance we have to be able to profile it. Profiling parallel code can be challenging, but Dask provides functionality to aid in profiling and inspecting execution. The diagnostic tools are quite different depending on whether you use a local or distributed scheduler.

### Local diagnostics

Dask allows local diagnostics by adding callback that collects information about your code execution. You can use the profilers as a context manager as described in the Dask documentation. For convenience the abTEM `compute` methods implements keywords for adding profilers.

Below we use the `Profiler` to monitor task execution by setting `profiler=True` and a `ResourceProfiler` to monitor the CPU usage and memory consumption by setting `resource_profiler=True`. We rerun the simulation above with these profilers. 

In [216]:
haadf_images, profilers = detector.detect(exit_waves_stem).compute(
    profiler=True, resource_profiler=True
)

[########################################] | 100% Completed | 2.87 sms


The Dask profilers are visualized using the plotting library `bokeh`. We need to run the commands below to show the plots in the notebook.

In [217]:
from bokeh.io import output_notebook

output_notebook()

In [220]:
profilers[0].visualize();

We ran this code on an Intel Core-I9 with four cores and two threads per core. Dask shows 


We should not expect $800 \ \%$ , however $400$

This is not because we did not utilize every thread, rather, we only utilized each individual thread less than $100 \ \%$. This is largely because the size  



In [221]:
profilers[1].visualize();

### Distributed

The diagnostic tools of the distributed scheduler are even more powerful and may even be a reason to consider using the distributed scheduler locally. The distributed scheduler provides the Dask danshboard 

http://127.0.0.1:8787/status

## Using GPUs

Almost every part of abTEM can be accelerated using a GPU through the [CuPy](https://cupy.dev/) library. We have only tested abTEM on CUDA compatiable GPUs, however, any GPU compatible with CuPy should work.

If you have a compatiable GPU and a working installation of CuPy, you can accelerate your image simulations by simply changing the configs at the top of your document as below:

In [19]:
abtem.config.set({"device": "cpu"});