# A Pipeline to extract relevant features for analysis of 4D-STEM datasets

## Subsections
#### 1. Bragg Peaks - Analyze Position Correlations
#### 2. Radial Integral - Identify valuable information
#### 3. Polar Score, RMSD, Annular Mean
#### 4. Radial Profile Coefficient Fits
#### 5. Symmetry Extraction

In [1]:
import py4DSTEM
from py4DSTEM.visualize import show_image_grid
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import featurization

## Import Dataframes

In [6]:
#dc_1.set_scan_shape(30,15)
R_Nx = 511 #dc_1.data.shape[0]
R_Ny = 511 #dc_1.data.shape[1]
Q_Nx = 256 #dc_1.data.shape[2]
Q_Ny = 256 #dc_1.data.shape[3]

In [3]:
path_BP = '/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/PaperDataRepo/AlexWorkflow/HMA/BP_centered_EC_Bin4x4_2022.02.16.npy'
path_var = '/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/PaperDataRepo/AlexWorkflow/HMA/ptransform_var_2022.02.16.npy'
path_aa1 = '/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/PaperDataRepo/AlexWorkflow/HMA/annularDF_5deg_r1_2022.02.16.npy'
path_aa2 = '/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/PaperDataRepo/AlexWorkflow/HMA/annularDF_5deg_r2_2022.02.16.npy'
path_aa3 = '/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/PaperDataRepo/AlexWorkflow/HMA/annularDF_5deg_r3_2022.02.16.npy'
path_aa4 = '/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/PaperDataRepo/AlexWorkflow/HMA/annularDF_5deg_r4_2022.02.16.npy'

In [4]:
BP = np.load(path_BP)
var = np.load(path_var)
aa1 = np.load(path_aa1)
aa2 = np.load(path_aa2)
aa3 = np.load(path_aa3)
aa4 = np.load(path_aa4)

## Clean Angular Average Features

In [7]:
aa = [aa1]
ims = [aa[i][:,j].reshape(R_Nx, R_Ny) for j in range(aa1.shape[1]) for i in range(len(aa))]

In [None]:
fig, ax = show_image_grid(lambda i:ims[i]**0.5, 8,10, returnfig = True, cmap = 'inferno')

In [None]:
# Observed that positions 60 - 65 are inhibited by beamstop, remove these positions in all AA arrays
import copy

aa1_clean = copy.deepcopy(aa1)
aa2_clean = copy.deepcopy(aa2)
aa3_clean = copy.deepcopy(aa3)
aa4_clean = copy.deepcopy(aa4)

beamstop_positions = [64, 63, 62, 61, 60]
for i in range(len(beamstop_positions)):
    aa1_clean = np.delete(aa1_clean, beamstop_positions[i], 1)
    aa2_clean = np.delete(aa2_clean, beamstop_positions[i], 1)
    aa3_clean = np.delete(aa3_clean, beamstop_positions[i], 1)
    aa4_clean = np.delete(aa4_clean, beamstop_positions[i], 1)

In [None]:
aa_clean = [aa1_clean]
ims_clean = [aa_clean[i][:,j].reshape(R_Nx, R_Ny) for j in range(aa1_clean.shape[1]) for i in range(len(aa_clean))]

In [None]:
ims_clean[1].shape

In [None]:
fig, ax = show_image_grid(lambda i:ims_clean[i]**0.5, 8,10, returnfig = True, cmap = 'inferno')

## Learn

In [None]:
## Creating colormap
import matplotlib.cm as mplcm
import matplotlib.colors as colors

NUM_COLORS = 200

cm = plt.get_cmap('gist_rainbow')
cmap = [cm(1.*i/NUM_COLORS) for i in range(NUM_COLORS)]

#cmap = mpl.colors.ListedColormap(cmap)

In [None]:
keys = ['BP', 'var', 'aa1', 'aa2', 'aa3','aa4']

In [None]:
features = [BP, var, aa1_clean, aa2_clean, aa3_clean, aa4_clean]

In [None]:
classification = featurization.clustering(keys, features)

In [None]:
classification.concatenate_features(keys = ['aa1','aa2','aa3', 'aa4'], output_key = 'aa')

In [None]:
classification.MinMaxScaler(keys = ['BP', 'var','aa'])

