In [None]:
import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
import numpy.matlib
from scipy import ndimage
from collections import Counter
import itertools
from scipy.signal import get_window
import cv2
from collections import defaultdict
from scipy.stats import mode
from scipy import stats
from copy import deepcopy
from tensorflow.keras.models import load_model

from ai4stem.utils.utils_data import load_pretrained_model, load_example_image

# Necessary specifications

In [None]:
# Specify path where to save the results:
results_folder = '.'

input_image = load_example_image()
image_name = 'Fe_bcc'
pixel_to_angstrom = 0.12452489444788318
window_size = 12.
stride_size = [36, 36]

# If want to visualize local windows, set to true
save_local_windows = False
local_windows_path = '.'

In [None]:
plt.imshow(input_image, cmap='gray')
plt.show()

The following cells do not have to be changed:

In [None]:
filenames = {image_name: (input_image, 
                               pixel_to_angstrom)}


In [None]:
# load pretrained model
model = load_pretrained_model()
model_name = 'pretrained_model'
model.summary()

# Functions for calculating FFt, segmentation, and prediction

In [None]:
def calc_fft(img, padding=(0, 0), power=2,
             sigma=None, r_cut=None,
             thresholding=False, apply_window=True, output_size=None,
             output_shape=(64, 64)):
    """Given HAADF image, calculate HAADF-FFT descriptor
    
    Parameters: 
    
    img: np.array
        HAADF input image
    padding: tuple
        zero padding employed to bring image size to power of 2
    power: int
        Number by which FFT amplitude is exponentiated
        in order to supress small fluctuations and
        emphasize peaks
    sigma: int
        Width of gaussian window employed to cut out central
        part of the FFT. In the standard setting (sigma=None),
        no cutting employed.
    r_cut: int
        Size of rectangular window
        that is used to cut the center of the FFT.
        In the standard setting (sigma=None),
        no cutting employed.
    thresholding: bool
        [incompletely implemented] If True, apply thresholding
        procedure to mitigate influence of central peak
    output_size: tuple
        Output size of fft, if None, fft size will be given
        by img.shape[0] and img.shape[1], if output size
        larger than image size, crop image, if smaller, apply 
        zero padding 
    """

    # First step: normalize image
    img = cv2.normalize(img, None,
                       alpha=0, beta=1,
                       norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)

    if apply_window:
        # windowing
        bw2d = np.outer(get_window('hanning',img.shape[0]), 
                        np.ones(img.shape[1]))
        bw2d_1 = np.transpose(np.outer(get_window('hanning',img.shape[1]), 
                                       np.ones(img.shape[0])))
        w = np.sqrt(bw2d * bw2d_1)
        img_windowed = img * w
    else:
        img_windowed = img
    
    # Calculate FFT
    f = np.fft.fft2(img_windowed, s=output_size)
    
    # Calculate power spectrum (or higher order exponential)
    fshift = np.fft.fftshift(np.power(np.abs(f), power))
    
    # Normalization
    fshift = cv2.normalize(fshift, None,
                           alpha=0, beta=1,
                           norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    
    
    # Remove central part of image, several options:
    # Spherical cut:
    if not r_cut == None:

        xc = (fshift.shape[0] - 1.0) / 2.0
        yc = (fshift.shape[1] - 1.0) / 2.0
        # spherical mask
        a, b = xc, yc
        x, y = np.ogrid[-a:fshift.shape[0] - a, -b:fshift.shape[1] - b]

        mask_out = x * x + y * y <= r_cut * r_cut

        for i in range(fshift.shape[0]):
            for j in range(fshift.shape[1]):
                if mask_out[i, j]:
                    fshift[i, j] = 0.0
   
    # cut using gaussian window: 
    if not sigma == None:
        bw2d = np.outer(get_window(('gaussian', sigma), fshift.shape[0]), 
                    np.ones(fshift.shape[1]))
        bw2d_1 = np.transpose(np.outer(get_window(('gaussian', sigma), fshift.shape[0]), 
                                       np.ones(fshift.shape[0])))
        w = np.sqrt(bw2d * bw2d_1)
        fshift = fshift * (1-w)

    if thresholding:
        # print("Threshold FFT spectrum")
        # Previous procedure employed by Byungchul
        """
        intfft = np.sort(fshift.ravel())[::-1]
        thresh = intfft[1]

        output = fshift / thresh
        #output[np.where(output[:]<0)] = 0 Neccessary?
        output[np.where(output[:]>thresh)] = 1
        
        fshift = output
        """
        # Chris:
        fshift = cv2.normalize(fshift, None, 
                               alpha=0, beta=1, 
                               norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        fshift = fshift/.1
        fshift[fshift>1] = 1
        fshift = cv2.normalize(fshift, None, 
                               alpha=0, beta=1, 
                               norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        
    
    # Cut out 64x64 window around center of FFT
    output = fshift
    #output2 = np.zeros((64,64))
    #for i in range(0,64):
    #    for j in range(0,64):
    #        output2[i,j] = output[int(float(output.shape[0])/float(2.0))-32+i,int(float(output.shape[1])/float(2.0))-32+j]

    output2 = np.zeros(output_shape)
    for i in range(0, output_shape[0]):
        for j in range(0, output_shape[1]):
            output2[i,j] = output[int(float(output.shape[0])/2.) - int(output_shape[0]/2.) + i,
                                  int(float(output.shape[1])/2.0) - int(output_shape[1]/2.) + j]

    
    output2 = cv2.normalize(output2, None, 
                            alpha=0, beta=1, 
                            norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    
    return output2

In [None]:
def localwindow(image_in, stride_size, pixel_max=100,
                normalize_before_fft=False, normalize_after_window=False):
    x_max = image_in.shape[0]
    y_max = image_in.shape[1]

    images = []
    spm_pos = []

    i = 0
    ni = 0
    while i < x_max-pixel_max:
        j = 0
        nj = 0
        ni = ni + 1

        while j < y_max-pixel_max:
            nj = nj + 1
            image = np.zeros((pixel_max,pixel_max))
            for x in range(0,pixel_max):
                for y in range(0,pixel_max):
                    image[x,y] = image_in[x+i,y+j] 
            if normalize_before_fft:
                image = cv2.normalize(image, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
            
            """
            filename = "local_images_" + str(ni) + "_" + str(nj)
            plt.figure()
            plt.imshow(image,cmap='gray')
            #plt.colorbar()
            #plt.draw()
            plt.axis('off')
            plt.savefig(filename + '.png',bbox_inches='tight',pad_inches=0)
            plt.close()
            np.save(filename + '.npy', image)

            makeWindowingFFT.windowFFT(image,filename +'.png',normalize_after_window)
            """
            j += stride_size[1]
            images.append(image)
            spm_pos.append([i, j])
        i += stride_size[0]
    return images, np.asarray(spm_pos), ni, nj

In [None]:
def predict_with_uncertainty(data, model=None, model_type='classification', n_iter=1000):
    """This function allows to calculate the uncertainty of a neural network model using dropout.

    This follows Chap. 3 in Yarin Gal's PhD thesis:
    http://mlg.eng.cam.ac.uk/yarin/thesis/thesis.pdf

    We calculate the uncertainty of the neural network predictions in the three ways proposed in Gal's PhD thesis,
     as presented at pag. 51-54:
    - variation_ratio: defined in Eq. 3.19
    - predictive_entropy: defined in Eq. 3.20
    - mutual_information: defined at pag. 53 (no Eq. number)

    .. codeauthors:: Angelo Ziletti <angelo.ziletti@gmail.com>, Andreas Leitherer <andreas.leitherer@gmail.com

    """

    labels = []
    results = []
    for idx_iter in range(n_iter):
        if (idx_iter % (int(n_iter) / 10 + 1)) == 0:
            print("Performing forward pass: {0}/{1}".format(idx_iter + 1, n_iter))

        result = model.predict(data)
        label = result.argmax(axis=-1)

        labels.append(label)
        results.append(result)

    results = np.asarray(results)
    prediction = results.mean(axis=0)

    if model_type == 'regression':
        predictive_variance = results.var(axis=0)
        uncertainty = dict(predictive_variance=predictive_variance)

    elif model_type == 'classification':
        # variation ratio
        mode, mode_count = stats.mode(np.asarray(labels))
        variation_ratio = np.transpose(1. - mode_count.mean(axis=0) / float(n_iter))

        # predictive entropy
        # clip values to 1e-12 to avoid divergency in the log
        prediction = np.clip(prediction, a_min=1e-12, a_max=None, out=prediction)
        log_p_class = np.log2(prediction)
        entropy_all_iteration = - np.multiply(prediction, log_p_class)
        predictive_entropy = np.sum(entropy_all_iteration, axis=1)

        # mutual information
        # clip values to 1e-12 to avoid divergency in the log
        results = np.clip(results, a_min=1e-12, a_max=None, out=results)
        p_log_p_all = np.multiply(np.log2(results), results)
        exp_p_omega = np.sum(np.sum(p_log_p_all, axis=0), axis=1)
        mutual_information = predictive_entropy + 1. / float(n_iter) * exp_p_omega

        uncertainty = dict(variation_ratio=variation_ratio, predictive_entropy=predictive_entropy,
                           mutual_information=mutual_information)
    else:
        raise ValueError("Supported model types are 'classification' or 'regression'."
                         "model_type={} is not accepted.".format(model_type))

    return prediction, uncertainty

# Analyze image

In [None]:
# FFT HAADF descriptor settings
sigma = None # optional parameter
thresholding = True # very important
n_iter = 100 # MC dropout samples


counter = 0
results_dict = defaultdict(dict)
for key in filenames:

    dx_origin = filenames[key][1]
    filename = filenames[key][0]
    name = key
    
    if type(filename) == str:
        
        img = cv2.imread(filename)
        image = deepcopy(img[:, :, 0])
    else:
        image = filename
    adapted_window_size = window_size * (1. / dx_origin)
    adapted_window_size = int(round(adapted_window_size))
    print('For image called {}, window {} [Angstrom] corresponds to {} pixels'.format(key, window_size, adapted_window_size))

    sliced_images, spm_pos, ni, nj = localwindow(image, stride_size=stride_size, pixel_max=adapted_window_size)
    np.save(os.path.join(results_folder, '{}_{}_images.npy'.format(name,
                                                                       model_name)), sliced_images)
    
    fft_descriptors = []
    for im in sliced_images:
        fft_desc = calc_fft(im, sigma=sigma, thresholding=thresholding)
        fft_descriptors.append(fft_desc)
    np.save(os.path.join(results_folder, '{}_fft_desc.npy'.format(name)), np.asarray(fft_descriptors))

    repeated_images = np.array([np.stack([_]) for _ in fft_descriptors])
    repeated_images = np.moveaxis(repeated_images, 1, -1)


    prediction, uncertainty = predict_with_uncertainty(repeated_images, 
                                                   model=model, 
                                                   model_type='classification', 
                                                   n_iter=n_iter)
    np.save(os.path.join(results_folder, '{}_{}_predictions.npy'.format(name,
                                                                       model_name)), prediction)
    for key in uncertainty:
        np.save(os.path.join(results_folder, '{}_{}_{}.npy'.format(name,
                                                                  model_name,
                                                                  key)), uncertainty[key])
    argmax_pred = prediction.argmax(axis=-1)
    argmax_pred = np.reshape(argmax_pred, (ni, nj))
    mutinfo = uncertainty['mutual_information']
    mutinfo = np.reshape(mutinfo, (ni, nj))

    results_dict[model_name][name] = {}
    results_dict[model_name][name]['Prediction'] = argmax_pred
    results_dict[model_name][name]['Mutual information'] = mutinfo
    results_dict[model_name][name]['Input Image'] = image

# Visualize predictions

In [None]:
import matplotlib

matplotlib.rcParams.update({'font.size': 10})


for key in filenames:
    name = key

    argmax_pred = results_dict[model_name][name]['Prediction']
    mutinfo = results_dict[model_name][name]['Mutual information']
    image = results_dict[model_name][name]['Input Image']

    fig, axs = plt.subplots(1, 3, figsize=(10, 10))


    im1 = axs[0].imshow(image, cmap='gray')
    fig.colorbar(im1, ax=axs[0], orientation='vertical')

    im2 = axs[1].imshow(argmax_pred, cmap='tab20')
    fig.colorbar(im2, ax=axs[1],  orientation='vertical')

    im3 = axs[2].imshow(mutinfo, cmap='hot', vmin=0.0)
    fig.colorbar(im3, ax=axs[2],  orientation='vertical')
    
    axs[0].axis('off')
    axs[1].axis('off')
    axs[2].axis('off')

    fig.tight_layout()
    
    plt.show()

