# Imports

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# import hyperspy.api as hs
from snmfem.estimators import NMF
import snmfem.datasets as ds

import matplotlib.pyplot as plt
import numpy as np

In [None]:
E0 = 200
x = np.linspace(0.2,20, num = 2000 )
b0 = 1
b1 = 1
y0 = lifshin_bremsstrahlung_b0(x,b0,E0)
y1 = lifshin_bremsstrahlung_b1(x,b1,E0)

plt.plot(x,y0,label = "y0")
plt.plot(x,y1, label = "y1")
plt.legend()

# Generating artificial datasets and loading them

If the datasets were already generated, they are not generated again

In [None]:
ds.generate_built_in_datasets()
spim = ds.load_particules(sample = 0)

In [None]:
spim.metadata

In [None]:
P = spim.set_fixed_P({"p1" : {"N" : 0.0, "Yb" : 0.0, "Pt" : 0.0, "Al" : 0.0, "Ti" : 0.0, "La" : 0.0 },
 "p2" : {"V" : 0.0, "Rb" : 0.0, "W" : 0.0, "Al" : 0.0, "Ti" : 0.0, "La" : 0.0}, "p3" : {"N" : 0.0, "Yb" : 0.0, "Pt" : 0.0, "V" : 0.0, "Rb" : 0.0, "W" : 0.0}})

In [None]:
from snmfem.models.EDXS_function import G_bremsstrahlung, continuum_xrays, gaussian, read_lines_db, read_compact_db, update_bremsstrahlung, elts_dict_from_dict_list


In [None]:
build_fixed_P(spim,col1=True)

# Problem solving

Full hyperspy syntax

## Loading analysis parameters

In [None]:
G = spim.build_G("bremsstrahlung", norm = True)
shape_2d = spim.shape_2d
phases = spim.phases
weights = spim.weights

est = NMF( n_components = 3,tol=0.000001, max_iter = 1000, G = G, shape_2d = shape_2d, true_D = phases, true_A = weights,fixed_P = None,hspy_comp = True)

## Calculating the decomposition

/!\ Depending on the parameters you choose and the size of the data it might take a while

In [None]:
out = spim.decomposition(algorithm = est, return_info=True)

## Getting the losses and the results of the decomposition

In [None]:
losses = est.get_losses()
Pr = est.P_

In [None]:
out.G_.sum(axis = 0)
out.P_.sum(axis = 0)
out.A_.sum(axis = 1)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline 
import numpy as np

res = (out.G_@out.P_)[:,0]
ph = phases[:,0]

plt.plot(res/np.sum(res))
plt.plot(ph/np.sum(ph))

## Ploting the results

In [None]:
spim.plot_decomposition_factors(3)

In [None]:
spim.plot_decomposition_loadings(3)

# Problem Solving

With the usual scikit use

In [None]:
G = spim.build_G("bremsstrahlung")
shape_2d = spim.shape_2d
phases, weights = spim.phases, spim.weights
X = spim.X

est = SmoothNMF(n_components = 3,tol=0.1, G = G, shape_2d = shape_2d, lambda_L= 2, true_D = phases, true_A = weights, hspy_comp = False)



In [None]:
D = est.fit_transform(X)