This notebook provides a step-by-step guide on how to use ai4stem for analyzing exerimental images.

In [None]:
! pip install 'git+https://github.com/AndreasLeitherer/ai4stem.git'
! pip install tensorflow
! pip install opencv-python

# Quick start

After specifying an input image (here, Fe bcc [100]), the following code can be used to analyze it via ai4stem, employing a pretrained model (which is also employed in the ai4stem publication):

In [None]:
import os
# tensorflow info/warnings switched off
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from ai4stem.utils.utils_data import load_pretrained_model, load_example_image
from ai4stem.utils.utils_prediction import predict

import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

input_image = load_example_image()
model = load_pretrained_model()

sliced_images, fft_descriptors, prediction, uncertainty = predict(input_image, model)

*Note:*

The model is trained on a specific pixel/angstrom relation. Specifically, the model is trained to classify local windows of size 12 Angstrom, which corresponds to 100 pixels in the simulation settings that we employed for creating the training set. If a different resolution is employed, we recommend to adapt the window size: given the pixel-to-Angstrom relation, calculate how much pixels correspond to 12 Angstrom and use this as window size. This is also exemplarily performed in the more detailed explanations further below. Alternatively, you may rescale the whole image (up/downsampling) such that the resolutions of your input image match the training resolution and then simply use a 100 pixels window size.

After this quickstart, more detailed explanations are provided in the following.

# Step-by-step explanations

First we import some packages:

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
import cv2
from collections import defaultdict
from copy import deepcopy

from ai4stem.utils.utils_data import load_pretrained_model, load_example_image
from ai4stem.utils.utils_fft import calc_fft
from ai4stem.utils.utils_prediction import localwindow
from ai4stem.utils.utils_nn import predict_with_uncertainty

# Necessary specifications

In [None]:
# Specify path where to save the results:
results_folder = '.'

input_image = load_example_image()
image_name = 'Fe_bcc'
pixel_to_angstrom = 0.12452489444788318
window_size = 12.
stride_size = [36, 36]

In [None]:
plt.imshow(input_image, cmap='gray')
plt.show()

The following cells do not have to be changed:

In [None]:
filenames = {image_name: (input_image, 
                               pixel_to_angstrom)}


In [None]:
# load pretrained model
model = load_pretrained_model()
model_name = 'pretrained_model'
model.summary()

# Analyze image

In [None]:
# FFT HAADF descriptor settings
# optional parameter for cutting low-frequency part
# by applying gaussian mask of width sigma; default setting
# is to not use this.
sigma = None
# Thresholding parameter; given FFT spectrum normalized
# to [0, 1], cut off at 0.1 to reduce low-frequency 
# contributions; default is is to use this setting.
thresholding = True # very important
n_iter = 100 # MC dropout samples


counter = 0
results_dict = defaultdict(dict)
for key in filenames:
    logger.info('Convert window size in Angstrom to window size in pixels.')
    dx_origin = filenames[key][1]
    filename = filenames[key][0]
    name = key
    
    if type(filename) == str:
        
        img = cv2.imread(filename)
        image = deepcopy(img[:, :, 0])
    else:
        image = filename
    adapted_window_size = window_size * (1. / dx_origin)
    adapted_window_size = int(round(adapted_window_size))
    logger.info('For image called {}, window {} [Angstrom] corresponds to {} pixels'.format(key, 
                                                                                            window_size, 
                                                                                            adapted_window_size))

    logger.info('Perform fragmentation.')
    sliced_images, spm_pos, ni, nj = localwindow(image, stride_size=stride_size, pixel_max=adapted_window_size)
    np.save(os.path.join(results_folder, '{}_{}_images.npy'.format(name,
                                                                       model_name)), sliced_images)
    logger.info('Calculate FFT-HAADF descriptor.')
    fft_descriptors = []
    for im in sliced_images:
        fft_desc = calc_fft(im, sigma=sigma, thresholding=thresholding)
        fft_descriptors.append(fft_desc)
    np.save(os.path.join(results_folder, '{}_fft_desc.npy'.format(name)), np.asarray(fft_descriptors))

    repeated_images = np.array([np.stack([_]) for _ in fft_descriptors])
    repeated_images = np.moveaxis(repeated_images, 1, -1)

    logger.info('Calculate neural-network predictions and uncertainty.')
    prediction, uncertainty = predict_with_uncertainty(repeated_images, 
                                                   model=model, 
                                                   model_type='classification', 
                                                   n_iter=n_iter)
    np.save(os.path.join(results_folder, '{}_{}_predictions.npy'.format(name,
                                                                       model_name)), prediction)
    for key in uncertainty:
        np.save(os.path.join(results_folder, '{}_{}_{}.npy'.format(name,
                                                                  model_name,
                                                                  key)), uncertainty[key])
    argmax_pred = prediction.argmax(axis=-1)
    argmax_pred = np.reshape(argmax_pred, (ni, nj))
    mutinfo = uncertainty['mutual_information']
    mutinfo = np.reshape(mutinfo, (ni, nj))

    results_dict[model_name][name] = {}
    results_dict[model_name][name]['Prediction'] = argmax_pred
    results_dict[model_name][name]['Mutual information'] = mutinfo
    results_dict[model_name][name]['Input Image'] = image
    logger.info('Calculation for image {} finished.'.format(name))

# Visualize predictions

In [None]:
import matplotlib

matplotlib.rcParams.update({'font.size': 10})


for key in filenames:
    name = key

    argmax_pred = results_dict[model_name][name]['Prediction']
    mutinfo = results_dict[model_name][name]['Mutual information']
    image = results_dict[model_name][name]['Input Image']

    fig, axs = plt.subplots(1, 3, figsize=(10, 10))


    im1 = axs[0].imshow(image, cmap='gray')
    fig.colorbar(im1, ax=axs[0], orientation='vertical', fraction=0.05)
    axs[0].set_title('Input image')

    im2 = axs[1].imshow(argmax_pred, cmap='tab10')
    fig.colorbar(im2, ax=axs[1],  orientation='vertical', fraction=0.05)
    axs[1].set_title('Assigned label')

    im3 = axs[2].imshow(mutinfo, cmap='hot', vmin=0.0)
    fig.colorbar(im3, ax=axs[2],  orientation='vertical', fraction=0.05)
    axs[2].set_title('Bayesian uncertainty \n (mutual information)')
    
    axs[0].axis('off')
    axs[1].axis('off')
    axs[2].axis('off')

    fig.tight_layout()
    
    plt.show()

