# Demonstration of unsupervised learning on engineered features using Iterative NMF
## The code is implemented in a class structure, which takes in the features previously extracted from the 4D-STEM datasets and now ready to perform unsupervised learning on. We show both the use of both PCA and Iterative NMF. PCA is used to select the initial number of components for the Iterative NMF process.

### Last modified September 25th, 2022

In [None]:
import py4DSTEM
from py4DSTEM.visualize import show_image_grid
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib.ticker as plticker
from Featurization import Featurization

In [None]:
## Creating colormap
import matplotlib.cm as mplcm
import matplotlib.colors as colors

NUM_COLORS = 200

cm = plt.get_cmap('gist_rainbow')
cmap = [cm(1.*i/NUM_COLORS) for i in range(NUM_COLORS)]

## Import Data

In [None]:
R_Nx = 100
R_Ny = 100
Q_Nx = 252
Q_Ny = 252

In [None]:
## Ag1 dataset features
# fp_BP_new = 'data/Ag1_BP.npy'
# fp_AA_new = 'data/Ag1_AA.npy'

# # Ag2 dataset features
# fp_BP_new = 'data/Ag2_BP.npy'
# fp_AA_new = 'data/Ag2_AA.npy'

# Ag3 dataset features
fp_BP_new = 'data/Ag3_BP.npy'
fp_AA_new = 'data/Ag3_AA.npy'


In [None]:
BP_new = np.load(fp_BP_new)
aa_new = np.load(fp_AA_new)

## Initialize classification class

In [None]:
keys = ['BP', 'aa']

In [None]:
classification = Featurization(keys, [BP_new, aa_new])

In [None]:
classification.MinMaxScaler(keys = ['BP', 'aa'])

## Bragg Disks

### Here, we will perform only 5 iterations to demonstrate the use of Iterative NMF and consensus clustering based on the parameters:
1. Initial number of components = 60 (comps = 60) for Ag1, Ag2, and Ag3
2. Iterations performed = 5 (iters_ = 5) -> number of times to run model with different random seeds
3. merge threshold = 0.20 (_thresh = 0.20) for Ag1, 0.25 for Ag2, and 0.25 for Ag3

In [None]:
keys = ['BP_mms']
comps = [60] * len(keys) #ag1 60; ag2 60; ag3 60
iters_ = [5] * len(keys)
_thresh = [0.25] * len(keys) #ag1 0.20; ag2 0.25; ag3 0.25
max_components = dict(zip(keys, comps))
merge_thresh = dict(zip(keys, _thresh))
iters = dict(zip(keys, iters_))

In [None]:
classification.NMF_iterative(
    keys = keys,
    max_components = max_components,
    merge_thresh = merge_thresh,
    iters = iters,
    return_all = True
)

In [None]:
# The length of this list should always be the number of iterations performed.
print(len(classification.W['BP_mms']))
print(classification.W['BP_mms'][0].shape)
print(type(classification.W['BP_mms']))

In [None]:
classification.get_class_ims(keys, classification_method = ['nmf'], R_Nx=R_Nx, R_Ny=R_Ny)

In [None]:
## This cell will show the raw individual clusters associated with each iteration
# for j in range(len(classification.W['BP_mms'])):
#     fig, ax = show_image_grid(lambda i:classification.class_ims['BP_mms'][j][i]**0.5, 4,10, returnfig = True, cmap = 'inferno')

In [None]:
# Plot raw cluster maps, no post-processing
thresh = 0.01
for j in range(len(classification.class_ims['BP_mms'])):
    fig, (ax) = plt.subplots(figsize = (6,6))
    ax.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')
    ax.axis('off')

    if  len(classification.class_ims['BP_mms'][j]) > 0:
        ival_1 = NUM_COLORS / len(classification.class_ims['BP_mms'][j])
    else: ival_1 = 1
    
    for i in range(len(classification.class_ims['BP_mms'][j])):
        iterval_1 = np.floor(ival_1 * i).astype(int)
        c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]
        cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)
        ax.matshow(np.ma.array(
            classification.class_ims['BP_mms'][j][i], 
            mask = classification.class_ims['BP_mms'][j][i]<thresh), cmap = cm)
    plt.show()

