This notebook calculates the local lattice rotation, as described in Supplementary Note 1 of the AI-STEM manuscript (Leitherer et al., 2023).

In [None]:
import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
import cv2
from collections import defaultdict
from copy import deepcopy
from scipy import stats
from scipy.signal import convolve2d

# ASE
import ase
from ase import Atoms


# For point set registration
from pycpd import RigidRegistration

# ai4stem utilities
from ai4stem.utils.utils_data import get_data_filename, load_reference_lattices, load_class_dicts
from ai4stem.utils.utils_prediction import localwindow
from ai4stem.utils.utils_reconstruction import reconstruct_via_atomap, get_nn_distance, norm_window_lattice

# Define required input

In [None]:
# Please specify
# 1. Savepath, 
# 2. Stride, 
# 3. pixel / angstrom relation 
# 4. Window size (in Angstrom, will be converted to pixels according to specified pixel/angstrom relation)
# 5. image path
# 6. paths to mutual information and predictions (AI-STEM predictions)
# 7. savepath
# Here, example data contained in the github repository is employed

save_path = '.'
stride_size = [6, 6]
pixel_to_angstrom = 0.12452489444788318
window_size_angstrom = 12.
image_path = get_data_filename('data/experimental_images/Fe_bcc_100.npy') # consider Fe bcc as example
mutual_information_path = get_data_filename('data/nn_predictions/Fe_bcc_100_mutual_information.npy')
predictions_path = get_data_filename('data/nn_predictions/Fe_bcc_100_predictions.npy')
save_path = '.'

The following cells determine the window size (in pixels, adapted to the window size
employed for a pretrained model) as well as the symmetry label assigned to the bulk regions.

In [None]:
window_size = window_size_angstrom * (1. / pixel_to_angstrom)
window_size = int(round(window_size))
print('Window size [Angstrom] = {}, Window size [pixels] = {}'.format(window_size_angstrom,
                                                                      window_size))

In [None]:
# load argmax predictions, take most popular assignment as symmtery
# for which the reference training image is loaded
# This is all defined for a pretrained model and needs to be adapted if retraining is applied

# also load relation between classes and int labels
numerical_to_text_labels, text_to_numerical_labels = load_class_dicts()

argmax_predictions = np.argmax(np.load(predictions_path), axis=-1)
assigned_label = numerical_to_text_label[stats.mode(argmax_predictions)[0][0]]
print('Refernce training image: {}'.format(assigned_label))

In [None]:
# Load refernce dict used below for calculating local
# lattice rotation
reference_dict = load_reference_lattices()

# Load image, extract local windows, reconstruct atomic columns

In [None]:
# Extract local windows
print('Load image from path {}'.format(image_path))
image = np.load(image_path)

print('Extract local fragments.')
sliced_images, spm_pos, ni, nj = localwindow(image, stride_size=stride_size, pixel_max=window_size)

In [None]:
separation = 1. / pixel_to_angstrom
print(separation)

atomic_columns = reconstruct_via_atomap(image, separation=int(separation / 2.), refine=False)

In [None]:
fig, axs = plt.subplots(figsize=(15, 15))
plt.imshow(image)
plt.scatter(atomic_columns[:, 0], atomic_columns[:, 1], s=4, c='r')
plt.show()

In [None]:
# Extract atomic colums for each local window (local window calculated earlier from image)


start = [window_size, window_size]
element_agnostic = False
coordinate_vectors = atomic_columns
element_symbols = np.array(['Ti' for _ in range(len(atomic_columns[:, 0]))], dtype=object)

number_of_strides_vector = [ni, nj]

x_sliding_volume_edge_length = window_size
y_sliding_volume_edge_length = window_size
step_size_x = stride_size[0]
step_size_y = stride_size[1]
x_min = 0
y_min = 0

all_boxes = []
all_stride_idx = []

list_of_xy_boxes = []
number_of_atoms_xy = []

