In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib qt

# Imports and utility functions

In [None]:

import matplotlib.pyplot as plt
import numpy as np
from snmfem.estimator.snmf import SNMF
from snmfem import EDXS_model
from snmfem.measures import find_min_angle, find_min_MSE, residuals
from snmfem.conf import DATASETS_PATH
from pathlib import Path

# hyperspy is the commonly used library to handle spectrum images in the electron microscopy community
import hyperspy.api as hs


# Data Loading

In [None]:
folder = DATASETS_PATH / Path("aspim037_N100_2ptcls_brstlg")
sample_num = 1

sample_filenames = list(folder.glob("*.npz"))
print("Found {} samples".format(len(sample_filenames)))
sample_filename = sample_filenames[sample_num]

In [None]:
# S=hs.load(filename+".hspy")
dat = np.load(sample_filename)
X = dat["X"]
Xdot = dat["Xdot"]
phases = dat["phases"] 
densities = dat["densities"]
weights = dat["weights"]
N = dat["N"]
k = len(densities)

S = hs.signals.Signal1D(X)
S.axes_manager[2].name="Energy"
S.axes_manager[2].scale=0.01
S.axes_manager[2].offset=0.20805000000000007
S.axes_manager[2].unit="keV"

# # This part of the spectrum image contains only pure spectra from phase 0
# # This kind of area is often available in experimental datasets
X_part=S.inav[60:,:].data

In [None]:
X.shape

In [None]:
X_part.shape

In [None]:
plt.figure()
plt.imshow(np.mean(X, axis=2))
plt.colorbar()


In [None]:
plt.figure()
plt.hist(X.flatten(), 100)
plt.yscale("log")

# Interactive plotting of the data

In [None]:
roi=hs.roi.RectangularROI(12,12,24,24)
S.plot()
spim_ROI=roi.interactive(S)
    
sum_ROI=hs.interactive(spim_ROI.sum,
               event=spim_ROI.axes_manager.events.any_axis_changed,
               recompute_out_event=None)
sum_ROI.plot()

# Testing an NMF algorithm

In [None]:
D = phases.T
A = weights.reshape(-1, 3).T
w = densities
W = np.diag(w)
Dp = N * D @ W
Xdot_m = Xdot.reshape(-1, 1980).T
X_m = X.reshape(-1, 1980).T

np.testing.assert_allclose(Xdot_m, Dp @ A)
np.testing.assert_allclose(np.sum(A, axis=0), np.ones(A.shape[1]))


In [None]:
from snmfem.estimator.nmf import NMF

In [None]:
from sklearn.decomposition import TruncatedSVD
from sklearn.decomposition import NMF
from sklearn.decomposition._nmf import _initialize_nmf as initialize_nmf 

In [None]:
# U,V = initialize_nmf(Xdot_m, k)

In [None]:
def truncated_SVD(X, k, algorithm='randomized', n_iter=5, **kwargs):
    """Wrapper function to perform a truncated SVD.
    
    Returns U, V, e where X = U V and e is the explained variance ratio.
    """
    tsvd = TruncatedSVD(n_components=k, algorithm=algorithm, n_iter=n_iter, **kwargs)
    U = tsvd.fit_transform(X)
    V = tsvd.components_
    e = tsvd.explained_variance_ratio_
    return U, V, e


In [None]:
Xr = Xdot.reshape(-1, Xdot.shape[2])
Xr.shape

In [None]:
U, V, e = truncated_SVD(Xr, k =k, n_iter=20)
np.testing.assert_allclose(np.sum(e),1)
np.testing.assert_allclose((U @ V), Xr)

# Parameters intialisation

In [None]:
# True bremsstrahlung parameters
brstlg_pars = {"c0" : 4.8935e-05,"c1" : 1464.19810, "c2" : 0.04216872, "b0" : 0.15910789, "b1" : -0.00773158, "b2" : 8.7417e-04}

# SNMF parameters
tol = 1e-4
max_iter = 5000
b_tol = 1e-1
mu_sparse = 0.0
eps_sparse = 1.0
phases = 3
debug = True

# # If mu_sparse !=0 a good initialization of the first phase is required, it can be done using the spectrum below
init_matrix=np.average(X_part,axis=(0,1))

## Creating the EDXS model

In [None]:
em = EDXS_model.EDXS_Model("Data/simple_xrays_threshold.json",brstlg_pars = brstlg_pars,e_offset=S.axes_manager[2].offset,e_scale=S.axes_manager[2].scale,e_size=S.axes_manager[2].size)
em.generate_g_matr([8,13,14,12,26,29,31,72,71,62,60,92,20],brstlg = False) # If brstlg=False, it will learn the continum model, if brstlg=True it add a column in the G matrix using the parameters , 

In [None]:
S.axes_manager[2].scale, 

In [None]:
S.axes_manager[2].offset

In [None]:
brstlg_pars

In [None]:
em.g_matr.shape

In [None]:
em.x

## Loading the ground truth

