In [15]:
import numpy as np
import matplotlib.pyplot as plt
import hyperspy.api as hs
import pyxem as pxm
import sys
import cv2
from tqdm import tqdm
from scipy.spatial import distance, distance_matrix
import random as rd

import matplotlib 
%matplotlib inline

# Distortion correction using perspective model

In [None]:
########################## build perspective matrix list for all images for drift&distortion correction
def get_perspective_matrix(img1, img2, MIN_MATCH_COUNT = 10, plot = False):
    # Initiate SIFT detector
    sift = cv2.SIFT_create()
    # find the keypoints and descriptors with SIFT
    kp1, des1 = sift.detectAndCompute(img1,None)
    kp2, des2 = sift.detectAndCompute(img2,None)
    FLANN_INDEX_KDTREE = 1
    index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
    search_params = dict(checks = 50)
    flann = cv2.FlannBasedMatcher(index_params, search_params)
    matches = flann.knnMatch(des1,des2,k=2)
    # store all the good matches as per Lowe's ratio test.
    good = []
    for m,n in matches:
        if m.distance < 0.7*n.distance:
            good.append(m)

    if len(good)>MIN_MATCH_COUNT:
        src_pts = np.float32([ kp1[m.queryIdx].pt for m in good ]).reshape(-1,1,2)
        dst_pts = np.float32([ kp2[m.trainIdx].pt for m in good ]).reshape(-1,1,2)
        M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,5.0)
        matchesMask = mask.ravel().tolist()
        h,w = img1.shape
        pts = np.float32([ [0,0],[0,h-1],[w-1,h-1],[w-1,0] ]).reshape(-1,1,2)
        dst = cv2.perspectiveTransform(pts,M)
        img2 = cv2.polylines(img2,[np.int32(dst)],True,255,3, cv2.LINE_AA)
    else:
        print( "Not enough matches are found - {}/{}".format(len(good), MIN_MATCH_COUNT) )
        matchesMask = None
    if plot == True:
        plt.figure()
        draw_params = dict(matchColor = (0,255,0), # draw matches in green color
                           singlePointColor = None,
                           matchesMask = matchesMask, # draw only inliers
                           flags = 2)
        img3 = cv2.drawMatches(img1,kp1,img2,kp2,good,None,**draw_params)
        plt.imshow(img3, 'gray'),plt.show()
    return M

def matrix_shift_after_perspective(shift_x, shift_y, perspective_matrix):
    translation_matrix = np.float32([[1,0,shift_y], [0,1,shift_x], [0,0,1]])
    return np.dot(translation_matrix,M)

############################## find warp matrix for each image, apply to each images

perspective_0 = np.float32(np.identity(3)) 
perspective_matrixes = [perspective_0]
for num_mapping in range(8):
    img1 = cv2.imread(f'data/index_mapping_confid_origin_{num_mapping}.jpeg',0)[:,:]          # queryImage
    img2 = cv2.imread(f'data/index_mapping_confid_origin_{num_mapping+1}.jpeg',0)[:,:]          # trainImage
    M = get_perspective_matrix(img2, img1)
    perspective_matrixes.append(M)
    
for i in range(1, len(perspective_matrixes)):
    perspective_matrixes[i] = np.dot(perspective_matrixes[i-1],perspective_matrixes[i])

############################## save warped confid images in a stack

image=[]
for num_mapping in range(8):
    img = cv2.imread(f'data/index_mapping_confid_origin_{num_mapping}.jpeg',0)
    if num_mapping == 3:
        img = cv2.copyMakeBorder(img, 0,100,0,100, cv2.BORDER_CONSTANT, value = 0)
    else:
        img = cv2.warpPerspective(img, perspective_matrixes[num_mapping], (500,500))
    image.append(img)
image = np.array(image)
tif.imsave(f'data/index_mapping_confid_warp.tif', image)

################################  perspective matrix list save in perspective_matrixes
perspective_matrixes = np.array(perspective_matrixes)
np.save(f'data/perspective_matrixes.npy', perspective_matrixes)

# Distortion correction using non-linear elastic model

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2



def read_xy(file_path):
    start_line = "X Trans -----------------------------------"  # Replace with the actual start line
    end_line = "Y Trans -----------------------------------"  # Replace with the actual end line

    # Open the file in read mode
    with open(file_path, "r") as file:
        # Initialize variables
        start_found = False
        content = []

        # Iterate over the lines in the file
        for line in file:
            # Check if the start line is found
            if start_line in line:
                start_found = True
                continue

            # Check if the end line is found
            if end_line in line:
                break

            # If the start line is found and the end line is not found, capture the content
            if start_found:
                content.append(line)

    # Join the captured content into a single string
    content_str_x = ''.join(content)

    # Open the file in read mode
    with open(file_path, "r") as file:
        # Initialize variables
        start_found = False
        content = []

        # Iterate over the lines in the file
        for line in file:
            # Check if the start line is found
            if end_line in line:
                start_found = True
                continue

            # If the start line is found and the end line is not found, capture the content
            if start_found:
                content.append(line)

    # Join the captured content into a single string
    content_str_y = ''.join(content)
    
    return content_str_x, content_str_y

def get_locs(content_str_x, content_str_y):
    array_y = np.fromstring(content_str_y, dtype=float, sep=' ')
    array_x = np.fromstring(content_str_x, dtype=float, sep=' ')
    new_loc = np.vstack((array_y, array_x)).T
    new_loc = new_loc.reshape(400,400,2)
    return new_loc

def image_deform(img, new_loc):
    new_image = np.zeros((img.shape[0], img.shape[1], 3), np.uint8)
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            x = int(new_loc[i][j][0])
            y = int(new_loc[i][j][1])
            if 0 <= x < 400 and 0 <= y < 400:
                new_image[i][j] = img[x][y]
    return new_image

# register 4d-stem data to frame 5
from tqdm import tqdm
frame_1 = 5
frame_2 = 9
    
import hyperspy.api as hs
data_file = hs.load(f"data/{frame_2}.hspy")
data_corrected = np.zeros([data_file.data.shape[0],data_file.data.shape[1],120,120])
file_path = f'registration/{frame_1}-{frame_2}_raw.txt'
content_str_x, content_str_y = read_xy(file_path)
new_loc = get_locs(content_str_x, content_str_y)
print("shift", new_loc[0][0])

for i in tqdm(range(data_file.data.shape[0])):
    for j in range(data_file.data.shape[1]):
        x = int(new_loc[i][j][0])
        y = int(new_loc[i][j][1])
        if 0 <= x < 400 and 0 <= y < 400:
            data_corrected[i,j] = np.array(data_file.inav[x,y].data)

data_corrected = hs.signals.Signal2D(data_corrected)
data_corrected.save(f"data/{frame_2}_corrected.hspy")