# Imports

In [1]:
%matplotlib qt
import matplotlib.pyplot as plt
import numpy as np
from snmf import SNMF
import utils_V2 as u2
# hyperspy is the commonly used library to handle spectrum images in the electron microscopy community
import hyperspy.api as hs

# Load the data

In [2]:
filename="Data/aspim036_N150_2ptcls_brstlg"

S=hs.load(filename+".hspy")
X=S.data

# This part of the spectrum image contains only pure spectra from phase 0
# This kind of area is often available in experimental datasets
X_part=S.inav[60:,:].data

# Performance assessment functions

These functions are used to compare the endmembers determined by SNMF and the ground truth

In [3]:
# This function will find the best matching endmember for each true spectrum. This is useful since the A and P matrice are initialized at random. 

# This function works but can probably greatly improved
def find_min_angle (list_true_vectors,list_algo_vectors) :
    # This function calculates all the possible angles between endmembers and true spectra
    # For each true spectrum a best matching endmember is found
    # The function returns the angles of the corresponding pairs
    copy_algo_vectors=list_algo_vectors.copy()
    size=list_algo_vectors[0].shape
    ordered_angles=[]
    for i in range(len(list_true_vectors)) :
        list_angles=[]
        for j in range(len(list_algo_vectors)) :
            list_angles.append(u2.MetricsUtils.spectral_angle(list_true_vectors[i],list_algo_vectors[j]))
        ind_min=np.argmin(np.array(list_angles))
        list_algo_vectors[ind_min]=1e28*np.ones(size)
        ordered_angles.append(u2.MetricsUtils.spectral_angle(list_true_vectors[i],copy_algo_vectors[ind_min]))
    return ordered_angles

# This function works but can probably greatly improved
def find_min_MSE (list_true_maps,list_algo_maps) :
    # This function calculates all the possible MSE between abundances and true maps
    # For each true map a best matching abundance is found
    # The function returns the MSE of the corresponding pairs
    copy_algo_maps=list_algo_maps.copy()
    size=list_algo_maps[0].shape
    ordered_maps=[]
    for i in range(len(list_true_maps)) :
        list_maps=[]
        for j in range(len(list_algo_maps)) :
            list_maps.append(u2.MetricsUtils.MSE_map(list_true_maps[i],list_algo_maps[j]))
        ind_min=np.argmin(np.array(list_maps))
        list_algo_maps[ind_min]=1e28*np.ones(size)
        ordered_maps.append(u2.MetricsUtils.MSE_map(list_true_maps[i],copy_algo_maps[ind_min]))
    return ordered_maps

# This function gives the residuals between the model determined by snmf and the data that were fitted
def residuals (data,model) :
    X_sum=data.sum(axis=0).sum(axis=0)
    model_sum=model.get_phase_map(0).sum()*model.get_phase_spectrum(0)+model.get_phase_map(1).sum()*model.get_phase_spectrum(1)+model.get_phase_map(2).sum()*model.get_phase_spectrum(2)
    return X_sum-model_sum

# Parameters 

In [4]:
# True bremsstrahlung parameters
c0  =  3.943136127751902
c1  =  3.9446849862408535 
c2  =  0.027663073842682524
b0  =  0.1414560446115408
b1  =  -0.1057210517202927
b2  =  0.026461615841445782

# SNMF parameters
brstlg_pars = [b1,b2,c0,c1,c2]
tol = 1e-4
max_iter = 50000
b_tol = 0.1
mu_sparse = 0.0
eps_sparse = 1.0
phases = 3

# Loading of ground truth
true_spectra=[]
true_maps=[]
true_spectra.append(np.genfromtxt(filename+"spectrum_p0"))
true_spectra.append(np.genfromtxt(filename+"spectrum_p1"))
true_spectra.append(np.genfromtxt(filename+"spectrum_p2"))
true_maps.append(np.load(filename+"map_p0.npy"))
true_maps.append(np.load(filename+"map_p1.npy"))
true_maps.append(np.load(filename+"map_p2.npy"))

# If required the b_matr optimization can be bypassed using a brstlg input
x_scale = u2.Gaussians().x
brstlg = u2.Distributions.simplified_brstlg_2(x_scale,b0,b1,b2,c0,c1,c2)

# If mu_sparse !=0 a good initialization of the first phase is required, it can be done using the spectrum below
init_matrix=np.average(X_part,axis=(0,1))

# SNMF

In [5]:
# Creation of an SNMF object with the parameters above
mdl = SNMF(max_iter = max_iter, tol = tol, b_tol = b_tol, mu_sparse=mu_sparse, eps_sparse = eps_sparse, num_phases=phases, bremsstrahlung=None, brstlg_pars = brstlg_pars, init_spectrum = None)

In [None]:
mdl.fit(X)

# Results

In [None]:
# Returns the angles between the ground truth and the endmembers found using SNMF
angles=find_min_angle(true_spectra,[mdl.get_phase_spectrum(0),mdl.get_phase_spectrum(1),mdl.get_phase_spectrum(2)])

maps=find_min_MSE(true_maps,[mdl.get_phase_map(0),mdl.get_phase_map(1),mdl.get_phase_map(2)])

print("Angle phase 0 :",angles[0])
print("Angle phase 1 :",angles[1])
print("Angle phase 2 :",angles[2])
print("MSE phase 0 :",maps[0])
print("MSE phase 1 :",maps[1])
print("MSE phase 2 :",maps[2])

### Visualisation of the results

In [None]:
# switch correspond to the index of the SNMF endmember
# true correspond to the index of the true spectrum
# The 2 should be changed independantly until a match is found
switch = 2
true= 2

plt.rcParams.update({'font.size': 22})
fig1 = plt.figure(figsize=(20, 12))
plt.subplot(121)
plt.plot(x_scale,true_spectra[true]/np.max(true_spectra[true]),'bo',label='truth',linewidth=4)
plt.plot(x_scale, mdl.get_phase_spectrum(switch)/np.max(mdl.get_phase_spectrum(switch)),'r-',label='reconstructed',markersize=3.5)
plt.legend(loc='best')
plt.xlim(0, 10)
plt.ylabel("Intensity")

plt.subplot(122)
plt.imshow(mdl.get_phase_map(switch), cmap="viridis")
plt.grid(b=30)
plt.title(f"Activations of first spectrum")
plt.colorbar()
plt.clim(0, 1)

fig1.tight_layout()

# Cross validation

In [None]:
init_p_matr = np.random.rand(dl.g_matr.shape[1],dl.p_)
init_p_matr[:,0] = (np.linalg.inv(dl.g_matr.T@dl.g_matr)@dl.g_matr.T@(true_spectra[0]-1.9*Distributions.simplified_brstlg(Gaussians().x,b0,b1,b2,c0))).clip(min=1e-5)
init_p_matr[:,1] = (np.linalg.inv(dl.g_matr.T@dl.g_matr)@dl.g_matr.T@(true_spectra[1]-2.0 *Distributions.simplified_brstlg(Gaussians().x,b0,b1,b2,c0))).clip(min=1e-5)
init_p_matr[:,2] = (np.linalg.inv(dl.g_matr.T@dl.g_matr)@dl.g_matr.T@(true_spectra[2]-1.7*Distributions.simplified_brstlg(Gaussians().x,b0,b1,b2,c0))).clip(min=1e-5)

init_a_matr = np.random.rand(dl.p_, X.shape[0]*X.shape[1])
init_a_matr[0,:]= true_maps[0].reshape(6400)
init_a_matr[1,:]= true_maps[1].reshape(6400)
init_a_matr[2,:]= true_maps[2].reshape(6400)