This notebook calculates the local lattice rotation, as described in Supplementary Note 1 of the ai4stem manuscript (Leitherer et al., 2023).

In [None]:
import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
import numpy.matlib
from scipy import ndimage
from collections import Counter
import itertools
from scipy.signal import get_window
import cv2
from collections import defaultdict
from scipy.stats import mode
from scipy import stats

from PIL import Image
import matplotlib.pyplot as plt
import hyperspy
import hyperspy.api as hs
import atomap.api as am
from ase import Atoms
from collections import defaultdict
from copy import deepcopy

import ase
from ase import Atoms


from ase.io import read
from pycpd import RigidRegistration, AffineRegistration, DeformableRegistration
from functools import partial
from copy import deepcopy
from scipt import stats

from ase import Atoms
from ase.neighborlist import NeighborList

# Define required input

In [None]:
# Please specify
# 1. Savepath, 2. STtride, 3. Window size
# 4. image path 5. pixel / angstrom relation 
# 6. paths to mutual information and predictions (AI-STEM predictions)
# This information can be found via the Zenodo link provided in the manuscript

save_path = '.'

stride_size = [12, 12]

window_size = 136 # Cu example


image_path = 'Cu_fcc_111.npy'
pixel_to_angstrom = 0.088052397290637275
mutual_information_pathh = 'Cu_fcc_111_mutual_information.npy'
predictions_path = 'Cu_fcc_111_probabilities.npy'

In [None]:
# load argmax predictions, take most popular assignment as symmtery
# for which the reference training image is loaded
argmax_predictions = np.argmax(np.load(predictions_path), axis=-1)
assigned_label = stats.mode(argmax_predictions)

# Load image, extract local windows, reconstruct atomic columns

In [None]:
# Function for calculating local windows

