In [1]:
%matplotlib widget
import numpy as np
import cupy as cp

import matplotlib.pyplot as plt 
from matplotlib_scalebar.scalebar import ScaleBar
from mpl_toolkits.axes_grid1 import make_axes_locatable

import realtime_ptycho as sm
from realtime_ptycho.core import Sparse4DData, Metadata4D
from realtime_ptycho.util import get_qx_qy_1D, get_qx_qy_2D, disk_overlap_function, plotcxmosaic, sector_mask, plot, single_sideband_reconstruction, imsave, mosaic, sparse_to_dense_datacube, wavelength

from numpy.fft import fftshift

from py4DSTEM.process.dpc import get_phase_from_CoM
from pathlib import Path

from tifffile import  imwrite
import time

from ipywidgets import AppLayout, FloatSlider, GridspecLayout, VBox, IntSlider, FloatLogSlider, HBox
import ipywidgets as widgets

import time

import cupy as cp
from cupyx.scipy.fft import fft2

from tqdm import trange
from skimage.filters import gaussian

In [2]:
scan_number = 147

base_path = Path(os.getcwd())
adfpath = base_path
sparse_path = base_path  
results_path = base_path / 'results/'

if not results_path.exists():
    results_path.mkdir()
    
filename4d = sparse_path / f'data_scan{scan_number}_th4.0_electrons.h5'
filenameadf = adfpath / f'scan{scan_number}.dm4'

alpha_max_factor = 1.2
alpha_max_factor = 1.05

In [3]:
print('1: data loading')
d = Sparse4DData.from_4Dcamera_file(filename4d)
metadata = Metadata4D.from_dm4_file(filenameadf)

metadata.alpha_rad = 25e-3
metadata.rotation_deg = 0
metadata.wavelength =  wavelength(metadata.E_ev)  

center, radius = d.determine_center_and_radius(manual=False, size=200) 
print(f'center: {center}')
print(f'radius: {radius}')
print('2: cropping')
d.crop_symmetric_center_(center, radius*alpha_max_factor)
print('3: sum diffraction pattern')
s = d.sum_diffraction()
print('4: plotting')

f,ax = plt.subplots(1,2,figsize=(8,4))
imax = ax[0].imshow(s)
ax[0].set_title(f'Scan {scan_number} sum after cropping')
imax = ax[1].imshow(np.log10(s+1))
ax[1].set_title(f'Scan {scan_number} log10(sum) after cropping')
plt.colorbar(imax)
plt.tight_layout()


1: data loading
center: [292.20951103 160.73790153]
radius: 123.21053157858441
2: cropping
old frames frame_dimensions: [576 576]
new frames frame_dimensions: [258 258]
3: sum diffraction pattern
4: plotting


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [4]:
dwell_time = 1/87e3
detector_to_real_fluence_80kv = 1#1/0.56

fluence = d.fluence(metadata.dr[0]) * detector_to_real_fluence_80kv
flux = d.flux(metadata.dr[0], dwell_time) * detector_to_real_fluence_80kv

# print(f"E               = {metadata.E_ev/1e3}             keV")
# print(f"λ               = {metadata.wavelength * 1e2:2.2}   pm")
print(f"dR              = {metadata.dr} Å")
print(f"scan       size = {d.scan_dimensions}")
print(f"detector   size = {d.frame_dimensions}")
print(f"scan       FOV  = {d.scan_dimensions*metadata.dr/10} nm")
print(f"fluence         ~ {fluence} e/Å^2")
print(f"flux            ~ {flux} e/Å^2/s")

dR              = [0.31626087 0.31626087] Å
scan       size = [512 512]
detector   size = [258 258]
scan       FOV  = [16.19255676 16.19255676] nm
fluence         ~ 21622.471675663513 e/Å^2
flux            ~ 7176.036971217062 e/Å^2/s


