In [1]:
%matplotlib widget

# base python modules
import numpy as np
import time
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from multiprocessing import cpu_count
import pyfftw

# NCEM + Molecular Foundry modules
import stempy
from stempy.io import sparse_array

# our module
import stemh_tools as st

In [2]:
# we are making use of the algorithm in Offelli and Petri, IEEE TIM 39, 363-368 (1990)
# to accurately acquire the phase of the wave the FFT gives us
# w/o error due to leakage and other discrete artifacts

def make_window(arr_size, a0, a1, a2, a3=0, d=2):
    coords = np.arange(-arr_size / 2, arr_size / 2)
    term0 = a0 * np.ones(arr_size)
    term1 = a1 * np.cos(2 * np.pi * coords / arr_size)
    term2 = a2 * np.cos(2 * 2 * np.pi * coords / arr_size)
    term3 = a3 * np.cos(3 * 2 * np.pi * coords / arr_size)

    window = np.sum([term0, term1, term2, term3], axis=0)
    if d == 2:
        window = np.outer(window, window)

    return window

def calc_energy(arr):
    arr = np.abs(arr)
    en_comps = pow(arr, 2)
    en = np.sum(en_comps)
    
    return en

# this is the FFT window optimized in the cited paper for this method
enrgy_win = make_window(576, .350139, .485260, .149889, .014712)

def circle_thickness(radius, x_dist):
    y = np.where(x_dist < radius, 2 * np.sqrt(radius ** 2 - x_dist ** 2), 0)

    return y

In [57]:
# paths hard coded from your own device

dataPath = '/Users/andrewducharme/Documents/Data/4D_ISTEM/philipp_211004/data_scan112_th4.5_electrons.h5'
savePath = '/Users/andrewducharme/Documents/Data/4D_ISTEM/OP_reanalysis/philipp_211004'

In [58]:
# open 4dstem data from h5 file into stempy SparseArray format
sa = sparse_array.SparseArray.from_hdf5(dataPath)
sa = sa[:, :-1, :, :]  # cut off flyback frames
# sa = sa.bin_scans(2)  # binning may not be necessary, but if you get no signal in the phase, try it

In [59]:
scan_row_num = sa.shape[0]  # same as scan_positions.attrs['Ny'] in hdf5 file metadata
scan_col_num = sa.shape[1]  # same as sp.attrs['Nx'] - 1, since flyback is removed
frame_row_num = sa.shape[2]
frame_col_num = sa.shape[3]

bf = stempy.image.create_stem_images(sa.data, 0, 30, scan_dimensions=(scan_col_num, scan_row_num), frame_dimensions=(frame_col_num, frame_row_num))[0]

fig, ax = plt.subplots()
ax.imshow(bf)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Find the location of the peaks in the Fourier transform of an interference pattern. The peak location is essentially constant throughout the scan.

The code tries to find the correct value, but double-check that it isn't picking up the 0th order.

In [None]:
# Remember stempy uses (col #, row #) unlike numpy's (# of rows, # of columns) structure

test_loc = [100,100]
test_frame = sa[test_loc]

# find the first order index by computing the real fft to match what we use in the loop
test_rfft = np.fft.rfft2(test_frame)

rfft_peaks = st.fft_find_peaks(test_rfft, 4)  # find two highest magnitude peaks in vac_rfft
print(rfft_peaks)

first_order = rfft_peaks[0, 1:]  # location of first order peak
# selection_size = st.calc_box_size(rfft_peaks) / 4 # eighth the distance between individual fft peaks

print(first_order)

selection_size = 10

# Check the frame isn't obviously garbage.
# The code looks at the raw rFFT, but it's easier for us humans to look at the fftshifted rFFT
# Just remember Fourier peak locations in the fftshifted data are not the locations in the actual analysis
fig, ax = plt.subplots()

# ax.imshow(test_frame, cmap='binary')
ax.imshow(np.abs(test_rfft), norm=LogNorm())
# ax.imshow(np.abs(np.fft.fftshift(test_rfft)), norm=LogNorm())

ax.xaxis.tick_top()
ax.tick_params(labelsize = 12)
fig.tight_layout()

[[404 539 126]
 [177 502 257]
 [175 506 167]
 [173 398 145]]
[539 126]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [62]:
# setting up pyfftw numpy interface
pyfftw.config.NUM_THREADS = cpu_count()
pyfftw.config.PLANNER_EFFORT = 'FFTW_ESTIMATE'
pyfftw.interfaces.cache.enable()
    
base = np.empty(sa.frame_shape, dtype='uint16')
    
# sparse array shape changed from (scan_row, scan_col, : ,:) to (scan_row * scan_col, :,:)
rsa = sa.ravel_scans()

# initialize arrays to store values through loop
arguments = np.zeros(rsa.scan_shape[0], dtype=np.float64)
peaks = np.zeros(rsa.scan_shape[0], dtype=complex)

start = time.time()

# the forward Fourier transform is the vast majority of the work+computation time here
for i, frame in enumerate(rsa):
    if i % 20e3 == 0:
        print(i)
    if not frame.any():
        arguments[i] = 0
        continue

    base[:] = frame * enrgy_win

    ft = pyfftw.interfaces.numpy_fft.rfft2(base)  # take Fourier transform of the windowed frame
    
    fourier_space_peak = st.grab_square_box(ft, selection_size, first_order)  # select the area around the first peak
    peaks[i] = ft[first_order[0], first_order[1]] # grab actual value at FFT peak
    
    Ehat_x = calc_energy(fourier_space_peak)
    Ehat_c = calc_energy(fourier_space_peak.real)
    
    arguments[i] = Ehat_c / Ehat_x

argSigns = np.sign(peaks.real)
phaseSigns = np.sign(peaks.imag)

phaseMap = phaseSigns * np.arccos(argSigns * np.sqrt(arguments))
phaseMap = phaseMap.reshape(scan_row_num, scan_col_num)

end = time.time()

print("Total time (s): " + str(end - start))
print("Per frame time (ms): " + str((end - start) / scan_row_num / scan_col_num * 1000))
print('1024 x 1024 time (min): ' + str((end - start) / scan_row_num / scan_col_num * 1024 * 1024 / 60))

fig, ax = plt.subplots()
ax.imshow(phaseMap)
plt.plot(vac_loc[0],vac_loc[1], 'ro')  # shows where our selection above is in the overall scan

# np.save(savePath + '14Phase_256_ord1', phaseMap)

sa.reshape((scan_row_num, scan_col_num, frame_row_num, frame_col_num))

0
20000
40000
60000
Total time (s): 76.76662087440491
Per frame time (ms): 1.1713656749634538
1024 x 1024 time (min): 20.47109889984131


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<stempy.io.sparse_array.SparseArray at 0x7fd1626cf790>

In [43]:
# check best locations for plane subtractions for phase map

loc = (100,20)

fig, ax = plt.subplots()
ax.xaxis.tick_top()
image = ax.imshow(st.plane_subtract(phaseMap,loc,10,15), cmap='seismic')
plt.colorbar(image)

plt.plot(loc[0],loc[1], 'ro')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [63]:
# plot the phase map with some plane subtraction and the phase minimum eliminated

psed = st.plane_subtract(phaseMap, (100,20), 10, 15)
psed = psed - psed.min()

fig, ax = plt.subplots()
ax.xaxis.tick_top()
image = ax.imshow(psed, cmap='seismic')
plt.colorbar(image)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.colorbar.Colorbar at 0x7fd164fe4d60>

In [65]:
np.save(savePath + '112Phase_256_ord1', phaseMap)

In [65]:
plt.close('all')