In [None]:
%matplotlib widget

# base python modules
import numpy as np
import time
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

# NCEM + Molecular Foundry modules
from stempy.io import sparse_array
import stempy

# our module
import stemh_tools as st
import data_selection as ds


dataPath = '/Users/andrewducharme/Documents/Data/phillip_210524/data_scan22_th4.0_electrons.h5'
savePath = '/Users/andrewducharme/Documents/Sparse Code/anlyzd_philipp_210524/'

In [None]:
# open 4dstem data from h5 file into stempy SparseArray format
sa = sparse_array.SparseArray.from_hdf5(dataPath)
sa = sa[:, :-1, :, :]  # cut off flyback frames

In [None]:
scan_row_num = sa.shape[0]  # same as scan_positions.attrs['Ny'] in hdf5 file metadata
scan_col_num = sa.shape[1]  # same as sp.attrs['Nx'] - 1, since flyback is removed
frame_row_num = sa.shape[2]
frame_col_num = sa.shape[3]
# numpy likes (# of rows, # of columns), but stempy likes (# of columns, # of rows)

bf = stempy.image.create_stem_images(sa.data, 0, 35, (scan_col_num, scan_row_num), frame_dimensions=(frame_col_num, frame_row_num))[0]
ds.quick_plot(bf)

In [None]:
# This section computes the Fourier transform of a frame where all three probes go through vacuum
# We use the square around the 0th order as an integral kernel in the phase reconstruction loop later
# We'll also use this computation to find the location of the Fourier peaks
# and how large a square will be selected around first order peak, since these are the same for every frame

vac_frame = sa[100, 1000] # get a frame that isn't obviously garbage
# rep_fft = st.fftw2D(rep_frame[70:330, 40:300])
vac_fft = st.fftw2D(vac_frame)

ds.quick_plot(vac_frame)
ds.quick_plot(vac_fft, log_norm=True)

fft_peaks = st.fft_find_peaks(vac_fft, 2)  # find two highest magnitude peaks in rep_fft

first_order = fft_peaks[1, 1:]  # location of first order peak
selection_size = st.calc_box_size(fft_peaks)

# create kernel
vacuum_kernel = np.conj(vac_fft)  # Appears same as vac_fft if plotted b/c must take abs val to plot
kernel_peak = st.grab_square_box(vacuum_kernel, selection_size)

ds.quick_plot(kernel_peak, log_norm=True)

In [None]:
crop_sa = sa[:800, :800]  # create a copy of the sparse array covering the region of interest
crop_row_num = sa.scan_shape[0]
crop_col_num = sa.scan_shape[1]

# sparse array format changed from (scan_row, scan_col, : ,:) to (scan_row * scan_col, :,:)
rsa = sa.ravel_scans()

# initialize array to store values through loop
phaseMap = np.zeros(rsa.scan_shape[0], dtype=np.float64)

start = time.time()

# the forward Fourier transform is the vast majority of the work+computation time here
for i, frame in enumerate(rsa):
    if i % 100000 == 0:
        print(i)
    if not frame.any():
        phaseMap[i] = 0
        continue
    
    # working numpy code
    # fft = np.fft.fftshift(np.fft.rfft2(frame))

    # working fftw code
    ft = st.fftw2D(frame)  # take Fourier transform of the full frame

    fourier_space_peak = st.grab_square_box(ft, selection_size, first_order)  # select the area around the first peak

    # phase computation
    t_temp = np.sum(kernel_peak * fourier_space_peak)  # convolve kernel and first order peak (* multiplies elementwise)
    phaseMap[i] = np.angle(t_temp)  # get angle of complex number t_temp in the complex plane

phaseMap = phaseMap.reshape(crop_row_num, crop_col_num)

end = time.time()

print("Total time (s): " + str(end - start))
print("Per frame time (ms): " + str((end - start) / scan_row_num / scan_col_num * 1000))
print('1024 x 1024 time (min): ' + str((end - start) / scan_row_num / scan_col_num * 1024 * 1024 / 60))

np.save(savePath + 'phase', phaseMap)