In [None]:
%load_ext line_profiler

# Single Side Band ptychography example

This example uses an implementation of the single side band method
https://doi.org/10.1016/j.ultramic.2014.09.013 to reconstruct an amplitude and phase
image from a simulated 4D scanning transmission electron microscopy dataset with an additional synthetic
potential modulation.

The dataset can be downloaded at https://zenodo.org/record/5113235.

In [None]:
import os
import functools

import matplotlib.pyplot as plt
import libertem.api as lt
import numpy as np
from matplotlib import colors
import sparse

# note: visulization requires empyre
# install empyre: pip install empyre
from empyre.vis.colors import ColormapCubehelix

from libertem import masks
from libertem.udf.base import UDF
from libertem.udf.sum import SumUDF
from libertem.executor.inline import InlineJobExecutor
from libertem.common.container import MaskContainer
from libertem.common.backend import set_use_cuda, set_use_cpu, get_device_class
from libertem.corrections.coordinates import identity
from libertem.viz.bqp import BQLive2DPlot
from libertem.viz.mpl import MPLLive2DPlot
from libertem.corrections.coordinates import flip_y, rotate_deg, identity

In [None]:
from ptychography40.reconstruction.ssb.udf import SSB_UDF, SSB_Base, generate_masks, rmatmul_csc_fourier
from ptychography40.reconstruction.common import wavelength, get_shifted

In [None]:
%matplotlib nbagg

## Create the LiberTEM context

An inline executor executes LiberTEM UDFs in the process where this script is running instead of running it on worker processes. This is currently more efficient for this SSB implementation since it allows to re-use precalculated data better. For most other applications, running on distributed workers is faster. See also https://github.com/LiberTEM/LiberTEM/issues/335

In [None]:
ctx = lt.Context()

## Open the input data

This creates a LiberTEM dataset. The data is not loaded yet, but only when analyses are run on it.

In [None]:
file_params = {'path': r'E:\LargeData\LargeData\ER-C-1\groups\data_science\data\reference\MIB\20200518 165148\default.hdr'}
ds = ctx.load("MIB", **file_params)


## Reconstruction parameters

These have to be adapted for each dataset.

In [None]:

# Acceleration voltage in keV
U = 300
rec_params = {
    "dtype": np.float32,
    "lamb": wavelength(U),
    "dpix": 12.7e-12,
    "semiconv": 22.1346e-3,  # 2020-05-18
    "semiconv_pix": 31,  # 2020-05-18
    # applied right to left
    "transformation": rotate_deg(88) @ flip_y(),
    "cx": 123,
    "cy": 126,
    "cutoff": 16,  # number of pixels: trotters smaller than this will be removed
}
cutoff_freq = np.inf

mask_params = {
    # Shape of the reconstructed area
    'reconstruct_shape': tuple(ds.shape.nav),
    # Shape of a detector frame
    'mask_shape': tuple(ds.shape.sig),
    # Use the faster shifting method to generate trotters
    'method': 'shift',
}

## Initial analysis of the dataset

We sum up all frames and confirm that size and position of the zero order beam match.

In [None]:
sum_udf = SumUDF()

live_plot = MPLLive2DPlot(
    dataset=ds,
    udf=sum_udf,
)
live_plot.display()
circ_a = plt.Circle((rec_params["cx"], rec_params["cy"]), rec_params["semiconv_pix"], fill=False, color='red')
live_plot.axes.add_artist(circ_a)
live_plot.fig.colorbar(live_plot.im_obj)

sum_result = ctx.run_udf(dataset=ds, udf=sum_udf, plots=[live_plot], progress=True)

## Center of mass analysis

This is used to confirm that the coordinate system between scan and detector is properly adjusted. The beam should be deflected towards the nuclei for high resolution STEM data. That means that the field should have little curl and negative divergence at the position of the nuclei. Furthermore, x and y deflection should point towards the nuclei. See also https://libertem.github.io/LiberTEM/concepts.html#coordinate-system

In [None]:
com_analysis = ctx.create_com_analysis(
    dataset=ds,
    cx=rec_params["cx"],
    cy=rec_params["cy"],
    mask_radius=rec_params["semiconv_pix"] + 30,
    flip_y=True,
    scan_rotation=88,
)
com_result = ctx.run(com_analysis, progress=True)
print(com_result)

In [None]:
fig, axes = plt.subplots()
axes.set_title("field")
y_centers, x_centers = com_result.field.raw_data
ch = ColormapCubehelix(start=1, rot=1, minLight=0.5, maxLight=0.5, sat=2)
axes.imshow(ch.rgb_from_vector(np.broadcast_arrays(y_centers, x_centers, 0)))

