In [6]:
%matplotlib widget

# base python modules
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import pyfftw
from multiprocessing import cpu_count
from time import time

# NCEM/Molecular Foundry modules
from stempy.io import sparse_array
from stempy.image import create_stem_images

# our module
import stemh_tools as st

dataPath = '/Users/andrewducharme/Documents/Data/philipp_211004/data_scan114_th4.5_electrons.h5'
savePath = '/Users/andrewducharme/Documents/Data/philipp_211004/STEMH Processed/'

In [18]:
# open 4dstem data from h5 file into stempy's SparseArray format
sa = sparse_array.SparseArray.from_hdf5(dataPath)
sa = sa[:, :-1, :, :]  # cut off flyback frames

In [3]:
scan_row_num = sa.shape[0]  # same as scan_positions.attrs['Ny'] in hdf5 file metadata
scan_col_num = sa.shape[1]  # same as sp.attrs['Nx'] - 1, since flyback is removed
frame_row_num = sa.shape[2]
frame_col_num = sa.shape[3]

bf = create_stem_images(sa.data, 0, 35, scan_dimensions=(scan_col_num, scan_row_num), frame_dimensions=(frame_col_num, frame_row_num))[0]
fig, ax = plt.subplots()
ax.imshow(bf)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x7f8f917be610>

In [16]:
# This section computes the Fourier transform of a frame where all three probes go through vacuum
# We use the square around the 0th order as an integral kernel in the phase reconstruction loop later
# We'll also use this computation to find the location of the Fourier peaks
# and how large a square will be selected around first order peak, since these are the same for every frame

# get a frame that isn't obviously garbage. Remember stempy uses (col #, row #),
# unlike numpy's (# of rows, # of columns) structure

vac_frame = sa[100, 220]

# find the first order index by computing the real fft to match what we use in the loop
vac_rfft = np.fft.fftshift(np.fft.rfft2(vac_frame))

rfft_peaks = st.fft_find_peaks(vac_rfft, 2)  # find two highest magnitude peaks in vac_rfft

first_order = rfft_peaks[1, 1:]  # location of first order peak
selection_size = st.calc_box_size(rfft_peaks) / 4 # eighth the distance between individual fft peaks

# create kernel by grabbing the neighborhood of the 0th order peak. Use complete fft for symmetric kernel
vac_fft = np.fft.fftshift(np.fft.fft2(vac_frame))
vacuum_kernel = np.conj(vac_fft)  # Appears same as vac_fft if plotted b/c must take abs val to plot
kernel_peak = st.grab_square_box(vacuum_kernel, selection_size)

fig, (ax0, ax1) = plt.subplots(1,2, sharex=True, sharey=True)
ax0.imshow(vac_frame)
ax0.set_title('Vacuum Frame')
ax1.imshow(np.abs(vac_rfft)**2, norm = LogNorm())
ax1.set_title('FFT of Vac. Frame')

fig.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
# setting up pyfftw numpy interface
pyfftw.config.NUM_THREADS = cpu_count()
pyfftw.config.PLANNER_EFFORT = 'FFTW_ESTIMATE'
pyfftw.interfaces.cache.enable()
    
input = np.empty(sa.frame_shape, dtype='float32')
    
# sparse array shape changed from (scan_row, scan_col, : ,:) to (scan_row * scan_col, :,:)
rsa = sa.ravel_scans()

# initialize array to store values through loop
phaseMap = np.zeros(rsa.scan_shape[0], dtype=np.float64)

start = time()

# the forward Fourier transform is the vast majority of the work+computation time here
for i, frame in enumerate(rsa):
    if i % 50e3 == 0:
        print(i)
    if not frame.any():
        phaseMap[i] = 0
        continue

    input[:] = frame

    ft = pyfftw.interfaces.numpy_fft.rfft2(input)  # take Fourier transform of the full frame
    ft = pyfftw.interfaces.numpy_fft.fftshift(ft)  # take Fourier transform of the full frame
    
    fourier_space_peak = st.grab_square_box(ft, selection_size, first_order)  # select the area around the first peak

    # phase computation
    t_temp = np.sum(kernel_peak * fourier_space_peak)  # integrate kernel and first order peak
    phaseMap[i] = np.angle(t_temp)

phaseMap = phaseMap.reshape(scan_row_num, scan_col_num)

end = time()

print("Total time (s): " + str(end - start))
print("Per frame time (ms): " + str((end - start) / scan_row_num / scan_col_num * 1000))
print('1024 x 1024 time (min): ' + str((end - start) / scan_row_num / scan_col_num * 1024 * 1024 / 60))

fig, ax = plt.subplots()
ax.imshow(phaseMap)

# np.save(savePath + 'Phase', phaseMap)

sa.reshape((scan_row_num, scan_col_num, frame_row_num, frame_col_num))

0
50000
Total time (s): 63.351468086242676
Per frame time (ms): 0.9666666883276775
1024 x 1024 time (min): 16.893724822998045


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<stempy.io.sparse_array.SparseArray at 0x7f8f799e2eb0>