In [None]:
! pip install umap-learn

In [None]:
import os
# tensorflow info/warnings switched off
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from tensorflow.keras.models import Model

from ai4stem.utils.utils_data import load_pretrained_model, load_example_image
from ai4stem.utils.utils_prediction import predict

from ai4stem.utils.utils_fft import calc_fft
from ai4stem.utils.utils_prediction import localwindow
from ai4stem.utils.utils_nn import decode_preds, predict_with_uncertainty

import numpy as np

import umap

import matplotlib.pyplot as plt

import pandas as pd

import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
# Specify path where to save the results:
results_folder = '.'

input_image = load_example_image()
image_name = 'Fe_bcc'
pixel_to_angstrom = 0.12452489444788318
window_size = 12.
stride_size = [36, 36]

adapted_window_size = int(window_size * (1. / pixel_to_angstrom))
print(adapted_window_size)

In [None]:
# load pretrained model
model = load_pretrained_model()
model_name = 'pretrained_model'
model.summary()

In [None]:
# calc fft
sliced_images, spm_pos, ni, nj = localwindow(input_image, stride_size=stride_size, pixel_max=adapted_window_size)

logger.info('Calculate FFT-HAADF descriptor.')
fft_descriptors = []
for im in sliced_images:
    fft_desc = calc_fft(im, sigma=None, thresholding=True)
    fft_descriptors.append(fft_desc)
    
    
data = np.array([np.stack([_]) for _ in fft_descriptors])
data = np.moveaxis(data, 1, -1)

In [None]:
# Get hidden representations

inputs = model.input
outpout_layer_name = 'Dense_1'
outputs = model.get_layer(outpout_layer_name).output
intermediate_layer_model = Model(inputs=inputs,
                                 outputs=outputs)
intermediate_layer_model.summary()

In [None]:
nn_representations = decode_preds(data, intermediate_layer_model, n_iter=10)
prediction, uncertainty = predict_with_uncertainty(data, model, 
                                                   model_type='classification', n_iter=10)

In [None]:
nn_representations.shape

In [None]:
layer_activations = {'nn_rep': nn_representations}
targets = {'nn_rep': {'argmax': prediction.argmax(axis=-1), 'mut_info': uncertainty['mutual_information']}}

In [None]:
# Apply unsupervised analysis

n_neighbors_list = [5, 50, 200]
metric = 'euclidean'
n_components = 2
s = 2.5
edgecolors = 'face'

for n_neighbors in n_neighbors_list:
    print(n_neighbors)

    for key in layer_activations:
        print(key)
        data_for_fitting = layer_activations[key]

        mapper1 = umap.UMAP(n_neighbors=n_neighbors, 
                            metric=metric, 
                            n_components=n_components).fit(data_for_fitting)
        embedding = mapper1.transform(data_for_fitting)

        for target in targets[key]:
            cmap = None
            nber_unique_colors = np.unique(targets[key][target]).size
            if target == 'mut_info':
                cmap = 'hot'
            else:
                cmap = 'tab10'
            fig, axs = plt.subplots(facecolor='white', figsize=(10, 10))
            df = pd.DataFrame({'e1': embedding[:, 0], 'e2': embedding[:, 1], 'target': targets[key][target]})
            
            if target == 'argmax_pred':
                df['target'] = [text_to_numerical_label[_] for _ in df['target'].values]
            
            im = axs.scatter(df['e1'].values, df['e2'].values, c=df['target'], cmap=cmap, s=s)
            axs.set_aspect('equal')
            fig.colorbar(im, ax=axs)
            plt.tight_layout()
            #plt.savefig(os.path.join(save_path, '{}_{}_nn_{}_embedding.png'.format(key, target, n_neighbors)), dpi=200)
            #plt.close()
            plt.show()