In [21]:
import phase_STEM as ps
import os

import numpy as np
import matplotlib.pyplot as plt
import h5py

In [2]:
%matplotlib qt5

In [3]:
path_file = "C:/Users/yuxi8598/Documents/Data_SU/Y Xia/20240913/ZSM-5(L+crushed)/"

In [4]:
file = path_file + '1440 1.30 Mx DF4.emd'

### Using "hyperspy" to load raw dataset

In [5]:
import hyperspy.api as hs

In [6]:
s = hs.load(file)

In [7]:
# we can quick check the data in the loaded dataset
s

[<Signal2D, title: DF4-A, dimensions: (|2048, 2048)>,
 <Signal2D, title: DF4-C, dimensions: (|2048, 2048)>,
 <Signal2D, title: Unrecognized_image_signal, dimensions: (|2048, 2048)>,
 <Signal2D, title: iDPC, dimensions: (|2048, 2048)>,
 <Signal2D, title: dDPC, dimensions: (|2048, 2048)>,
 <Signal2D, title: A-C, dimensions: (|2048, 2048)>,
 <Signal2D, title: DF4-D, dimensions: (|2048, 2048)>,
 <Signal2D, title: DF4-B, dimensions: (|2048, 2048)>,
 <Signal2D, title: B-D, dimensions: (|2048, 2048)>]

In [8]:
# Here, we can find the acquisition conditions

info = ps.tools.information_data(s)
for mem, value in info.items():
    print(f"{mem} --> {value}")

original_filename --> 1440 1.30 Mx DF4.emd
data --> 2024-09-13/14:40:39
beam_voltage (kV) --> 300.0
camera_length(mm) --> 228.3
resolution --> 0.03716882341069085
unit --> nm
size in pixel --> 2048
semi_convergence_angle (rad) --> 0.015
DF4_collection_angle(rad) --> [0.006 0.034]
HAADF_collection_angle(rad) --> [0. 0.]
defocus (nm) --> -352.24
Last measured beam_current (pA) --> 3.451
dwell_time(us) --> 5.0
magnification --> 1.3_Mx
dose_rate(eÅ-2) --> 779.66
tilt_alpha(deg) --> 1.976
tilt_beta(deg) --> 7.379
stage_x (um) --> -775.414
stage_y(um) --> 85.262
stage_z(um) --> 79.733


In [9]:
# We'll need "resolution" and "wavelength" for image displaying

resolution = info['resolution']
print(f'The resolution of images is: {round(resolution, 5)} nm per pixel'+'\n')
wavelength = ps.analysis.wavelength_beam(300)
print(f'The wavelength of beam is: {round(wavelength, 5)} nm ')

The resolution of images is: 0.03717 nm per pixel

The wavelength of beam is: 0.00197 nm 


In [10]:
# We need to extract the right segment images from the loaded dataset

DPC_imgs, titles = ps.tools.extract_segmented_image(s, order_map=None) # the defaulted "order_map" is "DF4-A, DF4-B, DF4-C, DF4-D"

In [156]:
crop = np.zeros((4, 512, 512))
for i in range(4):
    crop[i] = ps.tools.crop_matrix(images_OBF[i], [500,1400], [512,512])

In [160]:
fig, axes = plt.subplots(2, 2, figsize =(6, 6), sharex=True, sharey=True)
mask = ps.tools.circle_mask((2048,2048), (1024,1024), (0,550))
for i, ax in enumerate(axes.ravel()):
    if i < 4:
        pcm = ax.imshow(aligned_images[i], cmap = 'viridis')
        #pcm = ax.imshow(np.log(np.abs(OBF_Q[i]*mask)+1), cmap = 'viridis')
        ax.set_title(f'DPC image. {i+1}-th')
        fig.colorbar(pcm, ax=ax, shrink=0.4, extend = 'both')
    else:
        ax.axis('off')  # Turn off the axis for empty plots
plt.tight_layout()
#plt.savefig(f"{path_file}/LCoMxy.tiff")
plt.show()

