In [1]:
%load_ext line_profiler
%load_ext autoreload

# Single Side Band ptychography example

This example uses an implementation of the single side band method
https://doi.org/10.1016/j.ultramic.2014.09.013 to reconstruct an amplitude and phase
image from a simulated 4D scanning transmission electron microscopy dataset with an additional synthetic
potential modulation.

The dataset can be downloaded at https://zenodo.org/record/5113235.

In [2]:
import os
import functools

import matplotlib.pyplot as plt
import libertem.api as lt
import numpy as np
from matplotlib import colors
import sparse

# note: visulization requires empyre
# install empyre: pip install empyre
from empyre.vis.colors import ColormapCubehelix

from libertem import masks
from libertem.udf.base import UDF
from libertem.udf.sum import SumUDF
from libertem.executor.inline import InlineJobExecutor
from libertem.common.container import MaskContainer
from libertem.common.backend import set_use_cuda, set_use_cpu, get_device_class
from libertem.corrections.coordinates import identity
from libertem.viz.mpl import MPLLive2DPlot
from libertem.corrections.coordinates import flip_y, rotate_deg, identity

In [3]:
from ptychography40.reconstruction.ssb import SSB_UDF, BinnedSSB_UDF, generate_masks, crop_bin_params
from ptychography40.reconstruction.common import wavelength, get_shifted

In [4]:
%matplotlib nbagg

## Create the LiberTEM context

An inline executor executes LiberTEM UDFs in the process where this script is running instead of running it on worker processes. This is currently more efficient for this SSB implementation since it allows to re-use precalculated data better. For most other applications, running on distributed workers is faster. See also https://github.com/LiberTEM/LiberTEM/issues/335

In [5]:
ctx = lt.Context.make_with('threads')

## Open the input data

This creates a LiberTEM dataset. The data is not loaded yet, but only when analyses are run on it.

In [6]:
file_params = {'path': r'/storage/er-c-data/adhoc/libertem/libertem-test-data/20200518 165148//default.hdr'}
ds = ctx.load("MIB", **file_params)


## Reconstruction parameters

These have to be adapted for each dataset.

In [7]:

# Acceleration voltage in keV
U = 300
rec_params = {
    "dtype": np.float32,
    "lamb": wavelength(U),
    "dpix": 12.7e-12,
    "semiconv": 22.1346e-3,  # 2020-05-18
    "semiconv_pix": 31,  # 2020-05-18
    # applied right to left
    "transformation": rotate_deg(88) @ flip_y(),
    "cx": 123,
    "cy": 126,
    "cutoff": 16,  # number of pixels: trotters smaller than this will be removed
}
cutoff_freq = np.inf

mask_params = {
    # Shape of the reconstructed area
    'reconstruct_shape': tuple(ds.shape.nav),
    # Shape of a detector frame
    'mask_shape': tuple(ds.shape.sig),
    # Use the faster shifting method to generate trotters
    'method': 'shift',
}

## Initial analysis of the dataset

We sum up all frames and confirm that size and position of the zero order beam match.

In [8]:
sum_udf = SumUDF()

live_plot = MPLLive2DPlot(
    dataset=ds,
    udf=sum_udf,
)
live_plot.display()
circ_a = plt.Circle((rec_params["cx"], rec_params["cy"]), rec_params["semiconv_pix"], fill=False, color='red')
live_plot.axes.add_artist(circ_a)
live_plot.fig.colorbar(live_plot.im_obj)

sum_result = ctx.run_udf(dataset=ds, udf=sum_udf, plots=[live_plot], progress=True)

<IPython.core.display.Javascript object>

100%|██████████| 18/18 [00:01<00:00, 10.58it/s]


## Center of mass analysis

This is used to confirm that the coordinate system between scan and detector is properly adjusted. The beam should be deflected towards the nuclei for high resolution STEM data. That means that the field should have little curl and negative divergence at the position of the nuclei. Furthermore, x and y deflection should point towards the nuclei. See also https://libertem.github.io/LiberTEM/concepts.html#coordinate-system