for i in range(number_of_strides_vector[1]):

    list_of_x_boxes = []
    number_of_atoms_x = []

    for j in range(number_of_strides_vector[0]):

        # Determine atoms within sliding box
        positionvectors_within_sliding_volume = []
        element_names_within_sliding_volume = ''


        condition = (coordinate_vectors[:,0] <= start[0]) & (coordinate_vectors[:,1] <= start[1]) \
                    & (coordinate_vectors[:,0] >= (start[0]-x_sliding_volume_edge_length)) \
                    & (coordinate_vectors[:,1] >= (start[1]-y_sliding_volume_edge_length))
        positionvectors_within_sliding_volume = coordinate_vectors[condition]
        element_names_within_sliding_volume = element_symbols[condition]
        element_names_within_sliding_volume = ''.join(list(element_names_within_sliding_volume))
        
        positionvectors_within_sliding_volume = np.array([[_[0], 
                                                           _[1],
                                                          0.0] for _ in positionvectors_within_sliding_volume])

        if len(positionvectors_within_sliding_volume) == 0:
            number_of_atoms_x.append(0)

            element_name = element_names_within_sliding_volume  # should be ''
            # create ase Atoms object Atoms(symbols='',pbc=False)
            atoms_within_sliding_volume = ase.Atoms(element_name)
            # Optional: assign label
            # atoms_within_sliding_volume.info['label']='box_label_'+str(i)+str(j)+str(k)
            atoms_within_sliding_volume.set_pbc(False) # added for safety
            list_of_x_boxes.append(atoms_within_sliding_volume)

        else:
            number_of_atoms_x.append(len(positionvectors_within_sliding_volume))

            if element_agnostic:
                element_name = 'Fe'+str(len(positionvectors_within_sliding_volume))
            else:
                element_name = element_names_within_sliding_volume
            atoms_within_sliding_volume = ase.Atoms(element_name, positionvectors_within_sliding_volume)
            # Optional: assign label
            # atoms_within_sliding_volume.info['label']='box_label_'+str(i)+str(j)+str(k)
            atoms_within_sliding_volume.set_pbc(False) # added for safety
            list_of_x_boxes.append(atoms_within_sliding_volume)
        all_boxes.append([[_[0], _[1]] for _ in positionvectors_within_sliding_volume])
        all_stride_idx.append([j, i])

        start[0] += step_size_x

    number_of_atoms_xy.append(number_of_atoms_x)
    list_of_xy_boxes.append(list_of_x_boxes)
    start[0] = x_min+x_sliding_volume_edge_length  # Reset x_value after most inner for loop finished
    start[1] += step_size_y  # next y

In [None]:
# Visualize some of the local windows
for idx in range(len(sliced_images))[::500]:
    img = sliced_images[idx]
    pos = all_boxes[idx]
    pos = np.array(pos)
    pos[:, 0] -= all_stride_idx[idx][0] * stride_size[0]
    pos[:, 1] -= all_stride_idx[idx][1] * stride_size[1]
                 
    fig, axs = plt.subplots()
    axs.imshow(img)
    if not len(pos) == 0:
        pos = np.array(pos)
        axs.scatter(pos[:, 0], pos[:, 1], c='r')

# Given predicted symmetry, define reference lattice

In [None]:
reference_pixeltoangstrom = 0.12 # this fixed by the simulation

In [None]:
# Show all reference dicts
for key in reference_dict:
    fig, axs = plt.subplots()
    axs.imshow(reference_dict[key], cmap='gray')
    axs.set_title(key)

Load training image for currently assigned label

In [None]:
tr_img = reference_dict[assigned_label]
print('Reference image selected for current input: {}'.format(assigned_label))
plt.imshow(tr_img, cmap='gray')
plt.show()

In [None]:
# Reconstruct real-space lattice

ref_separation = int( (1. / reference_pixeltoangstrom) / 2. )
print('Separation employed for reference image = {}'.format(ref_separation))


atom_positions = reconstruct_via_atomap(tr_img,
                                        separation=ref_separation,
                                        refine=False)
fig, axs = plt.subplots(figsize=(10,10))
axs.set_aspect('equal')
plt.scatter(atom_positions[:, 0], atom_positions[:, 1], s=3)

In [None]:
window_size * pixel_to_angstrom

In [None]:
# Given input window size, extract same window size from training image.
# Assume adjusted window size -> need to adjust window used for mask!

# 1. convert given window size into angstrom
window_input_angstrom = float(window_size) * pixel_to_angstrom
window_reference = window_input_angstrom * ( 1. / float(reference_pixeltoangstrom))
print(round(window_reference))

In [None]:
# Extract segment from center of reference image, 
center = np.mean(atom_positions, axis=0)

mask_x = (atom_positions[:, 0] >= (center[0] - float(window_reference) / 2.)) & \
         (atom_positions[:, 0] <= (center[0] + float(window_reference) / 2.))
mask_y = (atom_positions[:, 1] >= (center[1] - float(window_reference) / 2.)) & \
         (atom_positions[:, 1] <= (center[1] + float(window_reference) / 2.))
    