In [52]:
def align_images(images):
    """
    Align images using cross-correlation.
    
    :param images: NumPy array of shape (4, x, y) containing 4 images
    :return: Aligned images as a NumPy array of shape (4, x, y)
    """
    num_images, height, width = images.shape
    aligned_images = np.zeros_like(images)
    aligned_images[0] = images[0]  # Use the first image as reference
    
    for i in range(1, num_images):
        # Compute cross-correlation
        correlation = signal.correlate2d(images[0], images[i], mode='same')
        
        # Find the shift that gives maximum correlation
        y_shift, x_shift = np.unravel_index(np.argmax(correlation), correlation.shape)
        y_shift -= height // 2
        x_shift -= width // 2
        
        # Apply the shift
        if x_shift > 0:
            aligned_images[i, :, x_shift:] = images[i, :, :-x_shift]
        elif x_shift < 0:
            aligned_images[i, :, :x_shift] = images[i, :, -x_shift:]
        else:
            aligned_images[i, :, :] = images[i, :, :]
        
        if y_shift > 0:
            aligned_images[i, y_shift:, :] = aligned_images[i, :-y_shift, :]
        elif y_shift < 0:
            aligned_images[i, :y_shift, :] = aligned_images[i, -y_shift:, :]
    
    return aligned_images

In [54]:
from scipy import signal

In [158]:
start = time.time()
aligned_images = align_images(crop)
print(time.time() - start)

292.5711476802826


In [159]:
plt.imshow(np.sum(aligned_images, axis = 0))

<matplotlib.image.AxesImage at 0x210cd3a6900>

In [57]:
DPCs = [aligned_images[0], aligned_images[1], aligned_images[2], aligned_images[3]]

In [189]:
#generating the DPCx and DPCy

CoMx = DPC_imgs[0] - DPC_imgs[2]
CoMy = DPC_imgs[1] - DPC_imgs[3]

#CoMx = images_OBF[0] - images_OBF[2]
#CoMy = images_OBF[1] - images_OBF[3]

In [190]:
CoM = [CoMx, CoMy]
fig, axes = plt.subplots(1, 2, figsize =(6, 6), sharex=True, sharey=True)

for i, ax in enumerate(axes.ravel()):
    if i < 2:
        pcm = ax.imshow(CoM[i], cmap = 'viridis')
        ax.set_title(f'DPC image. {i+1}-th')
        fig.colorbar(pcm, ax=ax, shrink=0.4, extend = 'both')
    else:
        ax.axis('off')  # Turn off the axis for empty plots
plt.tight_layout()
#plt.savefig(f"{path_file}/LCoMxy.tiff")
plt.show()

In [191]:
ps.tools.plot_vector_image(CoMx + 1j*CoMy, title=['1', '2', '3', '4'], imgsize = 6, storing = [False, 'path_save'])

In [192]:
#method from py4DSTEM

crop_CoMx = ps.tools.crop_matrix(CoMx, [1024,1024], [512,512])
crop_CoMy = ps.tools.crop_matrix(CoMy, [1024,1024], [512,512])

theta, flip = ps.analysis.get_rotation_and_flip_maxcontrast(crop_CoMx, crop_CoMy, 360, paddingfactor=1,
                                      regLowPass=0, regHighPass=0.01, stepsize=1,
                                      n_iter=1)
print(f'The theta is {np.degrees(theta):.2f}')
print(f'Does it need to be fliped: {flip}')

Building:   0%|           [ time left: ? ]

The theta is 80.22
Does it need to be fliped: True


In [193]:
idpc, errors = ps.analysis.get_phase_from_CoM(CoMx, CoMy, theta, flip,
                                 regLowPass=0, regHighPass=0.0001,
                                   paddingfactor=2, stepsize=1, n_iter=20)

Building:   0%|           [ time left: ? ]

In [264]:
def fourier_transform_rotate_crop(image, rotation_angle):

    # Step 2: Perform Fourier Transform
    f_transform = np.fft.fftshift(np.fft.fft2(image))

    # Step 3: Rotate the Fourier transformed image
    rotated_image = ndimage.rotate(image, rotation_angle, reshape=True, mode='reflect')

    # Step 5: Inverse Fourier Transform to get the final image
    
    inverse_transform = np.fft.fftshift(np.fft.fft2(rotated_image))
    h, w = image.shape
    y, x = inverse_transform.shape
    startx = x//2 - w//2
    starty = y//2 - h//2
    reversed_image = np.fft.ifft2(np.fft.fftshift(inverse_transform[starty:starty+h, startx:startx+w]))
    # Normalize the image for display
    final_image = np.real(reversed_image)

    fig, axes = plt.subplots(2,2)
    axes[0,0].imshow(image)
    axes[0,1].imshow(np.log(np.abs(f_transform)+1))
    axes[1,0].imshow(rotated_image)
    axes[1,1].imshow(final_image)
    plt.show()
    return final_image

In [265]:
rotated_idpc = fourier_transform_rotate_crop(cp.asnumpy(idpc), 45)

In [250]:
rotated_idpc = ndimage.rotate(cp.asnumpy(idpc), 45, reshape=True)
plt.imshow(rotated_idpc)

