In [1]:
%matplotlib widget

# base python modules
import numpy as np
import time
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from multiprocessing import cpu_count
import pyfftw

# NCEM + Molecular Foundry modules
from stempy.io import sparse_array
import stempy

# our module
import stemh_tools as st

In [3]:
# paths hard coded from your own device

dataPath = '/Users/andrewducharme/Documents/Data/4D_ISTEM/philipp_211004/data_scan110_th4.5_electrons.h5'
savePath = '/Users/andrewducharme/Documents/Data/4D_ISTEM/OP_reanalysis/philipp_211004'

In [None]:
# open 4dstem data from h5 file into stempy SparseArray format
sa = sparse_array.SparseArray.from_hdf5(dataPath)
sa = sa[:, :-1, :, :]  # cut off flyback frames
# sa = sa.bin_scans(2)  # binning may not be necessary, but if you get no signal in the phase, try it

In [5]:
scan_row_num = sa.shape[0]  # same as scan_positions.attrs['Ny'] in hdf5 file metadata
scan_col_num = sa.shape[1]  # same as sp.attrs['Nx'] - 1, since flyback is removed
frame_row_num = sa.shape[2]
frame_col_num = sa.shape[3]

bf = stempy.image.create_stem_images(sa.data, 0, 30, scan_dimensions=(scan_col_num, scan_row_num), frame_dimensions=(frame_col_num, frame_row_num))[0]

fig, ax = plt.subplots()
ax.imshow(bf)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [6]:
# This section computes the Fourier transform of an individual frame
# This finds the location of the Fourier peaks and defines how large a square will be selected around first order peak
# These values are the same for every frame

# Remember stempy uses (col #, row #) unlike numpy's (# of rows, # of columns) structure

vac_loc = [100,100]
vac_frame = sa[vac_loc]

# find the first order index by computing the real fft to match what we use in the loop
vac_rfft = np.fft.rfft2(vac_frame)

rfft_peaks = st.fft_find_peaks(vac_rfft, 4)  # find two highest magnitude peaks in vac_rfft
print(rfft_peaks)

first_order = rfft_peaks[0, 1:]  # location of first order peak
# selection_size = st.calc_box_size(rfft_peaks) / 4 # eighth the distance between individual fft peaks

print(first_order)

selection_size = 3

# Check the frame isn't obviously garbage.
# The code looks at the raw rFFT, but it's easier for us humans to look at the fftshifted rFFT
# Just remember Fourier peak locations in the fftshifted data are not the locations in the actual analysis
fig, ax = plt.subplots()
# ax.imshow(vac_frame, cmap='binary')
# ax.imshow(np.abs(vac_rfft), norm=LogNorm())
ax.imshow(np.abs(np.fft.fftshift(vac_rfft)), norm=LogNorm())
ax.xaxis.tick_top()
ax.tick_params(labelsize = 12)
fig.tight_layout()

[[793 516 201]
 [540 119 175]
 [299 448 102]
 [293 508 260]]
[516 201]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
# setting up pyfftw numpy interface
pyfftw.config.NUM_THREADS = cpu_count()
pyfftw.config.PLANNER_EFFORT = 'FFTW_ESTIMATE'
pyfftw.interfaces.cache.enable()

In [None]:
base = np.empty(sa.frame_shape, dtype='uint16')
    
# sparse array shape changed from (scan_row, scan_col, : ,:) to (scan_row * scan_col, :,:)
rsa = sa.ravel_scans()

# initialize arrays to store values through loop
peaks = np.zeros(rsa.scan_shape[0], dtype=complex)

start = time.time()

# the forward Fourier transform is the vast majority of the work+computation time here
for i, frame in enumerate(rsa):
    if i % 20e3 == 0:
        print(i)
    if not frame.any():
        peaks[i] = 0
        continue

    base[:] = frame

    ft = pyfftw.interfaces.numpy_fft.rfft2(base)  # take Fourier transform of the windowed frame
    
    fourier_space_peak = st.grab_square_box(ft, selection_size, first_order)  # select the area around the first peak
    peaks[i] = np.sum(fourier_space_peak)

phaseMap = np.angle(peaks)
phaseMap = phaseMap.reshape(scan_row_num, scan_col_num)

end = time.time()

print("Total time (s): " + str(end - start))
print("Per frame time (ms): " + str((end - start) / scan_row_num / scan_col_num * 1000))
print('1024 x 1024 time (min): ' + str((end - start) / scan_row_num / scan_col_num * 1024 * 1024 / 60))

fig, ax = plt.subplots()
ax.imshow(phaseMap)

plt.show()

# np.save(savePath + '220615_110_Phase_256_ord1_phaseSign', phaseMap)

sa.reshape((scan_row_num, scan_col_num, frame_row_num, frame_col_num))

0
20000
40000
60000
Total time (s): 89.26563310623169
Per frame time (ms): 1.362085466098506
1024 x 1024 time (min): 23.80416882832845


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<stempy.io.sparse_array.SparseArray at 0x7fb8d938ff70>

In [65]:
np.save(savePath + '112Phase_256_ord1', phaseMap)

In [108]:
plt.close('all')