In [2]:
import abtem
import dask
import matplotlib.pyplot as plt
import numpy as np
import scrapbook as sb
from abtem.core import config
import ase

In [3]:
gpts = 1024
sampling = 0.05
num_configs = 8
threads = 1

In [4]:
extent = gpts * sampling
slope = 12 / 60.2762953922335

p = int(np.floor(slope * extent))
q = max(p // 5, 1)

atoms = ase.cluster.Decahedron("Au", p, q, 0)
atoms.rotate("x", 30)
atoms.cell[0, 0] = atoms.cell[1, 1] = extent
atoms.center()
atoms.center(axis=2, vacuum=4)

In [18]:
frozen_phonons = abtem.FrozenPhonons(atoms, num_configs, sigmas=0.1)

potential = abtem.Potential(
    atoms,
    gpts=gpts,
    projection="infinite",
    slice_thickness=2,
)

wave = abtem.Probe(energy=200e3, semiangle_cutoff=30, defocus=np.linspace(0,100,10))
wave.grid.match(potential)

In [6]:
%load_ext memray

In [19]:
waves = wave.build(lazy=False)

waves.shape

(10, 1024, 1024)

In [23]:
wave.build(lazy=True).array

Unnamed: 0,Array,Chunk
Bytes,80.00 MiB,80.00 MiB
Shape,"(10, 1024, 1024)","(10, 1024, 1024)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,complex64 numpy.ndarray,complex64 numpy.ndarray
"Array Chunk Bytes 80.00 MiB 80.00 MiB Shape (10, 1024, 1024) (10, 1024, 1024) Dask graph 1 chunks in 5 graph layers Data type complex64 numpy.ndarray",1024  1024  10,

Unnamed: 0,Array,Chunk
Bytes,80.00 MiB,80.00 MiB
Shape,"(10, 1024, 1024)","(10, 1024, 1024)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,complex64 numpy.ndarray,complex64 numpy.ndarray


In [29]:
%%memray_flamegraph
with config.set({"fft": "mkl", "fftw.threads": threads}):
    wave.build(lazy=False).multislice(potential)

#with config.set({"fft": "fftw", "fftw.threads": threads}):
#    exit_wave.compute()

Output()

Output()

In [6]:
%%timeit -n1 -r1 -o

exit_wave = wave.multislice(potential)

with config.set({"fft": "fftw", "fftw.threads": threads}):
    exit_wave.compute()

[########################################] | 100% Completed | 1.25 ss
1.3 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<TimeitResult : 1.3 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>

In [7]:
sb.glue("fftw.average", _.average)
sb.glue("fftw.stdev", _.stdev)
sb.glue("fftw.best", _.best)
sb.glue("fftw.worst", _.worst)

In [8]:
%%timeit -n1 -r1 -o

exit_wave = wave.multislice(potential)

with config.set({"fft": "mkl", "mkl.threads": threads}):
    exit_wave.compute()

[########################################] | 100% Completed | 1.62 ss
1.66 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<TimeitResult : 1.66 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>

In [9]:
sb.glue("mkl.average", _.average)
sb.glue("mkl.stdev", _.stdev)
sb.glue("mkl.best", _.best)
sb.glue("mkl.worst", _.worst)