<matplotlib.image.AxesImage at 0x210c6722e40>

In [195]:
import cupy as cp

In [197]:
plt.scatter(np.arange(1,21,1), cp.asnumpy(1/errors))

<matplotlib.collections.PathCollection at 0x210cea8df40>

In [162]:
# As the "idpc" is cp.array, if you want to plot is using 'matplotlib', then you should convert it into np.array
#We can use the following commander to plot/save the 'idpc'

ps.tools.plot_image(DPC_imgs, properties= {
             'resolution': resolution,
              'unit': 'nm',
              'bar location':'',
             'image titles': "iDPC",
             'figsize':6,
              'cmap':'',
              'dpi': 600,
              'image format':'.jpeg',
              'showing titles': True,
             'cropping image': [True, [500,1400], [256,256]],
             'saving image': False,   # if you want to save it, then 'True'
             'saving path':path_file+ "/1440iDPC"  # choose the saving folder address and the name of this figure
              })

### Reconstructing OBF image

In [58]:
# Firstly, we need the aberrations measured by probe Cs-corrector
#Adopting the notations supplied by Uhlemann and Haider

ab = {        'C1': 0,                                   #defocus (over focus positive), a real value
              'A1': 1.8+ 1j*np.radians(-56.31),     #Two-fold astigamastism
              'B2': 47.2 + 1j*np.tan(34.1/32.7),    #Axial coma, a complex, here using the result measured by Sherpa
              'A2': 30.7 + 1j*np.tan(-30.7/2.1),         #Three-fold astigmatism
              'C3': 884.4,                                #Spherical aberration
              'A3': 289 + 1j*np.radians(88.5),          #Four-fold astigmatism
              'S3': 132.7 + 1j*np.radians(-97.2),          #Axial star aberration
              'A4': 7246 + 1j*np.radians(84.2)}

In [116]:
# Firstly, we need the aberrations measured by probe Cs-corrector
#Adopting the notations supplied by Uhlemann and Haider

ab = {        'C1': 0,                                   #defocus (over focus positive), a real value
              'A1': 0 + 1j*np.radians(0),     #Two-fold astigamastism
              'B2': 0 + 1j*np.radians(0),    #Axial coma, a complex, here using the result measured by Sherpa
              'A2': 0 + 1j*np.radians(0),         #Three-fold astigmatism
              'C3': 0,                                #Spherical aberration
              'A3': 0 + 1j*np.radians(0),          #Four-fold astigmatism
              'S3': 0 + 1j*np.radians(0),          #Axial star aberration
              'A4': 0 + 1j*np.radians(0)}

In [59]:
wavelength = ps.analysis.wavelength_beam(300)

In [183]:
#describing the geometric planes of DF4 detectors
c = 0
segments = np.array(([-45+c,45+c], [45+c,135+c], [135+c,225+c], [225+c-360,315+c-360]), dtype = np.float64)
print(segments)

[[ -45.   45.]
 [  45.  135.]
 [ 135.  225.]
 [-135.  -45.]]


In [201]:

parameters = {"sample thickness(nm)": 10,
             "wavelength(nm)": wavelength,
             "resolution(nm)": resolution,
              "dose rate(e-Å-2)": 1,
             "collection angles(rad)": [0.010, 0.030], 
             "semi_convergence angle(rad)": 0.01,
             "pixelsize of filters": DPC_imgs[0].shape[0]/8, #controlling the pixel size of the filter
             "virtual grids in one segment detector": 20}

In [143]:
import time

In [202]:
#built_WPO = ps.analysis.phase_filters_GPU(ab, segments, parameters, slices=[False, 20], single_side_band=False, process=True)

built_WPO2, num = ps.analysis.phase_filters_GPU(ab, segments, parameters, slices=1, process=True)

Building:   0%|           [ time left: ? ]

--- 12 seconds left!
The whole process takes 46 seconds.


#### Directly use built phase filters

In [72]:
phase_filters = h5py.File('C:/Users/yuxi8598/Documents/Data_SU/google colab/1.3Mx/1.30 Mx_GPU_2048px.h5', 'r')

