## Machine learning session 2: Decomposition (unfilled version)
**4D-STEM data analysis workshop**
**NTNU, Trondheim, June 11, 2024** - by Tina Bergh


In [None]:
%matplotlib qt

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import hyperspy.api as hs
import pyxem as pxm

In [None]:
folder = '.\\'
file = 'SPED_Ag'
file_ending = '.hspy'

# Integrate the signal

## Decompose the integrated signal

Perform singular value decomposition (SVD) which factorises the data matrix into two new matrices. If the first argument is True, the data will be normalised for Poissonian noise prior to the decomposition. 

For more information, read the documentation here: 
- http://hyperspy.org/hyperspy-doc/current/user_guide/mva/decomposition.html
- https://scikit-learn.org/stable/modules/decomposition.html

Define a function that will create markers based on the reconstructed dataset. We also need the calibration and offset value for the integrated dataset (these are not the same as for the original dataset with 2D patterns).

In [None]:
scale = s_rad.axes_manager[2].scale
offset = s_rad.axes_manager[2].offset
s_rad.axes_manager

In [None]:
def get_marker_offset_signal(sig):
    '''
    Get the signal positions needed to define markers based on the signal.
    
    Parameters
    ----------
    sig: numpy.ndarray
        The signal in one navigation position, as signal.data

    Returns
    ----------
    xys: numpy.ndarray
        The signal positions as [[xi,yi]] in calibrated x-units.
    
    '''
    xys = np.zeros((len(sig), 2))
    for i, sd in zip(range(len(sig)), sig):
        xys[i] = [i*scale+offset, sd]
    return np.array(xys)

Interate over the reconstructed dataset using the map function to extract marker positions for each navigation position. Create markers based on these offset positions. Then plot the integrated dataset with the reconstructed dataset as markers. Inspect how well the reconstructed dataset correspond to the original dataset. 

In [None]:
offsets = s_rad_decomp.map(get_marker_offset_signal, inplace=False, ragged=True)
mark = hs.plot.markers.Points(offsets.data.T, color='blue') 
s_rad.plot()
s_rad.add_marker(mark)

## Cluster based on the decompositon results

The decomposition results can further be used for clustering, i.e. finding groups of similar datapoints. The default clustering method is k-means clustering. We have to specify how many components in the decomposition results to use, and also the number of clusters that we want in the results. 
We can also use cluster_source='decomposition' if we skip the above BSS step. This will also give good results.

For more information: 
- http://hyperspy.org/hyperspy-doc/current/user_guide/mva/clustering.html
- https://scikit-learn.org/stable/modules/clustering.html
- https://scikit-learn.org/stable/modules/clustering.html#k-means

We want to separate the nanocrystalline region from the silver region and therefore specify two components.

In [None]:
s_rad.plot_cluster_labels(axes_decor="off");
s_rad.plot_cluster_signals();

Extract the labels and use them to create a navigation mask for the nanocrystalline region.

In [None]:
labels = s_rad.get_cluster_labels()
nav_mask = labels.inav[1].T
nav_mask.change_dtype('bool')
nav_mask.plot()

# Decompose the dataset

Now we will go back to working with the original dataset containing 2D patterns. 

By experience, this does not run if you are using the bundle with the full dataset. So, if you use the bundle, to bin the dataset more to reduce the data size before we can run it successfully. 

You also want to close all other processes on your PC to free as much RAM as possible!

In [None]:
# s = s.rebin(scale=(2,2,2,2))
# nav_mask = nav_mask.rebin(scale=(2,2))
# nav_mask.change_dtype('bool')

First, look at the maximum of the signal to inspect which regions we are interested in. 

We will create a signal mask the excludes the direct beam region, since this typically gives many components in the decomposition later. We will also exclude the region at higher scattering angles, since we have enough information contained within a smaller region. In addition, the regions to higher angles show some deviations from zone axis that would give components that we are not interested in here. 

In [None]:
direct_beam_mask0 = s.get_direct_beam_mask(radius=10.) # 5 if bundle
direct_beam_mask1 = s.get_direct_beam_mask(radius=35.) # 17 if bundle
direct_beam_mask1.data = ~direct_beam_mask1.data
signal_mask = direct_beam_mask0+direct_beam_mask1
signal_mask.plot()

### Sum Friedel pairs

The next step is optional. It is possible to perform the decomposition on the original dataset. Since we only have zone axis patterns that obey Friedel's law, we can sum the left and right handsides of the patterns to get a smaller dataset prior to the decomposition. This will speed things up and possibly give less components associated with slight deviations from zone axis, so it can help to give decomposition results that are easier to interpret. 