fig, axes = plt.subplots()
axes.set_title("magnitude")
p = axes.imshow(com_result.magnitude.raw_data)
fig.colorbar(p)

fig, axes = plt.subplots()
axes.set_title("divergence")
p = axes.imshow(com_result.divergence.raw_data)
fig.colorbar(p)

fig, axes = plt.subplots()
axes.set_title("curl")
p = axes.imshow(com_result.curl.raw_data)
fig.colorbar(p)

fig, axes = plt.subplots()
axes.set_title("x")
p = axes.imshow(com_result.x.raw_data)
fig.colorbar(p)

fig, axes = plt.subplots()
axes.set_title("y")
p = axes.imshow(com_result.y.raw_data)
fig.colorbar(p)

## Pre-calculate the trotter stack

This takes some time and can be re-used for given reconstruction parameters. The stack is in a sparse matrix format.

In [None]:
%%time
trotters = generate_masks(**rec_params, **mask_params)

In [None]:
fig, axes = plt.subplots()
axes.imshow(trotters[1].todense())

Uncomment to use GPU processing on device 0 with the inline executor

In [None]:
# set_use_cuda(0)

We create a LiberTEM `MaskContainer` from the mask stack. The `MaskContainer` is used to calculate and cache subsets of the mask stack with optimized properties for a fast dot product and tiled processing.

In [None]:
mask_container = MaskContainer(
    mask_factories=lambda: trotters, dtype=trotters.dtype, count=trotters.shape[0]
)

## Instantiate and run the SSB UDF
The mask_container is passed to the UDF to allow re-use. This is a work-around for https://github.com/LiberTEM/LiberTEM/issues/335

In [None]:
udf = SSB_UDF(**rec_params, mask_container=mask_container)

We create dedicated plots so that we can add a colorbar and only plot amplitude and phase. These plots will be updated with calculation results when running the UDF in the cell after.

In [None]:
ssb_plots = []
for channel in 'amplitude', 'phase':
    p = BQLive2DPlot(
        dataset=ds,
        udf=udf,
        channel=channel,
    )
    p.display()
    ssb_plots.append(p)

In [None]:
# We use the inline executor since the sparse matrix stack is quite large.
# Instead of process-based parallelism, we use multithreading in the UDF.
# This allows to re-use the masks and the cache of MaskContainer between partitions.
udf_result = ctx.run_udf(udf=udf, dataset=ds, plots=ssb_plots, progress=True)

In [None]:
print(rec_params, mask_params)

In [None]:
def crop_bin_params(rec_params, mask_params, binning_factor: int):    
    center = int(np.ceil(rec_params["semiconv_pix"] / binning_factor))
    size = 2 * center
    
    def crop_bin_vector(length, origin):
        bins = np.zeros((length, size), dtype=np.float32)
        for i in range(size):
            start = origin + i*binning_factor
            stop = start + binning_factor
            bins[start:stop, i] = 1
        return bins
    
    
    new_rec_params = rec_params.copy()
    new_rec_params['cy'] = center
    new_rec_params['cx'] = center
    new_rec_params['semiconv_pix'] = rec_params['semiconv_pix'] / binning_factor
    new_rec_params['cutoff'] = int(np.ceil(rec_params['cutoff'] / binning_factor**2))
    new_mask_params = mask_params.copy()
    new_mask_params['mask_shape'] = (size, size)
    new_mask_params['method'] = 'subpix'
    
    y_binner = crop_bin_vector(
        length=mask_params['mask_shape'][0],
        origin=int(rec_params['cy']) - binning_factor * center,
    ).T
    
    x_binner = crop_bin_vector(
        length=mask_params['mask_shape'][1],
        origin=int(rec_params['cx']) - binning_factor * center,
    )
    
    return(new_rec_params, new_mask_params, y_binner, x_binner)

In [None]:
test_rec_params = rec_params.copy()
test_mask_params = mask_params.copy()
binning_factor = 5
# test_rec_params['cy'] = 40
# test_rec_params['cx'] = 200

In [None]:
binned_rec_params, binned_mask_params, y_binner, x_binner = crop_bin_params(test_rec_params, test_mask_params, binning_factor)

In [None]:
data = np.zeros((128*128, 256, 256), dtype=np.float32)

In [None]:
import threadpoolctl

In [None]:
with threadpoolctl.threadpool_limits(1):
    %timeit y_binner @ data @ x_binner

In [None]:
%%timeit
data.reshape((128*128, 64, 4, 64, 4)).sum(axis=(-1, -3))

In [None]:
%%time
binned_masks = generate_masks(**binned_rec_params, **binned_mask_params)

In [None]:
import scipy.misc
testdata = np.swapaxes(scipy.misc.face()[190:190+256, 500:500+256], 0, 2)