In [73]:
ps.tools.data_tree(phase_filters)

  |>--[92mDataset: /seg_1[0m
  |>---Here is an array with a shape of (2048, 2048)
  |>--[92mDataset: /seg_2[0m
  |>---Here is an array with a shape of (2048, 2048)
  |>--[92mDataset: /seg_3[0m
  |>---Here is an array with a shape of (2048, 2048)
  |>--[92mDataset: /seg_4[0m
  |>---Here is an array with a shape of (2048, 2048)


In [75]:
phase_filters.keys()

<KeysViewHDF5 ['seg_1', 'seg_2', 'seg_3', 'seg_4']>

In [76]:
built_WPO = []
for ft in phase_filters:
    built_WPO.append(phase_filters[ft][()])

In [314]:
ps.tools.plot_vector_image(built_WPO2, title=['1', '2', '3', '4'], imgsize = 6, storing = [False, 'path_save'])

In [207]:
OBF_R, OBF_Q = reconstruct_OBF(DPC_imgs, built_WPO2, parameters)

In [205]:
images_OBF = np.zeros((4, 2048,2048))
for i in range(4):
    images_OBF[i] = np.real(pyfftw.interfaces.numpy_fft.ifft2(pyfftw.interfaces.numpy_fft.ifftshift(OBF_Q[i]*mask)))


In [111]:
ddd = [DPC_imgs[2], DPC_imgs[3], DPC_imgs[0], DPC_imgs[1]]

In [206]:
fig, axes = plt.subplots(2, 2, figsize =(6, 6), sharex=True, sharey=True)
mask = ps.tools.circle_mask((2048,2048), (1024,1024), (0,550))
for i, ax in enumerate(axes.ravel()):
    if i < 4:
        pcm = ax.imshow(images_OBF[i], cmap = 'viridis')
        #pcm = ax.imshow(np.log(np.abs(OBF_Q[i]*mask)+1), cmap = 'viridis')
        ax.set_title(f'DPC image. {i+1}-th')
        fig.colorbar(pcm, ax=ax, shrink=0.4, extend = 'both')
    else:
        ax.axis('off')  # Turn off the axis for empty plots
plt.tight_layout()
#plt.savefig(f"{path_file}/LCoMxy.tiff")
plt.show()

In [86]:
import scipy.ndimage as ndimage

In [92]:
def reconstruct_OBF(DPC_imgs, PCTFs, parameters):
    """
    It is used to reconstruct the OBF image based on segmented images.
    For enhancing the calculating speed, pyFFTW is utilized.
    
    Args:
        DPC_imgs (list): A list of segmented images (numpy.ndarray).
        PCTFs (list): A list of corresponding phase filters (numpy.ndarray).
        parameters: dictionary, recording the key information of experiments for the calculation.
    Returns:
        All the returned data is numpy.ndarray for the convenience of plotting.
        1. OBF_image: Reconstructed OBF in real space.
        2. OBF_Q: reconstructed OBF in Fourier domain
   
    """

    wavelength = parameters["wavelength(nm)"]
    collection_angle = parameters["collection angles(rad)"]
    resolution = parameters["resolution(nm)"]

    sizeX_img, sizeY_img = DPC_imgs[0].shape
    sizeX, sizeY = PCTFs[0].shape
    ratioX = (sizeX_img / sizeX)
    ratioY = (sizeY_img / sizeY)
    num = len(DPC_imgs)
    phase_filters = []
    #normalizing the PCTFs 
    for p in PCTFs:
        if ratioX != 1 or ratioY != 1:
            temp_filter = ndimage.zoom(1j * p, (ratioX, ratioY))
        else:
            temp_filter = 1j * p
        phase_filters.append(temp_filter)
        
    d_Q = np.zeros(num)
    for i in range(num):
        bkg = min(abs(np.min(DPC_imgs[i])), abs(np.max(DPC_imgs[i])))
        if bkg != 0:
            d_Q[i] = bkg
        else: d_Q[i] = 1
    KQ_squared = np.zeros((sizeX_img, sizeY_img), dtype=np.float64)
    for i in range(num):
        KQ_squared += np.real(np.square(np.abs(phase_filters[i]))) / d_Q[i]
    # Calculate the weighting of filters
    weighting = np.sqrt(KQ_squared)
    weighting[weighting == 0] = np.inf
    weightingInv = np.reciprocal(weighting)

    OBF_Q = pyfftw.empty_aligned((num, sizeX_img, sizeY_img), dtype=np.complex128)    
    OBF_Q.fill(0)
    # Create aligned arrays for FFT input and output

    for n in range (num):
        wq = np.conj(phase_filters[n]) * weightingInv /d_Q[n]     
        dft_DPCs = pyfftw.empty_aligned((sizeX_img, sizeY_img), dtype=np.complex128)
        dft_DPCs[:] = pyfftw.interfaces.numpy_fft.fft2(DPC_imgs[n])
        dft_DPCs = pyfftw.interfaces.numpy_fft.fftshift(dft_DPCs)  
        OBF_Q[n] = dft_DPCs * wq    
    # There are usually shifts existing in the segmented images
    # It is helpful to improve the quality of the reconstructed OBF image by doing shift correction
    # Use the first image as the reference
    reference = pyfftw.interfaces.numpy_fft.ifft2(pyfftw.interfaces.numpy_fft.ifftshift(OBF_Q[0]))  
    aligned_images = [reference.real]
    for i in range(1, num):
        # Compute the shift between the reference and the current image
        compare = pyfftw.interfaces.numpy_fft.ifft2(pyfftw.interfaces.numpy_fft.ifftshift(OBF_Q[i]))
        moving, error, diffphase = phase_cross_correlation(reference.real, compare.real, upsample_factor=200)
        # Apply the shift to the current image
        corrected_image = ndimage.shift(compare.real, moving)
        aligned_images.append(corrected_image)
        
    OBF_image = np.sum(np.array(aligned_images), axis=0)
    #plotting the reconstructed OBF in Fourier domain 
    reciprocal_res = 1/resolution
    extend_edge = 0.5*reciprocal_res
    extend = [-extend_edge, extend_edge, -extend_edge, extend_edge]
    summation = np.sum(OBF_Q, axis = 0)
    summation[sizeX_img//2, sizeY_img//2] = 0
    dft_amp = np.log(np.abs(summation/sizeX_img)+1)
    fig, (ax0, ax1) = plt.subplots(1, 2)
    ax0.imshow(dft_amp, extent = extend, interpolation='hanning', vmax = np.max(dft_amp)*1.1)
    ax0.set_title('Amplitude of OBF in Fourier domain')
    ax0.set_xlabel('Length (1/Å)')
    ax0.set_ylabel('Length (1/Å)')
    length = [-resolution * sizeX_img/2, resolution * sizeX_img/2, -resolution * sizeY_img/2, resolution * sizeY_img/2]
    ax1.imshow(OBF_image, extent = length, interpolation='gaussian')
    ax1.set_title('Reconstructed OBF')
    ax1.set_xlabel('Length (nm)')
    ax1.set_ylabel('Length (nm)')
    plt.tight_layout()
    plt.show()
    return OBF_image, OBF_Q

In [441]:
from skimage import draw

In [437]:

class OBFReconstructor:
    def __init__(self, DPC_imgs, PCTFs, parameters):
        """
        Class to reconstruct the OBF image based on segmented images using pyFFTW for enhanced performance.

        Args:
            DPC_imgs (list): A list of segmented images (numpy.ndarray).
            PCTFs (list): A list of corresponding phase filters (numpy.ndarray).
            parameters (dict): Dictionary containing key information for the experiment such as:
                - 'wavelength(nm)': Wavelength of light used.
                - 'collection angles(rad)': Collection angles of the imaging system.
                - 'resolution(nm)': Resolution of the imaging system.
        """
        self.DPC_imgs = DPC_imgs
        self.PCTFs = PCTFs
        self.resolution = parameters["resolution(nm)"]
        self.num = len(DPC_imgs)
        self.sizeX_img, self.sizeY_img = DPC_imgs[0].shape
        self.sizeX, self.sizeY = PCTFs[0].shape
        self.ratioX = self.sizeX_img / self.sizeX
        self.ratioY = self.sizeY_img / self.sizeY

    def normalize_PCTFs(self):
        """ Normalize the phase filters based on the size ratios. """
        phase_filters = []
        for p in self.PCTFs:
            if self.ratioX != 1 or self.ratioY != 1:
                temp_filter = ndimage.zoom(1j * p, (self.ratioX, self.ratioY))
            else:
                temp_filter = 1j * p
            phase_filters.append(temp_filter)
        return phase_filters

    def compute_dQ(self):
        """ Compute the d_Q array for background correction. """
        d_Q = np.zeros(self.num)
        for i in range(self.num):
            bkg = min(abs(np.min(self.DPC_imgs[i])), abs(np.max(self.DPC_imgs[i])))
            d_Q[i] = bkg if bkg != 0 else 1
        return d_Q

    def compute_weighting(self, phase_filters, d_Q):
        """ Compute the weighting array using the phase filters and d_Q. """
        KQ_squared = np.zeros((self.sizeX_img, self.sizeY_img), dtype=np.float64)
        for i in range(self.num):
            KQ_squared += np.real(np.square(np.abs(phase_filters[i]))) / d_Q[i]
        
        weighting = np.sqrt(KQ_squared)
        weighting[weighting == 0] = np.inf  # Avoid division by zero
        return np.reciprocal(weighting)

    def compute_OBF_Q(self, phase_filters, d_Q, weightingInv):
        """ Compute the OBF in the Fourier domain. """
        OBF_Q = pyfftw.empty_aligned((self.num, self.sizeX_img, self.sizeY_img), dtype=np.complex128)
        OBF_Q.fill(0)
        for n in range(self.num):
            wq = np.conj(phase_filters[n]) * weightingInv / d_Q[n]
            dft_DPCs = pyfftw.empty_aligned((self.sizeX_img, self.sizeY_img), dtype=np.complex128)
            dft_DPCs[:] = pyfftw.interfaces.numpy_fft.fft2(self.DPC_imgs[n])
            dft_DPCs = pyfftw.interfaces.numpy_fft.fftshift(dft_DPCs)
            OBF_Q[n] = dft_DPCs * wq
        return OBF_Q

    def align_images(self, OBF_Q):
        """ Align segmented images using the phase cross-correlation method. """
        values = np.zeros(self.num)
        for i in range(self.num):
            values[i] = np.max(OBF_Q[i].real)
        max_value = np.max(values)
        index = np.argmax(values)
        print(index)
        reference = pyfftw.interfaces.numpy_fft.ifft2(pyfftw.interfaces.numpy_fft.ifftshift(OBF_Q[index]))
        self.aligned_images = [reference.real]
        for i in range(0, self.num):
            if i != index:
                print(max_value/values[i])
                compare = pyfftw.interfaces.numpy_fft.ifft2(pyfftw.interfaces.numpy_fft.ifftshift(OBF_Q[i]))
                moving, _, _ = phase_cross_correlation(reference.real, compare.real, upsample_factor=200, normalization = "phase")
                print(i, moving)
                corrected_image = ndimage.shift(compare.real, moving)
                self.aligned_images.append(corrected_image)
        return np.sum(np.array(self.aligned_images), axis=0)

    def plot_results(self, OBF_image, OBF_Q):
        """ Plot the results of the reconstruction in real and Fourier space. """
        reciprocal_res = 1 / self.resolution
        extend_edge = 0.5 * reciprocal_res
        extend = [-extend_edge, extend_edge, -extend_edge, extend_edge]

        summation = np.sum(OBF_Q, axis=0)
        summation[self.sizeX_img // 2, self.sizeY_img // 2] = 0
        dft_amp = np.log(np.abs(summation / self.sizeX_img) + 1)

        fig, (ax0, ax1) = plt.subplots(1, 2)
        ax0.imshow(dft_amp, extent=extend, interpolation='hanning', vmax=np.max(dft_amp) * 1.1)
        ax0.set_title('Amplitude of OBF in Fourier domain')
        ax0.set_xlabel('Length (1/Å)')
        ax0.set_ylabel('Length (1/Å)')

        length = [-self.resolution * self.sizeX_img / 2, self.resolution * self.sizeX_img / 2, 
                  -self.resolution * self.sizeY_img / 2, self.resolution * self.sizeY_img / 2]
        ax1.imshow(OBF_image, extent=length, interpolation='gaussian')
        ax1.set_title('Reconstructed OBF')
        ax1.set_xlabel('Length (nm)')
        ax1.set_ylabel('Length (nm)')
        plt.tight_layout()
        plt.show()

    def reconstruct_OBF(self):
        """ Main function to reconstruct the OBF image and Fourier domain representation. """
        phase_filters = self.normalize_PCTFs()
        d_Q = self.compute_dQ()
        weightingInv = self.compute_weighting(phase_filters, d_Q)
        OBF_Q = self.compute_OBF_Q(phase_filters, d_Q, weightingInv)
        OBF_image = self.align_images(OBF_Q)
        self.plot_results(OBF_image, OBF_Q)
        return OBF_image, OBF_Q


In [400]:
DPC = [DPC_imgs[0] ,DPC_imgs[1], DPC_imgs[2], DPC_imgs[3]]

In [438]:
obf_reconstructor = OBFReconstructor(DPC, built_WPO2, parameters)

In [439]:
OBF_image, OBF_Q = obf_reconstructor.reconstruct_OBF()

0
1.2750704183684871
1 [0.815 1.125]
1.2679577129021908
2 [0.02 1.84]
1.8680854512926826
3 [-0.87  1.04]


In [399]:
plt.imshow(OBF_image)

<matplotlib.image.AxesImage at 0x211230d71d0>

In [488]:


def align_images(images, space='real', mask_center=(300, 1450), mask_size=(256, 256)):
    """
    Perform image registration using skimage.registration.phase_cross_correlation.
    
    Args:
        images: np.ndarray, with shape (num, pixel_x, pixel_y)
        space: string, if 'images' are in real space, use 'real'; 
               if they are Fourier transformed, use 'Fourier'.
        mask_center: tuple, the coordinate of the mask center in pixels, like (300, 1450)
        mask_size: tuple, the height and width of the mask, like (256, 256)
    
    Returns:
        aligned_images: np.ndarray of aligned images.
    """
    
    num, px, py = images.shape
    values = np.zeros(num)
    
    # Create the mask
    mask = np.zeros((px, py), dtype=np.uint8)
    half_mask_size = (mask_size[0] // 2, mask_size[1] // 2)
    
    row1 = max(0, mask_center[0] - half_mask_size[0])
    row2 = min(px, mask_center[0] + half_mask_size[0])
    col1 = max(0, mask_center[1] - half_mask_size[1])
    col2 = min(py, mask_center[1] + half_mask_size[1])
    
    if max(mask_center) >= max(px, py) or row2 <= row1 or col2 <= col1:
        raise ValueError("Invalid mask parameters. Check mask center or mask size.")
    
    mask[row1:row2, col1:col2] = 1
    
    # Calculate the standard deviation of the real part for each image
    for i in range(num):
        values[i] = np.std(images[i].real)
    
    # Choose the image with the highest variance as the reference
    index = np.argmax(values)
    aligned_images = []
    
    # Real-space alignment
    if space == 'real':
        reference = images[index]
        aligned_images.append(reference)
        
        for i in range(num):
            if i != index:
                shift_vector, _, _ = phase_cross_correlation(reference * mask, images[i] * mask, upsample_factor=200)
                aligned_image = shift(images[i], shift_vector)
                aligned_images.append(aligned_image)
    
    # Fourier-space alignment
    elif space in ['Fourier', 'Fourier space', 'Reciprocal space', 'Reciprocal']:
        reference = pyfftw.interfaces.numpy_fft.ifft2(pyfftw.interfaces.numpy_fft.ifftshift(images[index]))
        aligned_images.append(reference.real)
        plt.imshow(reference.real*mask)
        for i in range(num):
            if i != index:
                compare = pyfftw.interfaces.numpy_fft.ifft2(pyfftw.interfaces.numpy_fft.ifftshift(images[i]))
                shift_vector, _, _ = phase_cross_correlation(reference.real * mask, compare.real * mask, upsample_factor=200)
                aligned_image = shift(compare.real, shift_vector)
                aligned_images.append(aligned_image)
    
    else:
        raise ValueError("Please choose the correct space: 'real' or 'Fourier'.")
    
    return np.array(aligned_images)


In [489]:
# Assuming your images are stored in a variable called 'images'
aligned_images = align_images(OBF_Q, space='Fourier',mask_center=(600,1400),mask_size=(600,600))

In [490]:
plt.imshow(np.sum(aligned_images, axis =0))

<matplotlib.image.AxesImage at 0x210f3a5d3a0>

In [462]:
fig, axes = plt.subplots(2, 2, figsize =(6, 6), sharex=True, sharey=True)

for i, ax in enumerate(axes.ravel()):
    if i < 4:
        pcm = ax.imshow(aligned_images[i], cmap = 'viridis')
        ax.set_title(f'DPC image. {i+1}-th')
        fig.colorbar(pcm, ax=ax, shrink=0.4, extend = 'both')
    else:
        ax.axis('off')  # Turn off the axis for empty plots
plt.tight_layout()
#plt.savefig(f"{path_file}/LCoMxy.tiff")
plt.show()

In [267]:
from scipy.ndimage import shift

In [266]:
from skimage.registration import phase_cross_correlation

In [65]:
import pyfftw

In [293]:
ps.tools.plot_image([OBF_R,np.sum(aligned_images, axis =0)], properties= {
             'resolution': resolution,
              'unit': 'nm',
              'bar location':'',
             'image titles': 'reconstructed iDPC',
             'figsize':6,
              'cmap':'',
              'dpi': 600,
              'image format':'.jpeg',
              'showing titles': False,
             'cropping image': [True, [310,1485], [256,256]],
             'saving image': False,
             'saving path':path_file+ "OBF"
              })

In [377]:
mssim = ps.tools.get_mssim(ps.tools.linscale(OBF_R), ps.tools.linscale(OBF_image))
print(mssim)

0.99999994


In [300]:
ps.tools.contrast_ratio(OBF_R)

The minimum intensity is : -167.25
The maximum intensity is : 162.5
The mean intensity is : -0.0
The contrast of image is evaluated by: 
--> Max./Min. ratio: -0.97
--> Luminance contrast: 1.97
--> Weber contrast: -1.0
--> Michelson contrast: -0.9447
--> Peak SNR: 13.9224 dB


In [440]:
ps.tools.contrast_ratio(OBF_image)

The minimum intensity is : -168.28
The maximum intensity is : 175.72
The mean intensity is : -0.01
The contrast of image is evaluated by: 
--> Max./Min. ratio: -1.04
--> Luminance contrast: 2.04
--> Weber contrast: -1.0
--> Michelson contrast: -1.0925
--> Peak SNR: 14.7146 dB


In [473]:
ps.tools.contrast_ratio(np.sum(aligned_images, axis =0))

The minimum intensity is : -167.67
The maximum intensity is : 169.56
The mean intensity is : -0.01
The contrast of image is evaluated by: 
--> Max./Min. ratio: -1.01
--> Luminance contrast: 2.01
--> Weber contrast: -1.0
--> Michelson contrast: -1.0228
--> Peak SNR: 14.55 dB


In [355]:
from scipy.ndimage import rotate

In [416]:
rotated_img = rotate(OBF_image, -25)
plt.imshow(rotated_img)

<matplotlib.image.AxesImage at 0x2112ed83800>

In [417]:
matcher = ps.tools.TemplateMatcher(rotated_img)

In [418]:
matcher.select_template({'top_left': [923,2347], 'height':128, 'width':128})

In [419]:
matcher.search_template()

In [433]:
finds = matcher.get_matches(threshold=0.67)

In [434]:
stacks = matcher.stack_matches()

In [430]:
# Display the results
matcher.display()

In [435]:
stacks.shape

(26, 128, 128)

In [436]:
plt.imshow(np.sum(stacks, axis =0))
plt.axis('off')

(-0.5, 127.5, 127.5, -0.5)

#### There are several filters can denoise the reconstructed images

In [217]:
filtered,_ = ps.EMFilters.wiener_filter(OBF_R, delta=5, lowpass=True, lowpass_cutoff=0.3, lowpass_order=2)

In [219]:
filtered2,_ = ps.EMFilters.abs_filter(OBF_R, delta=5, lowpass=True, lowpass_cutoff=0.3, lowpass_order=2)

In [221]:
filtered3,_ = ps.EMFilters.nonlinear_filter(OBF_R, space ='real', N=5, mode='wiener', delta=5, lowpass_cutoff=0.3, lowpass = True, lowpass_order=2)

Building:   0%|           [ time left: ? ]

In [222]:
plt.imshow(filtered3[0])

<matplotlib.image.AxesImage at 0x210ca7e1ac0>

## The following codes can reconstruct First-momentum STEM image from the segmented images

In [35]:
FM = ps.analysis.FMSTEMReconstructor(CoMx, CoMy, stepsize=1, n_iter=10, regHighpass=1e-4)
optimal_theta, optimal_flip = FM.optimize_rotation(CoMx, CoMy, thetas = np.linspace(0, 90, 90))

Calculating:   0%|          | 0/180 [00:00<?, ?iteration/s]

Does it need to flip: True
The rotation angle is: 80.90 degrees


In [36]:
FM_image, _ = FM.run(optimal_theta, optimal_flip, point = [1024,1024], crop_size =1024, process=True)

Building:   0%|          | 0/10 [00:00<?, ?iteration/s]

Iteration: 3: break
You can change the iterations through <epsilon>


In [37]:
FM.display()

# Following codes can save results using "h5py"

In [None]:
saving_name = "simu_5nm16mrad80pm_DPC_STEM.h5"

images_for_saving = {'OBF':OBF_R, 'FM_STEM':FM_image}

data = h5py.File(f"{path_file}/saving_name", 'w')
infomation = data.create_group("Experimental_parameters")
string_dtype = h5py.special_dtype(vlen=str)  # Using special_dtype for variable-length strings
  for i in range(len(info)):
      element = (list(info.values())[i])
      if type(element)==str:
          info_input = infomation.create_dataset(list(info.keys())[i], shape=(1,), dtype = string_dtype)
          info_input[0] = element
      else:
          infomation.create_dataset(list(info.keys())[i], data = element, dtype = 'f')
          
reconstruct = data.create_group('reconstructed_images') 
for i in range(len(images_for_saving)):
    reconstruct.create_dataset(list(images_for_saving.keys())[i], data= list(images_for_saving.values())[i])
data.close()