In [None]:
keys_all = list(classification.features.keys())
print(keys_all)

## (Radial) Variance

In [None]:
import sklearn
pca = sklearn.decomposition.PCA(n_components = 70)

In [None]:
var_pca = pca.fit_transform(classification.features['var_mms'])

In [None]:
plt.plot(pca.explained_variance_ratio_[0:10])
plt.show()

In [None]:
keys = ['var_mms']
comps = [25] * len(keys)
iters_ = [5] * len(keys)
_thresh = [0.40] * len(keys)
max_components = dict(zip(keys, comps))
merge_thresh = dict(zip(keys, _thresh))
iters = dict(zip(keys, iters_))

In [None]:
classification.NMF_iterative(keys = keys, max_components = max_components, merge_thresh = merge_thresh ,iters = iters)

In [None]:
print(classification.W['var_mms'].shape)

In [None]:
var_mms_labels = (classification.W['var_mms'].max(axis=1,keepdims=1) == classification.W['var_mms']) * classification.W['var_mms']
var_mms_ims = [var_mms_labels[:,i].reshape(R_Nx, R_Ny) for i in range(var_mms_labels.shape[1])]
fig, ax = show_image_grid(lambda i:classification.W['var_mms'][:,i].reshape(R_Nx, R_Ny)**0.5, 2,5, returnfig = True, cmap = 'inferno')

In [None]:
fig, ax = show_image_grid(lambda i:var_mms_ims[i]**0.5, 2,5, returnfig = True, cmap = 'inferno')

In [None]:
thresh = 0.001
fig, (ax2) = plt.subplots(1, 1, figsize = (10,10))
#ax1.matshow(im, cmap = 'inferno')
ax2.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')
#ax1.axis('off')
ax2.axis('off')

ival_1 = NUM_COLORS / len(var_mms_ims)

for i in range(len(var_mms_ims)):
    iterval_1 = np.floor(ival_1 * i).astype(int)
    c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]
    cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)
    ax2.matshow(np.ma.array(var_mms_ims[i], 
                           mask = var_mms_ims[i]<thresh), cmap = cm)

In [None]:
#fig.savefig('sim_bpWrxry_125comp_25iter_0.23thresh_thresh0.001.svg', format='svg', dpi=1200)

In [None]:
# fp_var_W = '/Volumes/LaCie/4DSTEM/AuFilm_sim/Ag100/size0t100/Ag1p0_1t100_var_mms_nmfW_25comp_25iter_0.45thresh.npy'
# fp_var_H = '/Volumes/LaCie/4DSTEM/AuFilm_sim/Ag100/size0t100/Ag1p0_1t100_var_mms_nmfH_25comp_25iter_0.45thresh.npy'
# np.save(fp_var_W, classification.W['var_mms'])
# np.save(fp_var_H, classification.H['var_mms'])

# 25comp_25iter_0.45thresh_19.37.02min.sec - sigma2

## 25comp_25iter_0.45thresh_18.39.06min.sec

## BP

In [None]:
pca = sklearn.decomposition.PCA(n_components = 100)
bp_pca = pca.fit_transform(classification.features['BP_mms'])

In [None]:
plt.plot(pca.explained_variance_ratio_[0:30])
plt.show()

In [None]:
keys = ['BP_mms']
comps = [40] * len(keys)
iters_ = [5] * len(keys)
_thresh = [0.10] * len(keys)
max_components = dict(zip(keys, comps))
merge_thresh = dict(zip(keys, _thresh))
iters = dict(zip(keys, iters_))

In [None]:
classification.NMF_iterative(keys = keys, max_components = max_components, merge_thresh = merge_thresh ,iters = iters)

In [None]:
print(classification.W['BP_mms'].shape)

In [None]:
BP_mms_labels = (classification.W['BP_mms'].max(axis=1,keepdims=1) == classification.W['BP_mms']) * classification.W['BP_mms']
BP_mms_ims = [BP_mms_labels[:,i].reshape(R_Nx, R_Ny) for i in range(BP_mms_labels.shape[1])]

fig, ax = show_image_grid(lambda i:classification.W['BP_mms'][:,i].reshape(R_Nx, R_Ny)**0.5, 4,5, returnfig = True, cmap = 'inferno')