In [5]:
def get_phase_from_CoM(CoMx, CoMy, theta, flip, regLowPass=0.5, regHighPass=100, paddingfactor=2,
                                                stepsize=1, n_iter=10, phase_init=None):
    """
    Calculate the phase of the sample transmittance from the diffraction centers of mass.
    A bare bones description of the approach taken here is below - for detailed discussion of the
    relevant theory, see, e.g.:
        Ishizuka et al, Microscopy (2017) 397-405
        Close et al, Ultramicroscopy 159 (2015) 124-137
        Wadell and Chapman, Optik 54 (1979) No. 2, 83-96

    The idea here is that the deflection of the center of mass of the electron beam in the
    diffraction plane scales linearly with the gradient of the phase of the sample transmittance.
    When this correspondence holds, it is therefore possible to invert the differential equation and
    extract the phase itself.* The primary assumption made is that the sample is well
    described as a pure phase object (i.e. the real part of the transmittance is 1). The inversion
    is performed in this algorithm in Fourier space, i.e. using the Fourier transform property
    that derivatives in real space are turned into multiplication in Fourier space.

    *Note: because in DPC a differential equation is being inverted - i.e. the  fundamental theorem
    of calculus is invoked - one might be tempted to call this "integrated differential phase
    contrast".  Strictly speaking, this term is redundant - performing an integration is simply how
    DPC works.  Anyone who tells you otherwise is selling something.

    Accepts:
        CoMx            (2D array) the diffraction space centers of mass x coordinates
        CoMy            (2D array) the diffraction space centers of mass y coordinates
        theta           (float) the rotational offset between real and diffraction space coordinates
        flip            (bool) whether or not the real and diffraction space coords contain a
                        relative flip
        regLowPass      (float) low pass regularization term for the Fourier integration operators
        regHighPass     (float) high pass regularization term for the Fourier integration operators
        paddingfactor   (int) padding to add to the CoM arrays for boundry condition handling.
                        1 corresponds to no padding, 2 to doubling the array size, etc.
        stepsize        (float) the stepsize in the iteration step which updates the phase
        n_iter          (int) the number of iterations
        phase_init      (2D array) initial guess for the phase

    Returns:
        phase           (2D array) the phase of the sample transmittance
        error           (1D array) the error - RMSD of the phase gradients compared to the CoM - at
                        each iteration step
    """

    # Coordinates
    R_Nx,R_Ny = CoMx.shape
    R_Nx_padded,R_Ny_padded = R_Nx*paddingfactor,R_Ny*paddingfactor

    qx = cp.fft.fftfreq(R_Nx_padded)
    qy = cp.fft.rfftfreq(R_Ny_padded)
    qr2 = qx[:,None]**2 + qy[None,:]**2

    # Invese operators
    denominator = qr2 + regHighPass + qr2**2*regLowPass
    _ = np.seterr(divide='ignore')
    denominator = 1./denominator
    denominator[0,0] = 0
    _ = np.seterr(divide='warn')
    f = 1j * 0.25*stepsize
    qxOperator = f*qx[:,None]*denominator
    qyOperator = f*qy[None,:]*denominator

    # Perform rotation and flipping
    if not flip:
        CoMx_rot = CoMx*np.cos(theta) - CoMy*np.sin(theta)
        CoMy_rot = CoMx*np.sin(theta) + CoMy*np.cos(theta)
    if flip:
        CoMx_rot = CoMx*np.cos(theta) + CoMy*np.sin(theta)
        CoMy_rot = CoMx*np.sin(theta) - CoMy*np.cos(theta)

    # Initializations
    phase = cp.zeros((R_Nx_padded,R_Ny_padded))
    update = cp.zeros((R_Nx_padded,R_Ny_padded))
    dx = cp.zeros((R_Nx_padded,R_Ny_padded))
    dy = cp.zeros((R_Nx_padded,R_Ny_padded))
    error = cp.zeros((n_iter,))
    mask = cp.zeros((R_Nx_padded,R_Ny_padded),dtype=bool)
    mask[:R_Nx,:R_Ny] = True
    maskInv = mask==False
    if phase_init is not None:
        phase[:R_Nx,:R_Ny] = phase_init

    # Iterative reconstruction
    for i in range(n_iter):

        # Update gradient estimates using measured CoM values
        dx[mask] -= CoMx_rot.ravel()
        dy[mask] -= CoMy_rot.ravel()
        dx[maskInv] = 0
        dy[maskInv] = 0

        # Calculate reconstruction update
        update = cp.fft.irfft2( cp.fft.rfft2(dx)*qxOperator + cp.fft.rfft2(dy)*qyOperator)

        # Apply update
        phase += stepsize*update

        # Measure current phase gradients
        dx = (cp.roll(phase,(-1,0),axis=(0,1)) - cp.roll(phase,(1,0),axis=(0,1))) / 2.
        dy = (cp.roll(phase,(0,-1),axis=(0,1)) - cp.roll(phase,(0,1),axis=(0,1))) / 2.

        # Estimate error from cost function, RMS deviation of gradients
        xDiff = dx[mask] - CoMx_rot.ravel()
        yDiff = dy[mask] - CoMy_rot.ravel()
        error[i] = cp.sqrt(cp.mean((xDiff-cp.mean(xDiff))**2 + (yDiff-cp.mean(yDiff))**2))

        # Halve step size if error is increasing
        if i>0:
            if error[i] > error[i-1]:
                stepsize /= 2

    phase = phase[:R_Nx,:R_Ny]

    return phase, error, denominator

In [6]:
def run_dpc(regLowPass, regHighPass, stepsize, n_iter):
    print(n_iter[0])
    dpc, error, denominator = get_phase_from_CoM(comy,comx,np.deg2rad(metadata.rotation_deg),False, regLowPass=regLowPass[0], regHighPass=regHighPass[0], paddingfactor=2,
                                                    stepsize=stepsize[0], n_iter=n_iter[0], phase_init=None)

    fy = fftshift(cp.abs(fft2(comy))).get()
    dd = denominator.get()
    dd = fftshift(np.hstack([dd,np.fliplr(dd)]))
    dd = dd[::2,::2]
    dd = dd[:,:-1]

    ps = np.log10((fy * dd) + 1)
    
    return dpc, error, ps

In [8]:
plt.ioff()
comy, comx = d.center_of_mass()