In [None]:
size = y_binner.shape[0]
orig_size = int(size * binning_factor)
origin = (int(rec_params['cy'] - orig_size / 2), int(rec_params['cx'] - orig_size / 2))

In [None]:
crop = testdata[..., origin[0]:origin[0]+orig_size, origin[1]:origin[1]+orig_size]
folded = crop.reshape((-1, size, binning_factor, size, binning_factor))
binned = folded.sum(axis=(-1, -3))

In [None]:
matrix_res = y_binner @ testdata @ x_binner

In [None]:
fig, axes = plt.subplots(1, 3)
axes[0].imshow(np.swapaxes(testdata, 0, 2))
axes[1].imshow(np.swapaxes(binned, 0, 2) / np.max(binned))
axes[2].imshow(np.swapaxes(matrix_res, 0, 2) / np.max(matrix_res))

In [None]:
inline_ctx = lt.Context.make_with('inline')

In [None]:
y_binner.shape, x_binner.shape

In [None]:
def get_binner(constructor, y_binner, x_binner):
    
    @functools.lru_cache()
    def get(sig_slice):
        y_origin, x_origin = sig_slice.origin
        y_shape, x_shape = sig_slice.shape
        y_res = y_binner[:, y_origin:y_origin+y_shape]
        x_res = x_binner[x_origin:x_origin+x_shape]
        return (
            constructor(y_res),
            constructor(x_res),
            np.allclose(y_res, 0) or np.allclose(x_res, 0)
        )
    
    return get

class BinnedSSB_UDF(SSB_Base):
    def __init__(self, y_binner, x_binner, csr_trotters: scipy.sparse.csc_matrix, dtype=np.float32):
        # make sure the cropped and binned region has a size divisible by two
        binned_size = np.sqrt(csr_trotters.shape[1])
        assert np.allclose(binned_size % 2, 0)
        super().__init__(y_binner=y_binner, x_binner=x_binner, csr_trotters=csr_trotters, dtype=dtype)
    
    def get_task_data(self):
        result = super().get_task_data()
        result['binner'] = get_binner(self.xp.array, self.params.y_binner, self.params.x_binner)
        if self.meta.device_class == 'cpu':
            result['trotters'] = self.params.csr_trotters
        elif self.meta.device_class == 'cuda':
            import cupy.sparse
            result['trotters'] = cupy.sparse.csr_matrix(self.params.csr_trotters)
        return result
    
    def process_tile(self, tile):
        target_slice = self.meta.slice.shift(self.meta._partition_slice).get(nav_only=True)
        y_binner, x_binner, noop = self.task_data.binner(self.meta.sig_slice)
        if not noop:
            binned = y_binner @ tile @ x_binner
            binned_flat = binned.reshape((binned.shape[0], binned.shape[1]*binned.shape[2]))
            masks = self.task_data.trotters
            half_y = self.results.fourier.shape[0] // 2 + 1
            dot_result = masks.dot(binned_flat.T).T
            self.merge_dot_result(dot_result)
    
    def get_backends(self):
        ''
        return ('cupy', 'numpy')
    
    def get_tiling_preferences(self):
        ''
        dtype = np.result_type(np.complex64, self.params.dtype)
        result_size = np.prod(self.reconstruct_shape) * dtype.itemsize
        if self.meta.device_class == 'cuda':
            free, total = self.xp.cuda.runtime.memGetInfo()
            total_size = min(100e6, free // 4)
            good_depth = max(1, total_size / result_size * 4)
            return {
                "depth": good_depth,
                "total_size": total_size,
            }
        else:
            # We limit the depth of a tile so that the intermediate
            # results from processing a tile fit into the CPU cache.
            good_depth = max(1, 1e6 / result_size)
            return {
                "depth": int(good_depth),
                "total_size": 2e6,
            }


In [None]:
binned_mask_flat = binned_masks.reshape((
    binned_masks.shape[0], np.prod(binned_masks.shape[1:], dtype=np.int64)
))
csr_trotters = binned_mask_flat.tocsr()
binned_udf = BinnedSSB_UDF(
    y_binner=y_binner,
    x_binner=x_binner,
    csr_trotters=csr_trotters,
)



In [None]:
inline_ctx.run_udf(dataset=ds, udf=binned_udf, progress=True)

In [None]:
%%timeit
bin_res = ctx.run_udf(dataset=ds, udf=binned_udf, progress=True)

In [None]:
with threadpoolctl.threadpool_limits(1):
    %lprun -f binned_udf.process_tile -f binned_udf.merge_dot_result inline_ctx.run_udf(dataset=ds, udf=binned_udf, progress=True)

In [None]:
fig, axes = plt.subplots(1, 2)
axes[0].imshow(bin_res['phase'])
axes[1].imshow(bin_res['amplitude'])