In [None]:
classification.spatial_separation(keys, size = 25, threshold = thresh, method = 'yen', clean = True)

In [None]:
## This cell will show the spatially separated and filtered individual clusters associated with each iteration
# for j in range(len(classification.spatially_separated_ims['BP_mms'])):
#     fig, ax = show_image_grid(lambda i:classification.spatially_separated_ims['BP_mms'][j][i]**0.5, 5,10, returnfig = True, cmap = 'inferno')
#     plt.show()

In [None]:
# Plot spatially separated and filtered cluster maps
for j in range(len(classification.spatially_separated_ims['BP_mms'])):
    fig, (ax) = plt.subplots(figsize = (6,6))
    ax.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')
    ax.axis('off')

    if  len(classification.spatially_separated_ims['BP_mms'][j]) > 0:
        ival_1 = NUM_COLORS / len(classification.spatially_separated_ims['BP_mms'][j])
    else: ival_1 = 1
    
    for i in range(len(classification.spatially_separated_ims['BP_mms'][j])):
        iterval_1 = np.floor(ival_1 * i).astype(int)
        c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]
        cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)
        ax.matshow(np.ma.array(
            classification.spatially_separated_ims['BP_mms'][j][i], 
            mask = classification.spatially_separated_ims['BP_mms'][j][i]<thresh), cmap = cm)
    plt.show()

In [None]:
classification.consensus(
    keys=keys,
    threshold = thresh,
    location = 'spatially_separated_ims',
    method = 'mean',
    drop = 1)

In [None]:
## This cell shows the first 8 matched clusters after performing label correspondence
consensus_bins = list(classification.consensus_dict['BP_mms'].keys())
for j in range(len(consensus_bins[0:8])):
    fig, ax = show_image_grid(lambda i:classification.consensus_dict['BP_mms'][consensus_bins[j]][i]**0.5,
                              1, 10, returnfig = True, cmap = 'inferno')
    plt.show()

In [None]:
## This cell will show the averaged consensus clusters from the bins in the cell above
# fig, ax = show_image_grid(lambda i:classification.consensus_clusters['BP_mms'][i]**0.5, 5,10, returnfig = True, cmap = 'inferno')

In [None]:
fig, (ax) = plt.subplots(figsize = (6,6))
ax.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')
ax.axis('off')


ival_1 = NUM_COLORS / len(classification.consensus_clusters['BP_mms'])

for i in range(len(classification.consensus_clusters['BP_mms'])):
    iterval_1 = np.floor(ival_1 * i).astype(int)
    c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]
    cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)
    ax.matshow(np.ma.array(
        classification.consensus_clusters['BP_mms'][i], 
        mask = classification.consensus_clusters['BP_mms'][i]<thresh), cmap = cm)
plt.show()

## Angular Average

### Here, we will perform only 5 iterations to demonstrate the use of Iterative NMF and consensus clustering based on the parameters:
1. Initial number of components = 50 (comps = 60) for Ag1, Ag2, and Ag3
2. Iterations performed = 5 (iters_ = 5) -> number of times to run model with different random seeds
3. merge threshold = 0.45 (_thresh = 0.45) for Ag1, 0.40 for Ag2, 0.40 for Ag3

In [None]:
keys = ['aa_mms']
comps = [50] * len(keys) #ag1 50; ag2 50; ag3 50
iters_ = [5] * len(keys)
_thresh = [0.40] * len(keys) #ag1 0.45; ag2 0.40; ag3 0.40
max_components = dict(zip(keys, comps))
merge_thresh = dict(zip(keys, _thresh))
iters = dict(zip(keys, iters_))

In [None]:
classification.NMF_iterative(
    keys = keys,
    max_components = max_components,
    merge_thresh = merge_thresh,
    iters = iters,
    return_all = True
)