In [9]:
com_analysis = ctx.create_com_analysis(
    dataset=ds,
    cx=rec_params["cx"],
    cy=rec_params["cy"],
    mask_radius=rec_params["semiconv_pix"] + 30,
    flip_y=True,
    scan_rotation=88,
)
com_result = ctx.run(com_analysis, progress=True)
print(com_result)

100%|██████████| 18/18 [00:02<00:00,  8.54it/s]

[<AnalysisResult: field>, <AnalysisResult: magnitude>, <AnalysisResult: divergence>, <AnalysisResult: curl>, <AnalysisResult: x>, <AnalysisResult: y>]





In [10]:
fig, axes = plt.subplots()
axes.set_title("field")
y_centers, x_centers = com_result.field.raw_data
ch = ColormapCubehelix(start=1, rot=1, minLight=0.5, maxLight=0.5, sat=2)
axes.imshow(ch.rgb_from_vector(np.broadcast_arrays(y_centers, x_centers, 0)))

fig, axes = plt.subplots()
axes.set_title("magnitude")
p = axes.imshow(com_result.magnitude.raw_data)
fig.colorbar(p)

fig, axes = plt.subplots()
axes.set_title("divergence")
p = axes.imshow(com_result.divergence.raw_data)
fig.colorbar(p)

fig, axes = plt.subplots()
axes.set_title("curl")
p = axes.imshow(com_result.curl.raw_data)
fig.colorbar(p)

fig, axes = plt.subplots()
axes.set_title("x")
p = axes.imshow(com_result.x.raw_data)
fig.colorbar(p)

fig, axes = plt.subplots()
axes.set_title("y")
p = axes.imshow(com_result.y.raw_data)
fig.colorbar(p)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<matplotlib.colorbar.Colorbar at 0x7f24a81c4f90>

## Pre-calculate the trotter stack

This takes some time and can be re-used for given reconstruction parameters. The stack is in a sparse matrix format.

In [11]:
%%time
trotters = generate_masks(**rec_params, **mask_params)

CPU times: user 6.52 s, sys: 29.5 ms, total: 6.55 s
Wall time: 6.53 s


In [12]:
fig, axes = plt.subplots()
axes.imshow(trotters[1].todense())

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f241b85c310>

Uncomment to use GPU processing on device 0 with the inline executor

In [13]:
# set_use_cuda(0)

We create a LiberTEM `MaskContainer` from the mask stack. The `MaskContainer` is used to calculate and cache subsets of the mask stack with optimized properties for a fast dot product and tiled processing.

In [14]:
mask_container = MaskContainer(
    mask_factories=lambda: trotters, dtype=trotters.dtype, count=trotters.shape[0]
)



## Instantiate and run the SSB UDF
The mask_container is passed to the UDF to allow re-use. This is a work-around for https://github.com/LiberTEM/LiberTEM/issues/335

In [15]:
udf = SSB_UDF(**rec_params, mask_container=mask_container)

We create dedicated plots so that we can add a colorbar and only plot amplitude and phase. These plots will be updated with calculation results when running the UDF in the cell after.

In [16]:
ssb_plots = []
for channel in 'amplitude', 'phase':
    p = MPLLive2DPlot(
        dataset=ds,
        udf=udf,
        channel=channel,
    )
    p.display()
    ssb_plots.append(p)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [17]:
%%time
# We use the threading executor since the sparse matrix stack is quite large.
# Instead of process-based parallelism, we use multithreading.
# This allows to re-use the masks and the cache of MaskContainer between partitions.
udf_result = ctx.run_udf(udf=udf, dataset=ds, plots=ssb_plots, progress=True)

100%|██████████| 18/18 [00:02<00:00,  6.95it/s]


CPU times: user 38.5 s, sys: 979 ms, total: 39.5 s
Wall time: 2.83 s


In [18]:
print(rec_params, mask_params)

{'dtype': <class 'numpy.float32'>, 'lamb': 1.9687489006848795e-12, 'dpix': 1.27e-11, 'semiconv': 0.0221346, 'semiconv_pix': 31, 'transformation': array([[-0.0348995 ,  0.99939083],
       [ 0.99939083,  0.0348995 ]]), 'cx': 123, 'cy': 126, 'cutoff': 16} {'reconstruct_shape': (128, 128), 'mask_shape': (256, 256), 'method': 'shift'}