We can then rotate the left side by 180 degrees and inspect that the right and left side now look almost identical. 

In [None]:
s_left_rot = s_left.rotate_diffraction(180)
hs.plot.plot_signals([s_left_rot, s_right], norm='log')

We can then sum the two sides and inspect the sum. The signal to noise is better, and the data size is smaller!

In [None]:
s_rot = s_right + s_left_rot
s_rot.plot(norm='log')
s_rot

We extract only the right side of the signal mask. 

In [None]:
signal_mask_rot = signal_mask.isig[signal_mask.data.shape[0]/2:,:]
signal_mask_rot.plot()

## Decompose the summed signal

The components from SVD contain both negative and positive values, which make them hard to interpret. We can instead use an iterative matrix decomposition method called non-negative matrix factorisation (NMF) where we use the constraint that all values must be non-negative. This is consistent with the intuitive notion that you add parts together to form a whole, and it typically gives components that resembles the original signal (diffraction patterns and virtual images) to a much higher degree. It is an interative method, so we have to supply an initial guess. The default is to base the initial guess on an SVD, which by experience gives the best results. 

For more information: 
- http://hyperspy.org/hyperspy-doc/current/user_guide/mva/decomposition.html#non-negative-matrix-factorization-nmf
- https://scikit-learn.org/stable/modules/decomposition.html#non-negative-matrix-factorization-nmf-or-nnmf

Extract the factors and loadings and plot them with the same navigator. The loadings resemble virtual images of the silver grains, while the factors resemble the average silver pattern within those grains. 

In [None]:
factors = s_rot.get_decomposition_factors()
loadings = s_rot.get_decomposition_loadings()
hs.plot.plot_signals([factors, loadings], cmap='magma')

## Cluster based on NMF

We cluster the NMF results to get binary masks covering each of the silver grains. We specify five clusters corresponding to the five unique silver grains. We use the same signal and navigation masks as before.

Plot the clustering results. First, the labels.

In [None]:
labels = s_rot.get_cluster_labels()
hs.plot.plot_images([labels.inav[i] for i in range(5)], overlay=True, axes_decor='off', label=None);

Then plot the mean signal within each cluster.

In [None]:
cluster_sigs = s_rot.get_cluster_signals()
cluster_sigs.data = np.nan_to_num(cluster_sigs.data)
hs.plot.plot_images([cluster_sigs.inav[i] for i in range(5)], overlay=True, axes_decor='off', label=None);

Save the labels.

In [None]:
labels.save(folder+file+'_labels.hspy', overwrite=True)

Sum the original signal within the labels. We will do this using the map function and therefore define our own function for summing.

In [None]:
def sum_signal_in_label(label):
    return np.sum(s.data[label], axis=(0))

In [None]:
labels_sum_signal = labels.map(sum_signal_in_label, inplace=False)
labels_sum_signal = pxm.signals.ElectronDiffraction2D(labels_sum_signal)
labels_sum_signal

We set the diffraction calibration based on the original dataset

In [None]:
labels_sum_signal.set_diffraction_calibration(s.axes_manager[3].scale)

In [None]:
labels_sum_signal.axes_manager

In [None]:
labels_sum_signal.plot()

Save the summed signal within each cluster. We will use these as a staring point for the next session!!

In [None]:
labels_sum_signal.save(folder+file+'_labels_sum_signal.hspy', overwrite=True)

## Perform NMF with more components
If we inspect the dataset, we see that we have many more spots than those picked up by the NMF using only 6 components. In the overlap regions, we get double diffraction giving these extra spots. Where did they go? If a signal in your dataset is not accounting for a large protion of the variance in the whole dataset, this signal can go unnoticed in the decomposition. It is not trivial to find the best number of components for your dataset, and great care must be taken!

In [None]:
t0 = time()
s_rot.decomposition(True, 'NMF', 20, navigation_mask=nav_mask, signal_mask=signal_mask_rot)
tf = time()
print('NMF done in ' + str((tf-t0)) + ' s')

Plot the results and see if you can find the double diffraction spots now...

In [None]:
factors_rot_NMF20 = s_rot.get_decomposition_factors()
loadings_rot_NMF20 = s_rot.get_decomposition_loadings()
hs.plot.plot_signals([factors_rot_NMF20, loadings_rot_NMF20], cmap='viridis')