# Retrieve phase from center of mass
 - uses stempy functions

In [None]:
%matplotlib widget

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import imageio

import stempy.io as stio
import stempy.image as stim

In [None]:
def phase_from_com(com, theta=0, flip=False, reg=1e-10):
    """Integrate 4D-STEM centre of mass (DPC) measurements to calculate
    object phase. Assumes a three dimensional array com, with the final
    two dimensions corresponding to the image and the first dimension 
    of the array corresponding to the y and x centre of mass respectively.
    Note this version of the reconstruction is not quantitative.
       
    author: Hamish Brown
    
    Parameters
    ----------
        com : ndarray, 3D
            The center of mass for each frame as a 3D array of size [2, M, N]
        theta : float
            The angle between real space and reciprocal space in radians
        flip : bool
            Whether to flip the com direction to account for a mirror across the vertical axis.
        reg : float
            A regularization parameter
            
    Returns
    -------
        : ndarray, 2D
            A 2D ndarray of the DPC phase.
    
    """
    # Perform rotation and flipping if needed (from py4dstem)
    CoMx = com[0,]
    CoMy = com[1,]
    if not flip:
        CoMx_rot = CoMx*np.cos(theta) - CoMy*np.sin(theta)
        CoMy_rot = CoMx*np.sin(theta) + CoMy*np.cos(theta)
    if flip:
        CoMx_rot = CoMx*np.cos(theta) + CoMy*np.sin(theta)
        CoMy_rot = CoMx*np.sin(theta) - CoMy*np.cos(theta)
    
    # Get shape of arrays
    ny, nx = com.shape[1:]

    # Calculate Fourier coordinates for array
    ky, kx = [np.fft.fftfreq(x) for x in [ny,nx]]

    # Calculate numerator and denominator expressions for solution of 
    # phase from centre of mass measurements
    numerator = ky[:,None]*np.fft.fft2(CoMx_rot)+kx[None,:]*np.fft.fft2(CoMy_rot)
    denominator = 2*np.pi*1j*((kx**2)[None,:]+(ky**2)[:,None])+reg
    # Avoid a divide by zero for the origin of the Fourier coordinates
    numerator[0,0] = 0
    denominator[0,0] = 1

    # Return real part of the inverse Fourier transform
    return np.real(np.fft.ifft2(numerator/denominator))


In [None]:
# Load a sparse 4D camera data set

# Close all previous windows to avoid too many windows
plt.close('all')

scan_num = 18
threshold = 4.5

dPath = Path('/mnt/hdd1/2021.03.02')
fPath = Path('data_scan{}_th{}_electrons.h5'.format(scan_num, threshold))

fname = dPath / fPath

electron_events = stio.load_electron_counts(str(fname))

print('File: {}'.format(fname))
print('Initial scan dimensions = {}'.format(electron_events.scan_dimensions))

In [None]:
# Calculate a summed diffraction pattern of frames
# And find the center
dp = stim.calculate_sum_sparse(electron_events.data[::10],
                               electron_events.frame_dimensions)

# Set the center of the pattern (use figure below for manual)
center = stim.com_dense(dp)
#center = (248, 284)
print(center)

fg,ax = plt.subplots(1, 1)
ax.imshow(np.log(dp+1))
ax.scatter(center[0], center[1], c='r')
_ = ax.legend(['center of pattern'])

In [None]:
# Calculate a virtual bright field and dark field
outer_angle = 30 # in pixels

ims = stim.create_stem_images(electron_events, (0, 180), (50, 280), center=center) # here center is (col, row)
bf = ims[0,]
adf = ims[1,]

fg,ax = plt.subplots(1, 2, sharex=True, sharey=True)
ax[0].imshow(bf)
ax[0].set(title='vBF')
ax[1].imshow(adf)
ax[1].set(title='vADF')

In [None]:
# Calculate the center of mass of every frame
com = stim.com_sparse(electron_events.data, electron_events.frame_dimensions)

# This can be removed in the future
com = com.reshape((2, electron_events.scan_dimensions[1], electron_events.scan_dimensions[0]))

fg,ax = plt.subplots(1,2,sharex=True,sharey=True)
axim0 = ax[0].imshow(com[0,], cmap='bwr',vmin=com[0,10:-10,:].min(),vmax=com[0,10:-10,].max())
axim1 = ax[1].imshow(com[1,], cmap='bwr',vmin=com[1,10:-10,:].min(),vmax=com[1,10:-10,].max())

In [None]:
# Calculate the radius and angle for each COM measurement
com_mean = np.mean(com,axis=(1,2))
com_r = np.sqrt( (com[0,] - com_mean[0])**2 + (com[1,] - com_mean[1])**2 )
com_theta = np.arctan2((com[1,] - com_mean[1]), (com[0,] - com_mean[0]))

fg,ax = plt.subplots(1, 2,sharex=True,sharey=True)
ax[0].imshow(com_r,cmap='magma',vmin=com_r[10:-10,:].min(),vmax=com_r[10:-10,].max())
ax[1].imshow(com_theta, cmap='twilight')

In [None]:
# Retrieve phase from center of mass
# 300kV: flip=True and theta=0 + STEM scan rotation
# 80 kV: flip=True and theta=35 works well.
flip = True
theta = 0 * np.pi / 180. # rotation between diffraction and real space scan directions

# Calculate the phase
ph = phase_from_com(com, flip=flip, theta=theta, reg=1e-1)

fg,ax = plt.subplots(1,2,sharex=True,sharey=True)
#ax[0].imshow(ph, vmin=ph[10:-10,10:-10].min(), vmax=ph[10:-10,10:-10].max())
ax[0].imshow(ph / np.std(ph), vmin=-2, vmax=2)
ax[0].set(title = 'DPC')
ax[1].imshow(adf)
ax[1].set(title = 'vADF')

fg,ax = plt.subplots(1,2,sharex=True,sharey=True)
ax[0].imshow(np.log(np.abs(np.fft.fftshift(np.fft.fft2(ph)))),vmin=1e-3)
ax[1].imshow(np.log(np.abs(np.fft.fftshift(np.fft.fft2(bf)))))

In [None]:
fg,ax = plt.subplots(1, 3, sharex=True, sharey=True,figsize=(12,5))
ax[0].imshow(bf)
ax[0].set(title='BF')
ax[1].imshow(adf)
ax[1].set(title='ADF')
ax[2].imshow(ph,vmin=ph[10:-10,].min(),vmax=ph[10:-10,].max())
ax[2].set(title = 'DPC')

In [None]:
# Save the data
print('Saving COM and DPC for scan number {}'.format(scan_num))
imageio.imwrite(fname.with_name('scan{}_DPC'.format(scan_num)).with_suffix('.tif'), ph.astype(np.float32))
imageio.imwrite(fname.with_name('scan{}_comx'.format(scan_num)).with_suffix('.tif'), com[0,].astype(np.float32))
imageio.imwrite(fname.with_name('scan{}_comy'.format(scan_num)).with_suffix('.tif'), com[1,].astype(np.float32))
imageio.imwrite(fname.with_name('scan{}_BF'.format(scan_num)).with_suffix('.tif'), bf.astype(np.float32))
imageio.imwrite(fname.with_name('scan{}_ADF'.format(scan_num)).with_suffix('.tif'), adf.astype(np.float32))