In [19]:
test_rec_params = rec_params.copy()
test_mask_params = mask_params.copy()
binning_factor = 5
# test_rec_params['cy'] = 40
# test_rec_params['cx'] = 200

In [20]:
binned_rec_params, binned_mask_params, y_binner, x_binner = crop_bin_params(test_rec_params, test_mask_params, binning_factor)

In [21]:
data = np.zeros((128*128, 256, 256), dtype=np.float32)

In [22]:
import threadpoolctl

In [23]:
y_binner_conv = sparse.GCXS(y_binner)
x_binner_conv = sparse.GCXS(x_binner)

In [24]:
with threadpoolctl.threadpool_limits(1):
    %time y_binner @ data @ x_binner

CPU times: user 1.08 s, sys: 17.6 ms, total: 1.1 s
Wall time: 1.09 s


In [25]:
with threadpoolctl.threadpool_limits(1):
    %time y_binner @ (data @ x_binner)

CPU times: user 1.25 s, sys: 13.8 ms, total: 1.27 s
Wall time: 1.26 s


In [26]:
%%time
data.reshape((128*128, 64, 4, 64, 4)).sum(axis=(-1, -3))

CPU times: user 4.94 s, sys: 15.9 ms, total: 4.95 s
Wall time: 4.94 s


array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

In [27]:
%%time
binned_masks = generate_masks(**binned_rec_params, **binned_mask_params)

CPU times: user 1.59 s, sys: 23.6 ms, total: 1.62 s
Wall time: 1.59 s


In [28]:
import scipy.misc
testdata = np.swapaxes(scipy.misc.face()[190:190+256, 500:500+256], 0, 2)

In [29]:
size = y_binner.shape[0]
orig_size = int(size * binning_factor)
origin = (int(rec_params['cy'] - orig_size / 2), int(rec_params['cx'] - orig_size / 2))

In [30]:
crop = testdata[..., origin[0]:origin[0]+orig_size, origin[1]:origin[1]+orig_size]
folded = crop.reshape((-1, size, binning_factor, size, binning_factor))
binned = folded.sum(axis=(-1, -3))

In [31]:
matrix_res = y_binner @ testdata @ x_binner

In [32]:
fig, axes = plt.subplots(1, 3)
axes[0].imshow(np.swapaxes(testdata, 0, 2))
axes[1].imshow(np.swapaxes(binned, 0, 2) / np.max(binned))
axes[2].imshow(np.swapaxes(matrix_res, 0, 2) / np.max(matrix_res))

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f2484091450>

In [33]:
inline_ctx = lt.Context.make_with('inline')

In [34]:
dist_ctx = lt.Context()

In [35]:
y_binner.shape, x_binner.shape

((14, 256), (256, 14))

In [36]:
%autoreload
binned_mask_flat = binned_masks.reshape((
    binned_masks.shape[0], np.prod(binned_masks.shape[1:], dtype=np.int64)
))
csr_trotters = binned_mask_flat.tocsr()
binned_udf = BinnedSSB_UDF(
    y_binner=y_binner,
    x_binner=x_binner,
    csr_trotters=csr_trotters,
)



In [37]:
%%time
bin_res = dist_ctx.run_udf(dataset=ds, udf=binned_udf, progress=True)

100%|██████████| 18/18 [00:01<00:00, 11.75it/s]

CPU times: user 293 ms, sys: 75 ms, total: 368 ms
Wall time: 1.54 s





In [38]:
%%time
bin_res = ctx.run_udf(dataset=ds, udf=binned_udf, progress=True)

100%|██████████| 18/18 [00:01<00:00, 15.22it/s]

CPU times: user 7.91 s, sys: 5.58 s, total: 13.5 s
Wall time: 1.2 s





In [40]:
with threadpoolctl.threadpool_limits(1):
    %lprun -f binned_udf.process_partition -f binned_udf.merge_dot_result inline_ctx.run_udf(dataset=ds, udf=binned_udf, progress=True)

100%|██████████| 18/18 [00:05<00:00,  3.08it/s]


In [41]:
fig, axes = plt.subplots(1, 2)
axes[0].imshow(bin_res['phase'])
axes[1].imshow(bin_res['amplitude'])

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f241ba49490>