This notebook demonstrates how to use UMAP, a manifold learning algorithm, to inspect the internal neural-network representations of AI-STEM that are learned during training. We will consider an experimental image (Fe bcc [100]) as an example. 

# Import packages

In [None]:
! pip install umap-learn

In [None]:
import os
# tensorflow info/warnings switched off
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from tensorflow.keras.models import Model

from ai4stem.utils.utils_data import load_pretrained_model, load_example_image
from ai4stem.utils.utils_prediction import predict

from ai4stem.utils.utils_fft import calc_fft
from ai4stem.utils.utils_prediction import localwindow
from ai4stem.utils.utils_nn import decode_preds, predict_with_uncertainty

import numpy as np

import umap

import matplotlib.pyplot as plt

import pandas as pd

import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Load example image and specify AI-STEM parameters

In [None]:
# Specify path where to save the results:
results_folder = '.'

# load image
input_image = load_example_image()
image_name = 'Fe_bcc'
# image specifications
pixel_to_angstrom = 0.12452489444788318
# AI-STEM parameters
window_size = 12.
stride_size = [36, 36]
# convert window [Angstrom] to window [pixels]
adapted_window_size = int(window_size * (1. / pixel_to_angstrom))
print(adapted_window_size)

# Load pretrained model

In [None]:
model = load_pretrained_model()
model_name = 'pretrained_model'
model.summary()

# Segment image and calculate FFT-HAADF descriptor (i.e., the neural-network input)

In [None]:
# calc fft
sliced_images, spm_pos, ni, nj = localwindow(input_image, stride_size=stride_size, pixel_max=adapted_window_size)

logger.info('Calculate FFT-HAADF descriptor.')
fft_descriptors = []
for im in sliced_images:
    fft_desc = calc_fft(im, sigma=None, thresholding=True)
    fft_descriptors.append(fft_desc)
    
# reshape such that matches model input shape
data = np.array([np.stack([_]) for _ in fft_descriptors])
data = np.moveaxis(data, 1, -1)

# Extract internal neural-network representations

In [None]:
# Define model, where remove last classification layer

inputs = model.input
# select layer before last classification layer
# as new final layer:
outpout_layer_name = 'Dense_1' 
outputs = model.get_layer(outpout_layer_name).output
intermediate_layer_model = Model(inputs=inputs,
                                 outputs=outputs)
intermediate_layer_model.summary()

In [None]:
# Compute representations
nn_representations = decode_preds(data, intermediate_layer_model, n_iter=100)
prediction, uncertainty = predict_with_uncertainty(data, model, 
                                                   model_type='classification', 
                                                   n_iter=100)

# Apply UMAP and visualize results

In [None]:
# define dictionaries for visalizing
layer_activations = {'nn_rep': nn_representations}
targets = {'nn_rep': {'argmax': prediction.argmax(axis=-1), 'mut_info': uncertainty['mutual_information']}}

In [None]:
# Apply UMAP

# most important parameter:
# number of neighbors employed
# for calculating low-dimensional (here, 2D)
# embedding
n_neighbors_list = [5, 50, 200]
# choose Euclidean metric
# for measuring distance between data points
metric = 'euclidean'
# Choose 2 as embedding dimension
n_components = 2
# plotting parameters
s = 2.5
edgecolors = 'face'

for n_neighbors in n_neighbors_list:
    logger.info('Apply UMAP for number of neighbors = '.format(n_neighbors))

    for key in layer_activations:
        
        data_for_fitting = layer_activations[key]

        mapper1 = umap.UMAP(n_neighbors=n_neighbors, 
                            metric=metric, 
                            n_components=n_components).fit(data_for_fitting)
        embedding = mapper1.transform(data_for_fitting)

        for target in targets[key]:
            cmap = None
            nber_unique_colors = np.unique(targets[key][target]).size
            if target == 'mut_info':
                cmap = 'hot'
            else:
                cmap = 'tab10'
            fig, axs = plt.subplots(facecolor='white', figsize=(10, 10))
            df = pd.DataFrame({'e1': embedding[:, 0], 'e2': embedding[:, 1], 'target': targets[key][target]})
            
            if target == 'argmax_pred':
                df['target'] = [text_to_numerical_label[_] for _ in df['target'].values]
            
            im = axs.scatter(df['e1'].values, df['e2'].values, c=df['target'], cmap=cmap, s=s)
            axs.set_aspect('equal')
            fig.colorbar(im, ax=axs)
            plt.tight_layout()
            
            # uncomment if want to save
            #plt.savefig('{}_{}_nn_{}_embedding.png'.format(key, target, n_neighbors), dpi=200)
            #plt.close()
            plt.show()