In [1]:
import abtem
import ase
import dask
from IPython.display import Image, display 

dask.config.set({"array.svg.size": 90});

(walkthrough:parallelization)=
# Parallelization

The computational cost of running multislice simulations can grow large depending on the number of probe positions, phonons and many other factors. This cost can be mitigated by using parallelism. Much of the necessary work in abTEM is [embarrasingly parallel](https://en.wikipedia.org/wiki/Embarrassingly_parallel), for example, every probe position is independent, thus each CPU core may calculate a batch of positions independently, only requiring communication after finishing a run of the multislice algorithm. 

abTEM is parallelized using [Dask](https://www.dask.org/){cite}`dask`. Dask allows scaling from a single laptop to hundreds of nodes at high-performance computing (HPC) facilities with minimal changes to the code. 

In this document, we introduce how abTEM uses Dask, this is not required knowledge for running abTEM on a single machine, nonetheless, it may still help you optimize your simulations. If you are already an experienced Dask user, most of what you already know can be applied to using abTEM. If you are new to Dask you may benefit from watching [this introduction](https://www.youtube.com/watch?v=nnndxbr_Xq4) before continuing. We note that Dask is used in several other libraries in analysis of electron microscopy, for example, [hyperspy](https://hyperspy.org/), [libertem](https://libertem.github.io/LiberTEM/) and [py4DSTEM](https://py4dstem.readthedocs.io/en/latest/), hence, we think that you may benefit from knowing this library more generally.

## Task graphs

Simulating TEM experiments requires executing multiple tasks where each task may depend on the output of previous tasks. In Dask this is represented as a [*task graph*](https://docs.dask.org/en/stable/graphs.html), where each task is a node with edges between nodes if the task is dependent on another task. The simulation result is obtained by executing each task (node) in the graph with a Dask scheduler on a single machine or a cluster.

In [2]:
display(
    Image(url="https://docs.dask.org/en/stable/_images/dask-overview.svg", width=600)
)

Below we create the task graph for running a multislice simulation using plane waves with gold in the $\left<100\right>$ zone axis with 4 frozen phonons. We refer to previous walkthroughs for details. 

In [3]:
atoms = ase.build.bulk("Au", cubic=True) * (5, 5, 2)

frozen_phonons = abtem.FrozenPhonons(
    atoms, num_configs=4, sigmas=0.1, ensemble_mean=False
)

potential = abtem.Potential(frozen_phonons, gpts=512, slice_thickness=2)

probe = abtem.PlaneWave(energy=200e3)

exit_waves = probe.multislice(potential)

The result is an ensemble of $4$ wave functions of shape $512\times512$, which may be represented as a 3D array, where the first dimension represents the phonon ensemble and the last $2$ dimensions represents the 2D wave functions. 

As we have not executed the task graph yet, the wave functions are represented as a [Dask Array](https://docs.dask.org/en/stable/array.html). We can think of the Dask array as being composed of many smaller NumPy arrays, called *chunks*, and operations may be applied to each chunk rather than the full array. This enables: 

1. Parallelism over the chunks 
2. Representing a larger-than-memory array as many smaller arrays which each fits in memory

The Dask array `__repr__` shows how the chunks are laid out.

In [4]:
exit_waves.array

Unnamed: 0,Array,Chunk
Bytes,8.00 MiB,2.00 MiB
Shape,"(4, 512, 512)","(1, 512, 512)"
Count,15 Graph Layers,4 Chunks
Type,complex64,numpy.ndarray
"Array Chunk Bytes 8.00 MiB 2.00 MiB Shape (4, 512, 512) (1, 512, 512) Count 15 Graph Layers 4 Chunks Type complex64 numpy.ndarray",512  512  4,

Unnamed: 0,Array,Chunk
Bytes,8.00 MiB,2.00 MiB
Shape,"(4, 512, 512)","(1, 512, 512)"
Count,15 Graph Layers,4 Chunks
Type,complex64,numpy.ndarray


We see that the Dask array has the shape `(4, 512, 512)` requiring 8 MB of memory, this is composed of chunks with shape `(1, 512, 512)` requiring 2 MB each. We stress that the Dask array just represents a task graph, hence, memory is consumed only if it is computed.

Each chunk of the Dask array created above represents a wave function for a frozen phonon configuration. This reflects that, in the multislice algorithm, each frozen phonon configuration is independent and may be calculated in parallel. On the other hand, we should not have chunks across wave functions because each part of the wave function is affected by every other part.

We can visualize the task graph using Dasks [`visualize`](https://docs.dask.org/en/stable/graphviz.html) method. We see that the task graph consists of 4 fully independent branches, one for each frozen phonon.

````{note}
Drawing Dask task graphs with the cytoscape engine requires the `ipycytoscape` python library. To reproduce the result below you need to:
```
  python -m pip install ipycytoscape
```
and restart jupyter.
````

In [5]:
exit_waves.array.visualize(engine="cytoscape")

CytoscapeWidget(cytoscape_layout={'name': 'dagre', 'rankDir': 'BT', 'nodeSep': 10, 'edgeSep': 10, 'spacingFact…

We usually take the mean across the frozen phonon dimension, thus, we end up with an image represented as a single chunk.

In [6]:
hrtem_image = exit_waves.intensity().mean(0)

hrtem_image.array

Unnamed: 0,Array,Chunk
Bytes,1.00 MiB,1.00 MiB
Shape,"(512, 512)","(512, 512)"
Count,19 Graph Layers,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 1.00 MiB 1.00 MiB Shape (512, 512) (512, 512) Count 19 Graph Layers 1 Chunks Type float32 numpy.ndarray",512  512,

Unnamed: 0,Array,Chunk
Bytes,1.00 MiB,1.00 MiB
Shape,"(512, 512)","(512, 512)"
Count,19 Graph Layers,1 Chunks
Type,float32,numpy.ndarray


Taking the mean across frozen phonon chunks requires communicating the exit wave function intensity. Showing the task graph we see how the branches are merged where the result have to be communicated between workers.

In [7]:
hrtem_image.array.visualize(engine="cytoscape")

CytoscapeWidget(cytoscape_layout={'name': 'dagre', 'rankDir': 'BT', 'nodeSep': 10, 'edgeSep': 10, 'spacingFact…

## Chunks
To futher explore the role of chunks in abTEM, we create the task graph for running a STEM simulation with gold in the $\left<100\right>$ zone axis with 4 frozen phonons. We do not immediately apply a detector, hence we obtain an ensemble of exit wave functions. 

In [8]:
probe = abtem.Probe(energy=200e3, semiangle_cutoff=20)

scan = abtem.GridScan.from_fractional_coordinates(
    potential,
    start=(0, 0),
    end=(1 / 5, 1 / 5),
)

frozen_phonons = abtem.FrozenPhonons(
    atoms, num_configs=4, sigmas=0.1, ensemble_mean=True
)

potential = abtem.Potential(frozen_phonons, gpts=512, slice_thickness=2)

exit_waves_stem = probe.multislice(potential, scan=scan)

exit_waves_stem.axes_metadata

type               label           coordinates
-----------------  --------------  -------------------
FrozenPhononsAxis  Frozen phonons  -
ScanAxis           x [Å]           0.00 0.27 ... 3.81
ScanAxis           y [Å]           0.00 0.27 ... 3.81
RealSpaceAxis      x [Å]           0.00 0.04 ... 20.36
RealSpaceAxis      y [Å]           0.00 0.04 ... 20.36

The wave functions are represented as a 5D Dask array, a 3D ensemble of 2D wave functions, the ensemble is composed of one phonon dimension and 2 scan dimensions, one for each of the $x$ and $y$ direction. 

The full array is of shape `(4, 15, 15, 512, 512)` requiring 1.76 GB of memory, this is cut into chunks of shape `(1, 8, 7, 512, 512)` of 112 MB. Hence, there is a total of $4 \times 2 \times 3 = 24$ chunks.

In [9]:
exit_waves_stem.array

Unnamed: 0,Array,Chunk
Bytes,1.76 GiB,112.00 MiB
Shape,"(4, 15, 15, 512, 512)","(1, 8, 7, 512, 512)"
Count,17 Graph Layers,24 Chunks
Type,complex64,numpy.ndarray
"Array Chunk Bytes 1.76 GiB 112.00 MiB Shape (4, 15, 15, 512, 512) (1, 8, 7, 512, 512) Count 17 Graph Layers 24 Chunks Type complex64 numpy.ndarray",15  4  512  512  15,

Unnamed: 0,Array,Chunk
Bytes,1.76 GiB,112.00 MiB
Shape,"(4, 15, 15, 512, 512)","(1, 8, 7, 512, 512)"
Count,17 Graph Layers,24 Chunks
Type,complex64,numpy.ndarray


We do not make a chunk for every probe position, instead each chunk of the scan dimension represents a batch of wave functions. This is done partly to limit [the overhead](https://docs.dask.org/en/stable/best-practices.html#avoid-very-large-graphs) that every chunk comes with, more importantly, larger batches enables efficient thread parallelization within each run of the multislice algorithm.

We can change how many wave functions each batch should have using the `max_batch` keyword. Below we set `max_batch=4`, resulting in a total number of $4 \times 8 \times 8 = 256$ chunks.

In [10]:
exit_waves_stem = probe.multislice(potential, scan=scan, max_batch=4)

exit_waves_stem.array

Unnamed: 0,Array,Chunk
Bytes,1.76 GiB,8.00 MiB
Shape,"(4, 15, 15, 512, 512)","(1, 2, 2, 512, 512)"
Count,16 Graph Layers,256 Chunks
Type,complex64,numpy.ndarray
"Array Chunk Bytes 1.76 GiB 8.00 MiB Shape (4, 15, 15, 512, 512) (1, 2, 2, 512, 512) Count 16 Graph Layers 256 Chunks Type complex64 numpy.ndarray",15  4  512  512  15,

Unnamed: 0,Array,Chunk
Bytes,1.76 GiB,8.00 MiB
Shape,"(4, 15, 15, 512, 512)","(1, 2, 2, 512, 512)"
Count,16 Graph Layers,256 Chunks
Type,complex64,numpy.ndarray


The default value of `max_batch` is `"auto"`, with this setting the number of wave functions in each batch is determined such that the batch represents approximately `128 MB` of memory, this number may be changed through the configuration.

Before computing, we apply a HAADF detector and calculate the ensemble mean, this reduces the total size of the output to just 900 B. We note that the 1.76 GB ensemble of wave functions never needs to be in memory simulateneously, each chunk of exit wave functions are reduced immediately after completing the multislice algorithm. 

In [11]:
detector = abtem.AnnularDetector(inner=65, outer=200)

haadf_images = detector.detect(exit_waves_stem).reduce_ensemble()

haadf_images.array

Unnamed: 0,Array,Chunk
Bytes,900 B,16 B
Shape,"(15, 15)","(2, 2)"
Count,21 Graph Layers,64 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 900 B 16 B Shape (15, 15) (2, 2) Count 21 Graph Layers 64 Chunks Type float32 numpy.ndarray",15  15,

Unnamed: 0,Array,Chunk
Bytes,900 B,16 B
Shape,"(15, 15)","(2, 2)"
Count,21 Graph Layers,64 Chunks
Type,float32,numpy.ndarray


## Schedulers

After generating a task graph, it needs to be executed on (parallel) hardware. This is the job of a [task scheduler](https://docs.dask.org/en/stable/scheduler-overview.html). Dask provides several task schedulers, each will compute a task graph and give the same result, but with different performance characteristics.

Every time you call the `compute` method a Dask scheduler is used. abTEM adopts the default Dask scheduler configuration and every keyword argument used with the `compute` method in abTEM is forwarded to the Dask `compute` function. 

### Local scheduler

The default scheduler is the [`ThreadPoolExecutor`](https://docs.dask.org/en/stable/scheduling.html#local-threads), keyword arguments for the scheduler may be passed through the `compute` method. For example, the threaded scheduler take a `num_workers` keyword, which sets the number threads to use (defaults to number of cores).

In [28]:
haadf_images = detector.detect(exit_waves_stem).reduce_ensemble()

haadf_images.compute(
    scheduler="threads", num_workers=8
)

[                                        ] | 1% Completed | 130.74 ms



[####                                    ] | 10% Completed | 8.09 sms


Exception ignored in: <bound method GCDiagnosis._gc_callback of <distributed.utils_perf.GCDiagnosis object at 0x7fec5d13b640>>
Traceback (most recent call last):
  File "/opt/anaconda3/envs/test_threads/lib/python3.9/site-packages/distributed/utils_perf.py", line 179, in _gc_callback
    def _gc_callback(self, phase, info):
KeyboardInterrupt: 
Exception ignored in: <bound method GCDiagnosis._gc_callback of <distributed.utils_perf.GCDiagnosis object at 0x7fec5d13b640>>
Traceback (most recent call last):
  File "/opt/anaconda3/envs/test_threads/lib/python3.9/site-packages/distributed/utils_perf.py", line 179, in _gc_callback
    def _gc_callback(self, phase, info):
KeyboardInterrupt: 

KeyboardInterrupt



We can change the scheduler to using the `ProcessPoolExecutor` as below.

```python
haadf_images = detector.detect(exit_waves_stem).compute(scheduler="processes", num_workers=4)
```

Using `abtem.config.set` the scheduler can be set either as a context manager or globally.
```python
# As a context manager
with abtem.config.set(scheduler="processes"):
    haadf_images.compute()

# Set globally
abtem.config.set(scheduler="processes")
haadf_images.compute()
```

To improve performance we have to be able to profile it. Profiling parallel code can be challenging, but Dask provides functionality to aid in profiling and inspecting execution. The diagnostic tools are quite different depending on whether you use a local or distributed scheduler.

Dask allows local diagnostics by adding callback that collects information about your code execution. You can use the profilers as a context manager as described in the [Dask documentation](https://docs.dask.org/en/stable/diagnostics-local.html). For convenience the abTEM `compute` methods implements keywords for adding Dask profilers.

Below we use the [`Profiler`](https://docs.dask.org/en/stable/diagnostics-local.html#profiler) to monitor task execution by setting `profiler=True` and a [`ResourceProfiler`](https://docs.dask.org/en/stable/diagnostics-local.html#resourceprofiler) to monitor the CPU usage and memory consumption by setting `resource_profiler=True`. We rerun the simulation above with these profilers. 

In [24]:
haadf_images = detector.detect(exit_waves_stem).reduce_ensemble()

haadf_images, profilers = haadf_images.compute(profiler=True, resource_profiler=True)

To display the results Dask uses the Bokeh plotting library (this is also installed together with Dask). To display the plots in a Jupyter notebook we need to run the commands below.

In [21]:
from bokeh.io import output_notebook
output_notebook()

We first show the result from the `Profiler` object: This shows the execution time for each task as a rectangle, organized along the y-axis by worker (in this case threads), white space represents idle . The task types are grouped by color and, by hovering over each task, one can see the key and task that each block represents. For this calculation there is only one significant task shown in yellow; the task encompasses building the wave function, running the multislice algorithm (calculating the potential on demand) and calculating and integrating the diffraction patterns. 

In [25]:
profilers[0].visualize();

The result from the `ResourceProfiler` is shown below: This shows two lines, one for total CPU percentage used by all the workers, and one for total memory usage. The CPU usage is scaled so each worker contributes up to $100 \ \%$, i.e. two fully utilized workers use $200 \ \%$.

In [26]:
profilers[1].visualize();

We ran the calculation on a single computer with an 4 core CPU with 2 threads per core for 8 threads total. We see that our peak CPU usage was $\sim 600 \%$, this is a fairly typical usage statistic for a single machine, overhead and background processes limited us from reaching the $800 \ \%$ corresponding to using every available thread maximally.

We also see that the total memory use reached around 800 MB. If you are running the calculation on a system with more threads your memory consumption may be larger, it may even exceed the total memory cost of all tjhe wave functions, this is because every parallel run of the multislice algorithm requires a significant overhead for intermediate results (such as potential slices and fresnel propagators). If your calculation runs out of memory you can lower the number of workers, thus trading away computational speed for lower memory consumption.

We note that the overhead in both CPU usage and memory diminishes for larger simulation with more powerful hardware.

### The distributed scheduler (locally)
The Dask distributed scheduler is necessary for running your simulation a cluster, however, it also runs [locally on a personal machine](https://docs.dask.org/en/stable/scheduling.html#dask-distributed-local). You can find details in the Dask documentation . We demonstrate the basics below.

You can use the Dask distributed scheduler by just initializing a Dask `Client`. The `Client` takes keyword arguments such as `n_workers` (note that this is different from ``).

In [18]:
from dask.distributed import Client

client = Client(n_workers=6)
client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 55126 instead


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:55126/status,

0,1
Dashboard: http://127.0.0.1:55126/status,Workers: 6
Total threads: 12,Total memory: 16.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:55127,Workers: 6
Dashboard: http://127.0.0.1:55126/status,Total threads: 12
Started: Just now,Total memory: 16.00 GiB

0,1
Comm: tcp://127.0.0.1:55161,Total threads: 2
Dashboard: http://127.0.0.1:55164/status,Memory: 2.67 GiB
Nanny: tcp://127.0.0.1:55134,
Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-0h1cakdb,Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-0h1cakdb

0,1
Comm: tcp://127.0.0.1:55155,Total threads: 2
Dashboard: http://127.0.0.1:55158/status,Memory: 2.67 GiB
Nanny: tcp://127.0.0.1:55131,
Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-odcj7wl_,Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-odcj7wl_

0,1
Comm: tcp://127.0.0.1:55154,Total threads: 2
Dashboard: http://127.0.0.1:55156/status,Memory: 2.67 GiB
Nanny: tcp://127.0.0.1:55132,
Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-v8lqrdfh,Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-v8lqrdfh

0,1
Comm: tcp://127.0.0.1:55149,Total threads: 2
Dashboard: http://127.0.0.1:55151/status,Memory: 2.67 GiB
Nanny: tcp://127.0.0.1:55133,
Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-a0j5areo,Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-a0j5areo

0,1
Comm: tcp://127.0.0.1:55148,Total threads: 2
Dashboard: http://127.0.0.1:55150/status,Memory: 2.67 GiB
Nanny: tcp://127.0.0.1:55130,
Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-fap4x3t_,Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-fap4x3t_

0,1
Comm: tcp://127.0.0.1:55160,Total threads: 2
Dashboard: http://127.0.0.1:55162/status,Memory: 2.67 GiB
Nanny: tcp://127.0.0.1:55135,
Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-9k5s44wv,Local directory: /var/folders/_p/k7hrlnh132n933827sxdjkrm0000gn/T/dask-worker-space/worker-9k5s44wv


After intializing the client object any abTEM computation will use the Dask distributed scheduler.

In [20]:
haadf_images = detector.detect(exit_waves_stem).reduce_ensemble()
haadf_images.compute()

<abtem.measurements.Images object at 0x7f9b9022bca0>

A benefit of using the distributed scheduler on a single machine is the live diagnostic dashboard. You can access this through the link shown in the `__repr__` for the `Client` above, for details you can watch [this](https://www.youtube.com/watch?v=N_GqzcuGLCY) video walkthrough. If you are using Jupyter Lab the [Dask labextension](https://github.com/dask/dask-labextension) provides the same information as a panel inside the Jupyter Lab editor.

We can get back to the local scheduler by closing the `Client`.

In [29]:
client.close()

### Running abTEM on HPC clusters

Dask (and thus abTEM) has robust tools for deployment on high-performance compute clusters. We recommend consulting your HPC provider on how to deploy Dask applications on your available cluster. For general advice on deployment see the [Dask documentation](https://docs.dask.org/en/stable/deploying.html).

#### Submitting job scripts 

As an overview, Dask provides a number of different cluster managers, so you can use Dask distributed with a range of platforms. These cluster managers deploy a scheduler and the necessary workers as determined by communicating with the resource manager. All cluster managers follow the same interface but have platform specific configuration options.

For example for deployment using SLURM, your script might look something like:

```python
from dask_jobqueue import SLURMCluster
from dask.distributed import Client

cluster = SLURMCluster(
    queue="regular",
    account="myaccount",
    cores=32,
    memory="128 GB"
)

client = Client(cluster)

# Your abTEM code goes here

```

Dask also supports deployment from within an existing MPI environment, such as one created with the common MPI command-line launcher `mpirun`, see [here](http://mpi.dask.org/en/latest/) for more information. You can turn your batch Python script into an MPI executable with the `dask_mpi.initialize` function. 

```python
from dask_mpi import initialize
initialize()

from dask.distributed import Client
client = Client()  # Connect this local process to remote workers

# Your abTEM code goes here
```

This makes your Python script launchable directly with `mpirun`.
```
mpirun -np 4 python my_client_script.py
```

## Using GPUs

Almost every part of abTEM can be accelerated using a GPU through the [CuPy](https://cupy.dev/) library. We have only tested abTEM on CUDA compatiable GPUs, however, any GPU compatible with CuPy should work.

If you have a compatible GPU and a working installation of CuPy, you can accelerate your image simulations by simply changing the configs at the top of your document as below:

In [22]:
abtem.config.set({"device": "gpu"});

It is noted (and maybe obvious) that dask does not manage GPU threads. This makes the choice of batch sizes (i.e. propagating multiple wave functions in a single batch) extremely important in order to fully utilize your GPU. The default batch size in GPU calculations in abTEM is 512 MB, this is 4 times larger than the CPU batch size, however, if your GPU has 8GB or more memory, you will likely be able to squeeze out more performance by increasing this number to at least 2048 MB. Note that the batch size only determines the maximum number of plane waves in a batch, hence you need to leave room in the memory for intermediate overhead.     

While the above is enough for running abTEM on a single GPU; if you are using an NVidia GPU, we recommend installing `dask_cuda`. This is necessesary for multi GPU calculations.

In [None]:
from dask_cuda import LocalCUDACluster
from dask.distributed import Client

cluster = LocalCUDACluster()
client = Client(cluster)

We note that abTEM by default sets the FFT planning cache size of `cupy` to zero, we find that in most cases the memory consumption of the plans are not worth the small speedup they provide. You can change this through the abTEM config. 

## Performance tips

If your performance is unexpectedly low you can try   

### Running out of memory?



### Optimize your simulation parameters

The most effective way of speeding up your simulation is by optimizing your simulation parameters

### Use PRISM

### Change the FFT backend

The Fast Fourier Transform (FFT) is the most important algorithm determining the speed, hence, ensuring that this is 

abTEM supports two different FFT libraries: the open-source [FFTW](https://www.fftw.org/) and Intel's MKL FFT implementation.

We have found that MKL is faster, hence this is the default in abTEM, however, it may be worth trying FFTW 

You can change 

### Change the batch size

### Use power of 2 `gpts`