def localwindow(image_in, stride_size, pixel_max=100,
                normalize_before_fft=False, normalize_after_window=False):
    x_max = image_in.shape[0]
    y_max = image_in.shape[1]

    images = []
    spm_pos = []

    i = 0
    ni = 0
    while i < x_max-pixel_max:
        j = 0
        nj = 0
        ni = ni + 1

        while j < y_max-pixel_max:
            nj = nj + 1
            image = np.zeros((pixel_max,pixel_max))
            for x in range(0,pixel_max):
                for y in range(0,pixel_max):
                    image[x,y] = image_in[x+i,y+j] 
            if normalize_before_fft:
                image = cv2.normalize(image, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
            
            """
            filename = "local_images_" + str(ni) + "_" + str(nj)
            plt.figure()
            plt.imshow(image,cmap='gray')
            #plt.colorbar()
            #plt.draw()
            plt.axis('off')
            plt.savefig(filename + '.png',bbox_inches='tight',pad_inches=0)
            plt.close()
            np.save(filename + '.npy', image)

            makeWindowingFFT.windowFFT(image,filename +'.png',normalize_after_window)
            """
            j += stride_size[1]
            images.append(image)
            spm_pos.append([i, j])
        i += stride_size[0]
    return images, np.asarray(spm_pos), ni, nj

In [None]:
# Extract local windows
print(image_path)
image = np.load(image_path)
sliced_images, spm_pos, ni, nj = localwindow(image, stride_size=stride_size, pixel_max=window_size)

In [None]:
# Function for reconstructing real-space lattice
# from atomic columns via atomap
def reconstruct_via_atomap(image, separation, refine=True):

    separation_range = (separation - 1, separation + 1)
    
    peaks = am.get_feature_separation(image, separation_range=separation_range)
    atom_positions = am.get_atom_positions(image, separation=separation)
    
    
    if refine:
        
        min_peak_separation = separation
        tr_img = hs.signals.Signal2D(image)

        s_peaks = peaks #am.get_feature_separation(hs.signals.Signal2D(image),
                  #                          separation_range=(min_peak_separation, min_peak_separation * 2),
                  #                          show_progressbar=False)
        # Get peak positions and determine sublattice
        peak_pos = am.get_atom_positions(tr_img, separation=min_peak_separation)
        peak_pos = am.Sublattice(peak_pos, image=tr_img.data)

        # Refine peak positions using center of mass and 2D Gaussians based on NN distance
        peak_pos.find_nearest_neighbors()
        peak_pos.refine_atom_positions_using_center_of_mass()
        peak_pos.refine_atom_positions_using_2d_gaussian()

        #peak_pos.plot(navigator='signal')

        # Covert peaks to array
        peak_list = peak_pos.atom_list
        num_peaks = np.shape(peak_list)
        num_peaks = num_peaks[0]

        peaks = np.zeros((num_peaks,2))
        for i in range(0, num_peaks):
            peaks[i,:] = [peak_list[i].pixel_x, peak_list[i].pixel_y]
        atom_positions = peaks
    
    return atom_positions

In [None]:
separation = 1. / pixel_to_angstrom
print(separation)

atomic_columns = reconstruct_via_atomap(hs.signals.Signal2D(image), separation=int(separation / 2.), refine=False)

In [None]:
fig, axs = plt.subplots(figsize=(15, 15))
plt.imshow(image)
plt.scatter(atomic_columns[:, 0], atomic_columns[:, 1], s=4, c='r')
plt.show()

In [None]:
# Extract atomic colums for each local window (local window calculated earlier from image)


start = [window_size, window_size]
element_agnostic = False
coordinate_vectors = atomic_columns
element_symbols = np.array(['Ti' for _ in range(len(atomic_columns[:, 0]))], dtype=object)

number_of_strides_vector = [ni, nj]

x_sliding_volume_edge_length = window_size
y_sliding_volume_edge_length = window_size
step_size_x = stride_size[0]
step_size_y = stride_size[1]
x_min = 0
y_min = 0

all_boxes = []
all_stride_idx = []

list_of_xy_boxes = []
number_of_atoms_xy = []

for i in range(number_of_strides_vector[1]):

    list_of_x_boxes = []
    number_of_atoms_x = []

    for j in range(number_of_strides_vector[0]):

        # Determine atoms within sliding box
        positionvectors_within_sliding_volume = []
        element_names_within_sliding_volume = ''


        condition = (coordinate_vectors[:,0] <= start[0]) & (coordinate_vectors[:,1] <= start[1]) \
                    & (coordinate_vectors[:,0] >= (start[0]-x_sliding_volume_edge_length)) \
                    & (coordinate_vectors[:,1] >= (start[1]-y_sliding_volume_edge_length))
        positionvectors_within_sliding_volume = coordinate_vectors[condition]
        element_names_within_sliding_volume = element_symbols[condition]
        element_names_within_sliding_volume = ''.join(list(element_names_within_sliding_volume))
        
        positionvectors_within_sliding_volume = np.array([[_[0], 
                                                           _[1],
                                                          0.0] for _ in positionvectors_within_sliding_volume])

        if len(positionvectors_within_sliding_volume) == 0:
            number_of_atoms_x.append(0)

            element_name = element_names_within_sliding_volume  # should be ''
            # create ase Atoms object Atoms(symbols='',pbc=False)
            atoms_within_sliding_volume = ase.Atoms(element_name)
            # Optional: assign label
            # atoms_within_sliding_volume.info['label']='box_label_'+str(i)+str(j)+str(k)
            atoms_within_sliding_volume.set_pbc(False) # added for safety
            list_of_x_boxes.append(atoms_within_sliding_volume)

        else:
            number_of_atoms_x.append(len(positionvectors_within_sliding_volume))

            if element_agnostic:
                element_name = 'Fe'+str(len(positionvectors_within_sliding_volume))
            else:
                element_name = element_names_within_sliding_volume
            atoms_within_sliding_volume = ase.Atoms(element_name, positionvectors_within_sliding_volume)
            # Optional: assign label
            # atoms_within_sliding_volume.info['label']='box_label_'+str(i)+str(j)+str(k)
            atoms_within_sliding_volume.set_pbc(False) # added for safety
            list_of_x_boxes.append(atoms_within_sliding_volume)
        all_boxes.append([[_[0], _[1]] for _ in positionvectors_within_sliding_volume])
        all_stride_idx.append([j, i])

        start[0] += step_size_x

    number_of_atoms_xy.append(number_of_atoms_x)
    list_of_xy_boxes.append(list_of_x_boxes)
    start[0] = x_min+x_sliding_volume_edge_length  # Reset x_value after most inner for loop finished
    start[1] += step_size_y  # next y

In [None]:
# Visualize some of the local windows
for idx in range(len(sliced_images))[::500]:
    img = sliced_images[idx]
    pos = all_boxes[idx]
    pos = np.array(pos)
    pos[:, 0] -= all_stride_idx[idx][0] * stride_size[0]
    pos[:, 1] -= all_stride_idx[idx][1] * stride_size[1]
                 
    fig, axs = plt.subplots()
    axs.imshow(img)
    if not len(pos) == 0:
        pos = np.array(pos)
        axs.scatter(pos[:, 0], pos[:, 1], c='r')

# Given predicted symmetry, define reference lattice

In [None]:
# Load all reference lattices
reference_data_path = '/home/leitherer/Real_space_AI_STEM/ai4stem/data/AISTEM/Convolution/sim_test'

reference_pixeltoangstrom = 0.12

file_ending = 'sampling_0.12_convolved_STEM_image.hdf5'

reference_dict = {"BCC_Fe_100": os.path.join(reference_data_path, '{}_LatPar_{}A_{}'.format('BCC_Fe_100', 
                                                                                           2.87, 
                                                                                           file_ending)),
 "BCC_Fe_110": os.path.join(reference_data_path, '{}_LatPar_{}A_{}'.format('BCC_Fe_110', 
                                                                          2.87, 
                                                                          file_ending)),
 "BCC_Fe_111": os.path.join(reference_data_path, '{}_LatPar_{}A_{}'.format('BCC_Fe_111', 
                                                                          2.87, 
                                                                          file_ending)),
 "FCC_Cu_100": os.path.join(reference_data_path, '{}_LatPar_{}A_{}'.format('FCC_Cu_100', 
                                                                          3.63, 
                                                                          file_ending)),
 "FCC_Cu_110": os.path.join(reference_data_path, '{}_LatPar_{}A_{}'.format('FCC_Cu_110', 
                                                                          3.63, 
                                                                          file_ending)),
 "FCC_Cu_111": os.path.join(reference_data_path, '{}_LatPar_{}A_{}'.format('FCC_Cu_111', 
                                                                          3.63, 
                                                                          file_ending)),
 "FCC_Cu_211": os.path.join(reference_data_path, '{}_LatPar_{}A_{}'.format('FCC_Cu_211', 
                                                                          3.63, 
                                                                          file_ending)),
 "HCP_Ti_0001": os.path.join(reference_data_path, '{}_LatPar_{}A_{}'.format('HCP_Ti_0001', 
                                                                           2.95, 
                                                                           file_ending)),
 "HCP_Ti_10m10": os.path.join(reference_data_path, '{}_LatPar_{}A_{}'.format('HCP_Ti_10m10', 
                                                                            2.95, 
                                                                            file_ending)),
 "HCP_Ti_2m1m10": os.path.join(reference_data_path, '{}_LatPar_{}A_{}'.format('HCP_Ti_2m1m10', 
                                                                             2.95, 
                                                                             file_ending))}

Load training image for currently assigned label

In [None]:
training_path = reference_dict[assigned_label]
print(training_path)

training_file = h5py.File(training_path,'r')

img = training_file.get('convolved_stem_image/conv_stem')

tr_img = np.array(img)
print(tr_img.shape)

plt.imshow(tr_img)
training_file.close()

In [None]:
# Reconstruct real-space lattice

ref_separation = int( (1. / reference_pixeltoangstrom) / 2. )
print(ref_separation)


atom_positions = reconstruct_via_atomap(hs.signals.Signal2D(tr_img), separation= ref_separation, refine=False)
fig, axs = plt.subplots(figsize=(20,20))
axs.set_aspect('equal')
plt.scatter(atom_positions[:, 0], atom_positions[:, 1], s=1)

In [None]:
# Given input window size, extract same window size from training image.
# Assume adjusted window size -> need to adjust window used for mask!

# 1. convert given window size into angstrom
window_input_angstrom = float(window_size) * pixel_to_angstrom
window_reference = window_input_angstrom * ( 1. / float(reference_pixeltoangstrom))
print(round(window_reference))

In [None]:
# Extract segment from center of reference image, 
center = np.mean(atom_positions, axis=0)
print(center)


mask_x = (atom_positions[:, 0] >= (center[0] - float(window_reference) / 2.)) & \
         (atom_positions[:, 0] <= (center[0] + float(window_reference) / 2.))
mask_y = (atom_positions[:, 1] >= (center[1] - float(window_reference) / 2.)) & \
         (atom_positions[:, 1] <= (center[1] + float(window_reference) / 2.))
    
mask = mask_x & mask_y
filtered_columns = atom_positions[mask]
fig, axs = plt.subplots()
axs.set_aspect('equal')
plt.scatter(filtered_columns[:, 0], filtered_columns[:, 1], s=1)

In [None]:
# Normalize extracted reference lattice and then extract small local
# region (not more than 20 atoms, i.e., few neareast neighbors)

reference_lattice = deepcopy(filtered_columns)
# Shift to origin
x_shift = np.mean(reference_lattice[:,0])
y_shift = np.mean(reference_lattice[:,1])

dist = np.zeros((reference_lattice.shape[0],1))
for p in range(reference_lattice.shape[0]):
    dist[p] = np.sqrt( (reference_lattice[p,0]-x_shift)**2 + (reference_lattice[p,1]-y_shift)**2 )

reference_lattice_cent = reference_lattice[np.argmin(dist)]

reference_lattice[:,0] = reference_lattice[:,0] - reference_lattice_cent[0]
reference_lattice[:,1] = reference_lattice[:,1] - reference_lattice_cent[1]

# Delete atoms until have only few nearest neighbors (20)
radius = float(window_reference) / 2.
reference_lattice_tmp = deepcopy(reference_lattice)
Nat = 20
delta = (1. / pixel_to_angstrom) / 8.
while len(reference_lattice_tmp) >= Nat:
    del_peaks = np.sqrt(reference_lattice_tmp[:,0]**2 + reference_lattice_tmp[:,1]**2) < radius
    reference_lattice_tmp = reference_lattice_tmp[del_peaks == True]
    radius -= delta
reference_lattice = reference_lattice_tmp
print(len(reference_lattice))    
fig, axs = plt.subplots()
axs.set_aspect('equal')
plt.scatter(reference_lattice[:, 0], reference_lattice[:, 1], s=1)



In [None]:
# Define function for normalizign the lattice

def norm_window_lattice(atomic_columns):
    
    # Select box
    peaks_box = np.array(atomic_columns)
  
    # Shift to origin
    x_shift = np.mean(peaks_box[:,0])
    y_shift = np.mean(peaks_box[:,1])

    dist = np.zeros((peaks_box.shape[0],1))
    for p in range(peaks_box.shape[0]):
        dist[p] = np.sqrt( (peaks_box[p,0]-x_shift)**2 + (peaks_box[p,1]-y_shift)**2 )

    peaks_box_cent = peaks_box[np.argmin(dist)]

    peaks_box[:,0] = peaks_box[:,0] - peaks_box_cent[0]
    peaks_box[:,1] = peaks_box[:,1] - peaks_box_cent[1]

    # Apply radial mask
    #"""
    delta = (1. / pixel_to_angstrom) / 8.
    radius = float(window_size) / 2. #50
    while peaks_box.shape[0] > reference_lattice.shape[0]:
        del_peaks = np.sqrt(peaks_box[:,0]**2 + peaks_box[:,1]**2) < radius
        peaks_box = peaks_box[del_peaks == True]  
        
        radius = radius - delta

    lattice = peaks_box
    
    return lattice
    #"""
    #
    #return peaks_box

In [None]:
# Define criterion according to which the local rotation angle is calculated
# Reason: algorithm performs clock- or counter-clockwise rotation. 
# Thus need to adjust.


# All rotation symmetries
rotations_dict = {"BCC_Fe_100": 90.,
                  "BCC_Fe_110": 180., 
                  "BCC_Fe_111": 60., 
                  "FCC_Cu_100": 90., 
                  "FCC_Cu_110": 60., 
                  "FCC_Cu_111": 60., 
                  "FCC_Cu_211": 180., 
                  "HCP_Ti_0001": 60., 
                  "HCP_Ti_10m10": 180., 
                  "HCP_Ti_2m1m10": 60.}


def criterion(assigned_label, mismatch_angle):
    mismatch_angle = np.abs(mismatch_angle)
    
    symmetry_angle = rotations_dict[assigned_label]
    max_angle = symmetry_angle / 2.
    # eg for Ti, rot sym angle is 60., max angle is 30: If calculated
    # mismatch angle is larger than 30., simply subtract 60. - in
    # this way the caculated mismatch angle is the smallest angle 
    # to match reference lattice and local window columns
    
    if mismatch_angle >= max_angle:
        mismatch_angle -= symmetry_angle
        mismatch_angle = np.abs(mismatch_angle)
    return mismatch_angle
        
    

In [None]:
# Function for scaling lattice isotropically
def get_nn_distance(atoms, distribution='quantile_nn', cutoff=20.0,
                    min_nb_nn=1,#5,
                    pbc=True, plot_histogram=False, bins=100, 
                    constrain_nn_distances=False, nn_distances_cutoff=0.9, 
                    element_sensitive=False, central_atom_species=26, neighbor_atoms_species=26,
                    return_more_nn_distances=False, return_histogram=False):
    
    if not pbc:
        atoms.set_pbc((False, False, False))

    nb_atoms = atoms.get_number_of_atoms()
    cutoffs = np.ones(nb_atoms) * cutoff
    # Notice that if get_neighbors(a) gives atom b as a neighbor,
    #    then get_neighbors(b) will not return a as a neighbor - unless
    #    bothways=True was used."
    nl = NeighborList(cutoffs, skin=0.1, self_interaction=False, bothways=True)
    # nl.build(atoms) previously used.
    nl.update(atoms)
    nn_dist = []

    for idx in range(nb_atoms):
        # element sensitive part - only select atoms of specified chemical species as central atoms
        if element_sensitive:
            if atoms.get_atomic_numbers()[idx]==central_atom_species:
                pass
            else:
                continue        
        
        #print("List of neighbors of atom number {0}".format(idx))
        indices, offsets = nl.get_neighbors(idx)
        if len(indices) >= min_nb_nn: # before was >!!
            coord_central_atom = atoms.positions[idx]
            # get positions of nearest neighbors within the cut-off
            dist_list = []
            for i, offset in zip(indices, offsets):
                # element sensitive part - only select neighbors of specified chemical species
                if element_sensitive:
                    if atoms.get_atomic_numbers()[i]==neighbor_atoms_species:
                        pass
                    else:
                        continue
                # center each neighbors wrt the central atoms
                coord_neighbor = atoms.positions[i] + np.dot(offset, atoms.get_cell())
                # calculate distance between the central atoms and the neighbors
                dist = np.linalg.norm(coord_neighbor - coord_central_atom)
                dist_list.append(dist)

            # dist_list is the list of distances from the central_atoms
            if len(sorted(dist_list)) > 0:
                # get nearest neighbor distance
                nn_dist.append(sorted(dist_list)[0])
            else:
                print("List of neighbors is empty for some atom. Cutoff must be increased.")
                return None
        else:
            print("Atom {} has less than {} neighbours. Skipping.".format(idx, min_nb_nn))


    if constrain_nn_distances:
         original_length = len(nn_dist)
         # Select all nearest neighbor distances larger than nn_distances_cutoff
         threshold_indices = np.array(nn_dist) > nn_distances_cutoff 
         nn_dist = np.extract(threshold_indices , nn_dist)
         if len(nn_dist)<original_length:
             print("Number of nn distances has been reduced from {} to {}.".format(original_length,len(nn_dist)))

    if distribution == 'avg_nn':
        length_scale = np.mean(nn_dist)
    elif distribution == 'quantile_nn':
        # get the center of the maximally populated bin
        hist, bin_edges = np.histogram(nn_dist, bins=bins, density=False)

        # scale by r**2 because this is how the rdf is defined
        # the are of the spherical shells grows like r**2
        hist_scaled = []
        for idx_shell, hist_i in enumerate(hist):
            hist_scaled.append(float(hist_i)/(bin_edges[idx_shell]**2))

        length_scale = (bin_edges[np.argmax(hist_scaled)] + bin_edges[np.argmax(hist_scaled) + 1]) / 2.0

        if plot_histogram:
            # this histogram is not scaled by r**2, it is only the count
            plt.hist(nn_dist, bins=bins)  # arguments are passed to np.histogram
            plt.title("Histogram")
            plt.show()
    else:
        raise ValueError("Not recognized option for atoms_scaling. "
                         "Possible values are: 'min_nn', 'avg_nn', or 'quantile_nn'.")
                         
    if return_more_nn_distances and distribution=='quantile_nn':
        length_scale_3 = (bin_edges[np.argsort(hist_scaled)[-3:][0]] + bin_edges[np.argsort(hist_scaled)[-3:][0] + 1]) / 2.0
        length_scale_2 = (bin_edges[np.argsort(hist_scaled)[-3:][1]] + bin_edges[np.argsort(hist_scaled)[-3:][1] + 1]) / 2.0
        return length_scale, length_scale_2, length_scale_3
    elif return_histogram:
        return length_scale, hist_scaled, nn_dist
    else:
        return length_scale

In [None]:
def scale_lattice(lattice, window_size):
    atoms = Atoms(np.full(len(lattice), 'Fe'), 
                  positions=[[_[0], _[1], 0.0] for _ in lattice])
    
    nn_distance = get_nn_distance(atoms, cutoff=window_size)
    
    return np.array(lattice) / nn_distance
    

In [None]:
reference_lattice_scaled = scale_lattice(reference_lattice,
                                         window_reference)

# Calculate mismatch angles

In [None]:
mismatch_angles = []

input_lattice = []
transformed_reflattice = []
regs = []

for idx, pos in enumerate(all_boxes):

    if idx % 1000 == 0:
        print('Process box {} / {}'.format(idx, len(all_boxes)))
        
    # X: target, Y: source
    Y = reference_lattice_scaled 
    
    X = norm_window_lattice(pos) 
    X = scale_lattice(X, window_size)
    
    # AffineRegistration
    reg = RigidRegistration(**{'X': X, 'Y': Y, 'max_iterations': 100})
    reg.register()

    mismatch_angle = (-1) * np.arcsin(reg.get_registration_parameters()[1][0][1]) * 180 / np.pi

    mismatch_angle = criterion(assigned_label, mismatch_angle)
    mismatch_angles.append(mismatch_angle)
    
    transformed_reflattice.append(reg.transform_point_cloud(Y))
    regs.append(reg)
    input_lattice.append(X)

In [None]:
# statistics of calculated mismatch angles
print('Min: {}, Max: {}, 5% quantile: {}, 95% quantile: {}'.format(np.min(mismatch_angles),
                                                                   np.max(mismatch_angles),
                                                                   np.quantile(mismatch_angles, 0.05),
                                                                   np.quantile(mismatch_angles, 0.95)))

vmin = np.quantile(mismatch_angles, 0.05)
vmax = np.quantile(mismatch_angles, 0.95)

In [None]:
cmap = 'plasma'
fig, axs = plt.subplots(figsize=(15, 15))
plt.imshow(np.reshape(mismatch_angles, (ni, nj)), 
           vmin=vmin, vmax=vmax, cmap=cmap)
plt.colorbar()

# Apply smoothing

In [None]:
from scipy.signal import convolve2d
nber_nn = [1, 2, 4, 8, 16]
data = np.reshape(mismatch_angles, (ni, nj))
results = {}
fig, ax = plt.subplots(1, len(nber_nn), facecolor='white', figsize=(25, 5))
for idx, n in enumerate(nber_nn):
    kernel = np.ones((n, n))
    
    # remove parts of image where mutual information is above threshold
    mutinfo = np.load(results_path)
    mask = mutinfo < 0.1
    data_filtered = data.flatten()
    data_filtered[~mask] = 0.0
    data_filtered = np.reshape(data_filtered, (ni, nj))
    
    smoothed_data = convolve2d(data_filtered, kernel, boundary='symm', mode='same') / float(n * n)
    #print(data.shape, smoothed_data.shape)
    
    vmin = np.quantile(smoothed_data.flatten()[mask], 0.05)
    vmax = np.quantile(smoothed_data.flatten()[mask], 0.95)
    print(vmin, vmax)
    
    axs = ax[idx]
    
    smoothed_data = smoothed_data.flatten()
    smoothed_data[~mask] = np.nan
    cmap = plt.get_cmap('plasma')
    cmap.set_bad('gray')
    smoothed_data = np.reshape(smoothed_data, (ni, nj))
    
    im = axs.imshow(smoothed_data, vmin=vmin, vmax=vmax, cmap=cmap)
    axs.set_title('NN = {}'.format(n))
    fig.colorbar(im, ax=axs)
    results[n] = smoothed_data
    
    np.save('./mismatch_results/{}_mismatch_w_smoothing_nn_{}.npy'.format(assigned_label, n), smoothed_data)
plt.savefig('./mismatch_results/{}_mismatch_w_smoothing.svg'.format(assigned_label))
plt.close()

# Visulize, exemplarily, local lattice reconstructions with fit

In [None]:
for idx in range(len(sliced_images))[510:520]:
    img = sliced_images[idx]
    pos = all_boxes[idx]
    pos = np.array(pos)
    pos[:, 0] -= all_stride_idx[idx][0] * stride_size[0]
    pos[:, 1] -= all_stride_idx[idx][1] * stride_size[1]
    
    mismatch_angle = mismatch_angles[idx] #regs[idx].get_registration_parameters()[1][0][0] # mismatch_angles[idx]

    ref_lattice = transformed_reflattice[idx]
    in_lattice = np.array(input_lattice[idx])
    
    
        
    fig, axs = plt.subplots(1, 2, figsize=(10, 10))
    axs[0].imshow(img)
    if not len(pos) == 0:
        pos = np.array(pos)
        axs[0].scatter(pos[:, 0], pos[:, 1], c='r')
        axs[0].set_title('Pos {}, Mismatch angle = {},\n DNat = {}'.format(spm_pos[idx],
                                                                           mismatch_angle,
                                                                           (len(in_lattice) - len(ref_lattice))))
        
        axs[1].scatter(in_lattice[:, 0], in_lattice[:, 1], marker='o', s=35, label='Reconstructed columns')
        axs[1].scatter(ref_lattice[:, 0], ref_lattice[:, 1], marker='x', s =35, label='Fit')
        axs[1].legend()
        
        axs[0].set_aspect('equal')
        axs[1].set_aspect('equal')