mask = mask_x & mask_y
filtered_columns = atom_positions[mask]
fig, axs = plt.subplots()
axs.set_aspect('equal')
plt.scatter(filtered_columns[:, 0], filtered_columns[:, 1], s=1)
plt.show()

In [None]:
# Normalize extracted reference lattice and then extract small local
# region (not more than 20 atoms, i.e., few neareast neighbors)

reference_lattice = deepcopy(filtered_columns)
# Shift to origin
x_shift = np.mean(reference_lattice[:,0])
y_shift = np.mean(reference_lattice[:,1])

dist = np.zeros((reference_lattice.shape[0],1))
for p in range(reference_lattice.shape[0]):
    dist[p] = np.sqrt( (reference_lattice[p,0]-x_shift)**2 + (reference_lattice[p,1]-y_shift)**2 )

reference_lattice_cent = reference_lattice[np.argmin(dist)]

reference_lattice[:,0] = reference_lattice[:,0] - reference_lattice_cent[0]
reference_lattice[:,1] = reference_lattice[:,1] - reference_lattice_cent[1]

# Delete atoms until have only few nearest neighbors (20)
radius = float(window_reference) / 2.
reference_lattice_tmp = deepcopy(reference_lattice)
Nat = 20
delta = (1. / pixel_to_angstrom) / 8.
while len(reference_lattice_tmp) >= Nat:
    del_peaks = np.sqrt(reference_lattice_tmp[:,0]**2 + reference_lattice_tmp[:,1]**2) < radius
    reference_lattice_tmp = reference_lattice_tmp[del_peaks == True]
    radius -= delta
reference_lattice = reference_lattice_tmp    

fig, axs = plt.subplots()
axs.set_aspect('equal')
plt.scatter(reference_lattice[:, 0], reference_lattice[:, 1], s=1)

plt.show()

In [None]:
# Define criterion according to which the local rotation angle is calculated
# Reason: algorithm performs clock- or counter-clockwise rotation. 
# Thus need to adjust.


# All rotation symmetries
rotations_dict = {"BCC_Fe_100": 90.,
                  "BCC_Fe_110": 180., 
                  "BCC_Fe_111": 60., 
                  "FCC_Cu_100": 90., 
                  "FCC_Cu_110": 60., 
                  "FCC_Cu_111": 60., 
                  "FCC_Cu_211": 180., 
                  "HCP_Ti_0001": 60., 
                  "HCP_Ti_10m10": 180., 
                  "HCP_Ti_2m1m10": 60.}


def criterion(assigned_label, mismatch_angle):
    mismatch_angle = np.abs(mismatch_angle)
    
    symmetry_angle = rotations_dict[assigned_label]
    max_angle = symmetry_angle / 2.
    # eg for Ti, rot sym angle is 60., max angle is 30: If calculated
    # mismatch angle is larger than 30., simply subtract 60. - in
    # this way the caculated mismatch angle is the smallest angle 
    # to match reference lattice and local window columns
    
    if mismatch_angle >= max_angle:
        mismatch_angle -= symmetry_angle
        mismatch_angle = np.abs(mismatch_angle)
    return mismatch_angle
        
    

In [None]:
def scale_lattice(lattice, window_size):
    """
    Scale given (reconstructed) lattice, employing cutoff that 
    corresponds to window_size.
    """
    atoms = Atoms(np.full(len(lattice), 'Fe'), 
                  positions=[[_[0], _[1], 0.0] for _ in lattice])
    
    nn_distance = get_nn_distance(atoms, cutoff=window_size)
    
    return np.array(lattice) / nn_distance
    

In [None]:
reference_lattice_scaled = scale_lattice(reference_lattice,
                                         window_reference)

# Calculate mismatch angles

In [None]:
mismatch_angles = []

input_lattice = []
transformed_reflattice = []
regs = []

for idx, pos in enumerate(all_boxes):

    if idx % 1000 == 0:
        print('Process box {} / {}'.format(idx, len(all_boxes)))
        
    # X: target, Y: source
    Y = reference_lattice_scaled 
    
    X = norm_window_lattice(pos, reference_lattice,
                            window_size, pixel_to_angstrom) 
    X = scale_lattice(X, window_size)
    
    # AffineRegistration
    reg = RigidRegistration(**{'X': X, 'Y': Y, 'max_iterations': 100})
    reg.register()

    mismatch_angle = (-1) * np.arcsin(reg.get_registration_parameters()[1][0][1]) * 180 / np.pi

    mismatch_angle = criterion(assigned_label, mismatch_angle)
    mismatch_angles.append(mismatch_angle)
    
    transformed_reflattice.append(reg.transform_point_cloud(Y))
    regs.append(reg)
    input_lattice.append(X)