In [None]:
classification.get_class_ims(keys, classification_method = ['nmf'], R_Nx=R_Nx, R_Ny=R_Ny)

In [None]:
## This cell will show the raw individual clusters associated with each iteration
# for j in range(len(classification.W['aa_mms'])):
#     fig, ax = show_image_grid(lambda i:classification.class_ims['aa_mms'][j][i]**0.5, 4,10, returnfig = True, cmap = 'inferno')

In [None]:
# Plot raw cluster maps, no post-processing

thresh = 0.01
for j in range(len(classification.class_ims['aa_mms'])):
    fig, (ax) = plt.subplots(figsize = (6,6))
    ax.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')
    ax.axis('off')

    if  len(classification.class_ims['aa_mms'][j]) > 0:
        ival_1 = NUM_COLORS / len(classification.class_ims['aa_mms'][j])
    else: ival_1 = 1
    
    for i in range(len(classification.class_ims['aa_mms'][j])):
        iterval_1 = np.floor(ival_1 * i).astype(int)
        c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]
        cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)
        ax.matshow(np.ma.array(
            classification.class_ims['aa_mms'][j][i], 
            mask = classification.class_ims['aa_mms'][j][i]<thresh), cmap = cm)
    plt.show()

In [None]:
classification.spatial_separation(keys, size = 25, threshold = thresh, method = 'yen', clean = True)

In [None]:
## This cell will show the spatially separated and filtered individual clusters associated with each iteration
# for j in range(len(classification.spatially_separated_ims['aa_mms'])):
#     fig, ax = show_image_grid(lambda i:classification.spatially_separated_ims['aa_mms'][j][i]**0.5, 5,10, returnfig = True, cmap = 'inferno')
#     plt.show()

In [None]:
for j in range(len(classification.spatially_separated_ims['aa_mms'])):
    fig, (ax) = plt.subplots(figsize = (6,6))
    ax.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')
    ax.axis('off')

    if  len(classification.spatially_separated_ims['aa_mms'][j]) > 0:
        ival_1 = NUM_COLORS / len(classification.spatially_separated_ims['aa_mms'][j])
    else: ival_1 = 1
    
    for i in range(len(classification.spatially_separated_ims['aa_mms'][j])):
        iterval_1 = np.floor(ival_1 * i).astype(int)
        c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]
        cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)
        ax.matshow(np.ma.array(
            classification.spatially_separated_ims['aa_mms'][j][i], 
            mask = classification.spatially_separated_ims['aa_mms'][j][i]<thresh), cmap = cm)
    plt.show()

In [None]:
classification.consensus(
    keys=keys,
    threshold = thresh,
    location = 'spatially_separated_ims',
    method = 'mean',
    drop = 1)

In [None]:
## This cell shows the first 8 matched clusters after performing label correspondence
consensus_bins = list(classification.consensus_dict['aa_mms'].keys())
for j in range(len(consensus_bins[0:8])):
    fig, ax = show_image_grid(lambda i:classification.consensus_dict['aa_mms'][consensus_bins[j]][i]**0.5,
                              1, 10, returnfig = True, cmap = 'inferno')
    plt.show()

In [None]:
## This cell will show the averaged consensus clusters from the bins in the cell above
# fig, ax = show_image_grid(lambda i:classification.consensus_clusters['aa_mms'][i]**0.5, 5,10, returnfig = True, cmap = 'inferno')

In [None]:
fig, (ax) = plt.subplots(figsize = (6,6))
ax.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')
ax.axis('off')


ival_1 = NUM_COLORS / len(classification.consensus_clusters['aa_mms'])

for i in range(len(classification.consensus_clusters['aa_mms'])):
    iterval_1 = np.floor(ival_1 * i).astype(int)
    c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]
    cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)
    ax.matshow(np.ma.array(
        classification.consensus_clusters['aa_mms'][i], 
        mask = classification.consensus_clusters['aa_mms'][i]<thresh), cmap = cm)
plt.show()