In [None]:
true_spectra=[]
true_maps=[]
true_spectra.append(np.genfromtxt(filename+"spectrum_p0"))
true_spectra.append(np.genfromtxt(filename+"spectrum_p1"))
true_spectra.append(np.genfromtxt(filename+"spectrum_p2"))
true_maps.append(np.load(filename+"map_p0.npy"))
true_maps.append(np.load(filename+"map_p1.npy"))
true_maps.append(np.load(filename+"map_p2.npy"))

true_spectra = np.array(true_spectra)
true_maps = np.array(true_maps)


In [None]:
true_maps.shape, true_spectra.shape

In [None]:
plt.figure(figsize=(10,10))
plt.subplot(221)
plt.plot(em.x, true_spectra.T);
plt.yscale("log")
for i in range(3):
    plt.subplot(2,2,i+2)
    plt.imshow(true_maps[i])

In [None]:
# Checking the loss of the ground truth
n_components = true_maps.shape[0]
im_size = true_maps.shape[1:]

perfect_reconstruction = (true_maps.reshape(n_components,-1).T @ true_spectra ).reshape(*im_size, -1)
diff = perfect_reconstruction -X
np.std(diff)

In [None]:
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.imshow(np.mean(perfect_reconstruction, axis=2))
plt.colorbar()
plt.subplot(122)
plt.imshow(np.mean(X, axis=2))
plt.colorbar()


# SNMF

## Initialize the algorithm

In [None]:
mdl = SNMF(max_iter = max_iter, tol = tol, b_tol = b_tol, mu_sparse=mu_sparse, eps_sparse = eps_sparse, num_phases=phases,edxs_model=em, brstlg_pars = brstlg_pars, init_spectrum = init_matrix, debug=debug)

## Running the algorithm

In [None]:
mdl.fit(X,eval_print=50, flush=False)

In [None]:
# from time import sleep
# for i in range(10):
#     print(f"\r {i}", end="", flush=True)
#     sleep(0.1)


# Results

## Comparison with ground truth

In [None]:
# Returns the angles between the ground truth and the endmembers found using SNMF
angles=find_min_angle(true_spectra,[mdl.get_phase_spectrum(0),mdl.get_phase_spectrum(1),mdl.get_phase_spectrum(2)])

maps=find_min_MSE(true_maps,[mdl.get_phase_map(0),mdl.get_phase_map(1),mdl.get_phase_map(2)])

print("Angle phase 0 :",angles[0])
print("Angle phase 1 :",angles[1])
print("Angle phase 2 :",angles[2])
print("MSE phase 0 :",maps[0])
print("MSE phase 1 :",maps[1])
print("MSE phase 2 :",maps[2])

## Visualisation of the results 

### Matching endmembers and ground truth

In [None]:
# switch is the index of the endmember (0 to 2)
switch = 1
# true is the index of the ground truth (0 to 2)
true=1

## Endmember + abundance plot

In [None]:
plt.rcParams.update({'font.size': 22})
fig1 = plt.figure(figsize=(20, 12))
plt.subplot(121)
plt.plot(em.x,100*true_spectra[true]/np.sum(true_spectra[true]),'bo',label='truth',linewidth=4)
plt.plot(em.x, mdl.get_phase_spectrum(switch),'r-',label='reconstructed',markersize=3.5)
plt.legend(loc='best')
plt.xlim(0, 10)
plt.ylim(0,1)
plt.ylabel("Intensity")

plt.subplot(122)
plt.imshow(mdl.get_phase_map(switch), cmap="viridis")
plt.grid(b=30)
plt.title(f"Activations of first spectrum")
plt.colorbar()
plt.clim(0, 1)

fig1.tight_layout()

## Plotting the convergence

### Losses

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1)
ax1.plot(mdl.base_losses)
ax1.set_title("base loss",fontsize=24)
ax1.set_xticks([])
ax2.plot(mdl.losses)
ax2.set_title("full loss",fontsize=24)
ax2.set_xlabel("number of iterations",fontsize=20)

### B parameters step sizes

In [None]:
fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(5, 1)
ax1.plot(np.array(mdl.eta_list)[:,0])
ax1.set_title("step size b1",fontsize=18)
ax1.set_xticks([])
ax2.plot(np.array(mdl.eta_list)[:,1])
ax2.set_title("step size b2",fontsize=18)
ax2.set_xticks([])
ax3.plot(np.array(mdl.eta_list)[:,2])
ax3.set_title("step size c0",fontsize=18)
ax3.set_xticks([])
ax4.plot(np.array(mdl.eta_list)[:,3])
ax4.set_title("step size c1",fontsize=18)
ax4.set_xticks([])
ax5.plot(np.array(mdl.eta_list)[:,4])
ax5.set_title("step size c2",fontsize=18)
ax5.set_xlabel("number of iterations",fontsize=16)

### A, P and B norms

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
ax1.plot(mdl.a_norm)
ax1.set_title("A norm",fontsize=18)
ax1.set_xticks([])
ax2.plot(mdl.p_norm)
ax2.set_title("P norm",fontsize=18)
ax2.set_xticks([])
ax3.plot(mdl.b_norm)
ax3.set_title("B norm",fontsize=18)
ax3.set_xlabel("number of iterations",fontsize=16)