In [None]:
# statistics of calculated mismatch angles
print('Min: {}, Max: {}, 5% quantile: {}, 95% quantile: {}'.format(np.min(mismatch_angles),
                                                                   np.max(mismatch_angles),
                                                                   np.quantile(mismatch_angles, 0.05),
                                                                   np.quantile(mismatch_angles, 0.95)))

vmin = np.quantile(mismatch_angles, 0.05)
vmax = np.quantile(mismatch_angles, 0.95)

In [None]:
cmap = 'plasma'
fig, axs = plt.subplots(figsize=(15, 15))
plt.imshow(np.reshape(mismatch_angles, (ni, nj)), 
           vmin=vmin, vmax=vmax, cmap=cmap)
plt.colorbar()

# Apply smoothing

In [None]:
mutinfo.shape

In [None]:
16 *12

In [None]:
# Average over neighboring strides:
nber_nn = [16, 20]
# above value of 16 has been employed for 
# AI-STEM manuscript results (corresponds to averaging over areas of 10-20 Angstrom, depending on stride) 
# larger range of values: [1, 2, 4, 8, 16]


data = np.reshape(mismatch_angles, (ni, nj))
# save results in dictionary
results = {}

# plot (and optionally save results below)
fig, ax = plt.subplots(1, len(nber_nn), facecolor='white', figsize=(25, 5))
for idx, n in enumerate(nber_nn):
    kernel = np.ones((n, n))
    
    # ignore parts of image where mutual information is above threshold (here: 0.1)
    mutinfo = np.load(mutual_information_path).flatten()
    mask = mutinfo < 0.1
    data_filtered = data.flatten()
    data_filtered[~mask] = 0.0
    data_filtered = np.reshape(data_filtered, (ni, nj))
    
    # apply smoothing
    smoothed_data = convolve2d(data_filtered, kernel, boundary='symm', mode='same') / float(n * n)
    
    # for visaulization only, focus on quantile values, to mitigate extreme values
    vmin = np.quantile(smoothed_data.flatten()[mask], 0.05)
    vmax = np.quantile(smoothed_data.flatten()[mask], 0.95)
    
    axs = ax[idx]
    
    smoothed_data = smoothed_data.flatten()
    smoothed_data[~mask] = np.nan
    cmap = plt.get_cmap('plasma')
    cmap.set_bad('gray')
    smoothed_data = np.reshape(smoothed_data, (ni, nj))
    
    im = axs.imshow(smoothed_data, vmin=vmin, vmax=vmax, cmap=cmap)
    axs.set_title('NN = {}'.format(n))
    fig.colorbar(im, ax=axs)
    results[n] = smoothed_data
    
    # uncomment to save
    #np.save(os.path.join(save_path, '{}_mismatch_w_smoothing_nn_{}.npy'.format(assigned_label, n)),
    #        smoothed_data)
# uncomment to save
#plt.savefig(os.path.join(save_path, '{}_mismatch_w_smoothing.svg'.format(assigned_label)))
#plt.close()

# Visulize, exemplarily, local lattice reconstructions with fit

In [None]:
for idx in range(len(sliced_images))[510:520]:
    img = sliced_images[idx]
    pos = all_boxes[idx]
    pos = np.array(pos)
    pos[:, 0] -= all_stride_idx[idx][0] * stride_size[0]
    pos[:, 1] -= all_stride_idx[idx][1] * stride_size[1]
    
    mismatch_angle = mismatch_angles[idx]

    ref_lattice = transformed_reflattice[idx]
    in_lattice = np.array(input_lattice[idx])
    
    fig, axs = plt.subplots(1, 2, figsize=(10, 10))
    axs[0].imshow(img)
    if not len(pos) == 0:
        pos = np.array(pos)
        axs[0].scatter(pos[:, 0], pos[:, 1], c='r')
        axs[0].set_title('Pos {}, Mismatch angle = {},\n delta Nat = {}'.format(spm_pos[idx],
                                                                           mismatch_angle,
                                                                           (len(in_lattice) - len(ref_lattice))))
        
        axs[1].scatter(in_lattice[:, 0], in_lattice[:, 1], marker='o', s=35, label='Reconstructed columns')
        axs[1].scatter(ref_lattice[:, 0], ref_lattice[:, 1], marker='x', s =35, label='Fit')
        axs[1].legend()
        
        axs[0].set_aspect('equal')
        axs[1].set_aspect('equal')