regLowPass = np.array([1e3])
regHighPass = np.array([5e-1])
stepsize = np.array([0.9])
n_iter = np.array([20]).astype(np.int)

dpc, error, ps = run_dpc(regLowPass, regHighPass, stepsize, n_iter)

fig1, ax1 = plt.subplots(1,1,figsize=(7.5,7.5))
im1 = ax1.imshow(dpc.get(), cmap= plt.cm.get_cmap('bone'))
ax1.set_title(f'DPC reconstruction')
ax1.set_xticks([])
ax1.set_yticks([])
ax1.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))
divider = make_axes_locatable(ax1)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im1, cax=cax)

fig2, ax2 = plt.subplots(1,1,figsize=(7.5,7.5))
im2 = ax2.imshow(ps, cmap= plt.cm.get_cmap('bone'), alpha=1)
ax2.set_title(f'log10(Power spectrum * regularization)')
ax2.set_xticks([])
ax2.set_yticks([])

fig3, ax3 = plt.subplots(1,1,figsize=(7.5,7.5))
im3 = ax3.plot(error.get())
ax3.set_xlabel('iteration')
ax3.set_ylabel('Error')


gs = GridspecLayout(1,9)

plot_box1 = HBox(children =[fig1.canvas, fig2.canvas, fig3.canvas])   
Cslider_box = VBox(width=10)

sliders = []

i = 0

def set_new_data_and_update(dpc, error, ps):
    im1.set_data(dpc)
    im2.set_data(ps)
    
    im3[0].set_ydata(error)
    im3[0].set_xdata(np.arange(len(error)))
    ax3.set_xlim(0,len(error))
    ax3.set_ylim(error.min(),error.max())
    
    im1.set_clim(dpc.min(),dpc.max())
    im2.set_clim(ps.min(),ps.max())
    
    fig1.canvas.draw()
    fig1.canvas.flush_events()
    
    fig2.canvas.draw()
    fig2.canvas.flush_events()
    
#     fig3.canvas.draw()
#     fig3.canvas.flush_events()
    
    
    text.value = f'all values set'
    plt.draw()
    

def regLowPass_changed(v):
    w = v['new']
    text.value = f'{w}'
    regLowPass[:] = v['new']
    dpc[:], error, ps = run_dpc(regLowPass, regHighPass, stepsize, n_iter)
    text.value = f'dpc done'
    set_new_data_and_update(dpc.get(), error.get(), ps)
    i += 1
    text.value = f'{i}'
    
def regHighPass_changed(v):
    regHighPass[:] = v['new']
    dpc[:], error, ps = run_dpc(regLowPass, regHighPass, stepsize, n_iter)
    set_new_data_and_update(dpc.get(), error.get(), ps)
    i += 1
    text.value = f'{i}'
    
def stepsize_changed(v):
    stepsize[:]  = v['new']
    dpc[:], error, ps = run_dpc(regLowPass, regHighPass, stepsize, n_iter)
    set_new_data_and_update(dpc.get(), error.get(), ps)
    i += 1

def n_iter_changed(v):
    n_iter[:] = v['new']
    dpc[:], error, ps = run_dpc(regLowPass, regHighPass, stepsize, n_iter)
    set_new_data_and_update(dpc.get(), error.get(), ps)
    i += 1
    text.value = f'{i}'
    
s1 = FloatLogSlider(description='regLowPass',value=1e3,base=10,step=0.2, min=0, max=6)
s1.observe(regLowPass_changed, names='value')
sliders.append(s1)

s2 = FloatLogSlider(description='regHighPass',value=5e-1,base=10,step=0.2, min=-2, max=4)
s2.observe(regHighPass_changed, names='value')
sliders.append(s2)

s3 = FloatSlider(description='stepsize',value=0.9, min=0.1, max=2)
s3.observe(stepsize_changed, names='value')
sliders.append(s3)

s4 = IntSlider(description='n_iter',value=20, min=1, max=500)
s4.observe(n_iter_changed, names='value')
sliders.append(s4)

text = widgets.HTML(
    value="1",
    placeholder='',
    description='',
)
Cslider_box.children=sliders+ [text]

gs[0,1:] = plot_box1
gs[0,0] = Cslider_box

AppLayout(center=gs)

20


AppLayout(children=(GridspecLayout(children=(HBox(children=(Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset …

In [9]:
m = 10
fig3, ax3 = plt.subplots(1,1,dpi=300)
im3 = ax3.imshow(dpc.get()[m:-m,m:-m], cmap= plt.cm.get_cmap('bone'))
fig3.savefig(results_path / 'dpc.pdf')
AppLayout(center=fig3.canvas)

AppLayout(children=(Canvas(layout=Layout(grid_area='center'), toolbar=Toolbar(toolitems=[('Home', 'Reset origi…

In [11]:
imwrite(results_path /f'scan{scan_number}_dpc', dpc.get().astype('float32'), imagej=True, resolution=(1./(metadata.dr[0]/10), 1./(metadata.dr[1]/10)), metadata={'spacing': 1 / 10, 'unit': 'nm', 'axes': 'YX'})