In [None]:
fig, ax = show_image_grid(lambda i:BP_mms_ims[i]**0.5, 4,5, returnfig = True, cmap = 'inferno')

In [None]:
thresh = 0.01
fig, (ax2) = plt.subplots(1, 1, figsize = (10,10))
#ax1.matshow(im, cmap = 'inferno')
ax2.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')
#ax1.axis('off')
ax2.axis('off')

ival_1 = NUM_COLORS / len(BP_mms_ims)

for i in range(len(BP_mms_ims)):
    iterval_1 = np.floor(ival_1 * i).astype(int)
    c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]
    cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)
    ax2.matshow(np.ma.array(BP_mms_ims[i], 
                           mask = BP_mms_ims[i]<thresh), cmap = cm)

In [None]:
#fig.savefig('sim_bpWrxry_125comp_25iter_0.23thresh_thresh0.001.svg', format='svg', dpi=1200)

In [None]:
# fp_bp_W = '/Volumes/LaCie/4DSTEM/AuFilm_sim/Ag100/size0t100/Ag1p0_1t100_bp_mms_nmfW_60comp_25iter_0.25thresh.npy'
# fp_bp_H = '/Volumes/LaCie/4DSTEM/AuFilm_sim/Ag100/size0t100/Ag1p0_1t100_bp_mms_nmfH_60comp_25iter_0.25thresh.npy'
# np.save(fp_bp_W, classification.W['BP_mms'])
# np.save(fp_bp_H, classification.H['BP_mms'])

#60comp_25iter_0.25thresh_88.18.03min.sec - sigma2

## Angular Averaging

In [None]:
pca = sklearn.decomposition.PCA(n_components = 100)
aa_pca = pca.fit_transform(classification.features['aa_mms'])

In [None]:
plt.plot(pca.explained_variance_ratio_[0:25])
plt.show()

In [None]:
keys = ['aa_mms']
comps =  [25] * len(keys)
threshs = [0.35] * len(keys)
iters_ = [10] * len(keys)
max_components = dict(zip(keys, comps))
merge_thresh = dict(zip(keys, threshs))
iters = dict(zip(keys, iters_))

In [None]:
classification.NMF_iterative(keys = keys, max_components = max_components, merge_thresh = merge_thresh, iters = iters)

In [None]:
classification.W['aa_mms'].shape

In [None]:
aa_mms_labels = (classification.W['aa_mms'].max(axis=1,keepdims=1) == classification.W['aa_mms']) * classification.W['aa_mms']
aa_mms_ims = [aa_mms_labels[:,i].reshape(R_Nx, R_Ny) for i in range(aa_mms_labels.shape[1])]

In [None]:
fig, ax = show_image_grid(lambda i:classification.W['aa_mms'][:,i].reshape(R_Nx, R_Ny)**0.5, 4,5, returnfig = True, cmap = 'inferno')

In [None]:
fig, ax = show_image_grid(lambda i:aa_mms_ims[i]**0.5, 4,5, returnfig = True, cmap = 'inferno')

In [None]:
thresh = 0.01
fig, (ax2) = plt.subplots(1, 1, figsize = (10,10))
#ax1.matshow(im, cmap = 'inferno')
ax2.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')
#ax1.axis('off')
ax2.axis('off')

ival_1 = NUM_COLORS / len(aa_mms_ims)

for i in range(len(aa_mms_ims)):
    iterval_1 = np.floor(ival_1 * i).astype(int)
    c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]
    cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)
    ax2.matshow(np.ma.array(aa_mms_ims[i], 
                           mask = aa_mms_ims[i]<thresh), cmap = cm)

In [None]:
# fp_aa_W = '/Volumes/LaCie/4DSTEM/AuFilm_sim/Ag100/size0t100/sim_aa_mms_nmfW_50comp_25iter_0.40thresh.npy'
# fp_aa_H = '/Volumes/LaCie/4DSTEM/AuFilm_sim/Ag100/size0t100/sim_aa_mms_nmfH_50comp_25iter_0.40thresh.npy'
# np.save(fp_aa_W, classification.W['aa_mms'])
# np.save(fp_aa_H, classification.H['aa_mms'])

#50comp_25iter_0.40thresh_64.45.02min.sec - sigma2

#100comp25iter0.45thresh1246.38.07min.sec