In [1]:
import snmfem.conf as conf
import json
import numpy as np
from pathlib import Path
import re
import snmfem.estimators as estimators
import snmfem.measures as measures
from sklearn.decomposition import NMF as sk_NMF
import warnings
import matplotlib.pyplot as plt

# Inputs

In [2]:
data_json = conf.SCRIPT_CONFIG_PATH / Path("dataset_Toy.json")
with open(data_json,"r") as f :
    data_dict = json.load(f)

num_list = [1,2,3]
data_folder = conf.DATASETS_PATH / Path(data_dict["data_folder"])
samples = [data_folder / Path("sample_{}.npz".format(i)) for i in num_list ]


# Default

In [3]:
default_params = {
    "n_components" : data_dict["model_parameters"]["params_dict"]["k"],
    "tol" : 1e-4,
    "max_iter" : 500000,
    "init" : "nndsvda",
    "random_state" : 1,
    "verbose" : 0
}

# NMF function

In [4]:
def run_snmfem_NMF(Xflat,true_spectra, true_maps,G, dict,skip_G = False,force_simplex = True, mu =0, epsilon_reg = 1.0, get_ind = False, u = False) : 
    # Load estimator
    Estimator = getattr(estimators, "NMF") 

    estimator = Estimator(**dict,force_simplex=force_simplex,mu=mu,epsilon_reg=epsilon_reg)
    if skip_G : 
        estimator.fit(Xflat)
    else : 
        estimator.fit(Xflat,G=G)

    G = estimator.G_
    P = estimator.P_
    A = estimator.A_
    angle = measures.find_min_angle(true_spectra,(G@P).T, get_ind, unique=u)
    mse = measures.find_min_MSE(true_maps,A, get_ind, unique=u)
    return angle, mse, (G@P), A

def run_snmfem_SmoothNMF(Xflat,true_spectra, true_maps,G, dict, shape_2D,skip_G = False,force_simplex = True, mu =0, epsilon_reg = 1.0, get_ind = False,lambda_L=0.0, u = False) : 
    # Load estimator
    Estimator = getattr(estimators, "SmoothNMF") 

    estimator = Estimator(**dict,shape_2d=shape_2D,lambda_L=lambda_L,force_simplex=force_simplex,mu=mu,epsilon_reg=epsilon_reg)
    if skip_G : 
        estimator.fit(Xflat)
    else : 
        estimator.fit(Xflat,G=G)

    G = estimator.G_
    P = estimator.P_
    A = estimator.A_
    angle = measures.find_min_angle(true_spectra,(G@P).T, get_ind, unique = u)
    mse = measures.find_min_MSE(true_maps,A, get_ind, unique=u)
    return angle, mse, (G@P), A

def run_scikit_NMF(Xflat,true_spectra, true_maps, G,dict,loss = "frobenius",alpha = 0.0, l1_ratio = 1.0, regularization = "components", get_ind = False, u = False) :
    G = None
    estimator = sk_NMF(**dict,beta_loss=loss, alpha = alpha, l1_ratio = l1_ratio, regularization = regularization,solver="mu")

    W = estimator.fit_transform(Xflat)
    H = estimator.components_
    W, H = rescaled_DA(W,H)
    angle = measures.find_min_angle(true_spectra,W.T, get_ind, unique = u)
    mse = measures.find_min_MSE(true_maps,H, get_ind, unique=u)
    return angle, mse, W, H

def load_data(sample) : 
    data = np.load(sample)
    X = data["X"]
    nx, ny, ns = X.shape
    Xflat = X.transpose([2,0,1]).reshape(ns, nx*ny)
    densities = data["densities"]
    phases = data["phases"]
    true_spectra_flat = np.expand_dims(densities, axis=1) * phases * data["N"]
    true_maps = data["weights"]
    k = true_maps.shape[2]
    true_maps_flat = true_maps.transpose([2,0,1]).reshape(k,nx*ny)
    G = data["G"]
    return Xflat, true_spectra_flat, true_maps_flat, G

def rescaled_DA(D,A) : 
    k, p = A.shape
    o = np.ones((p,))
    s = np.linalg.lstsq(A.T, o)[0]
    D_rescale = D@np.diag(1/s)
    A_rescale = np.diag(s)@A
    return D_rescale, A_rescale


# Running NMF

In [5]:
warnings.simplefilter(action='ignore', category=FutureWarning)

force_simplex = True
skip_G = False
alpha = 1.0
lambda_L = 1.0
unique = True
shape_2D = data_dict["weights_parameters"]["shape_2D"]
result_angles = np.zeros((5,data_dict["model_parameters"]["params_dict"]["k"]))
result_mse = np.zeros((5,data_dict["model_parameters"]["params_dict"]["k"]))
for s in samples : 
    Xflat, true_spectra, true_maps, G = load_data(s)
    print("NOT SMOOTH")
    snmfem = run_snmfem_NMF(Xflat, true_spectra, true_maps, G,default_params,force_simplex=force_simplex,skip_G = skip_G, u = unique)
    print("SMOOTH")
    smooth_snmfem = run_snmfem_SmoothNMF(Xflat, true_spectra, true_maps, G,default_params,force_simplex=force_simplex,skip_G = skip_G, lambda_L=lambda_L,shape_2D=shape_2D, u = unique)
    scikit_KL = run_scikit_NMF(Xflat, true_spectra, true_maps, G,default_params,loss= "kullback-leibler", u= unique)
    scikit_Fro = run_scikit_NMF(Xflat, true_spectra, true_maps, G,default_params, loss = "frobenius", u = unique)
    scikit_alpha = run_scikit_NMF(Xflat, true_spectra, true_maps, G,default_params,loss="kullback-leibler",alpha = alpha, u= unique)
    result_angles[0,:] += np.array(snmfem[0])
    result_angles[1,:] += np.array(smooth_snmfem[0])
    result_angles[2,:] += np.array(scikit_KL[0])
    result_angles[3,:] += np.array(scikit_Fro[0])
    result_angles[4,:] += np.array(scikit_alpha[0])
    result_mse[0,:] += np.array(snmfem[1])
    result_mse[1,:] += np.array(smooth_snmfem[1])
    result_mse[2,:] += np.array(scikit_KL[1])
    result_mse[3,:] += np.array(scikit_Fro[1])
    result_mse[4,:] += np.array(scikit_alpha[1])

result_angles /= len(samples)
result_mse /= len(samples)


NOT SMOOTH
exit because of negative decrease
Stopped after 1853 iterations in 0.0 minutes and 2.0 seconds.
SMOOTH
exit because of negative decrease
Stopped after 2416 iterations in 0.0 minutes and 3.0 seconds.
NOT SMOOTH
exit because of negative decrease
Stopped after 1931 iterations in 0.0 minutes and 3.0 seconds.
SMOOTH
exit because of negative decrease
Stopped after 3000 iterations in 0.0 minutes and 4.0 seconds.
NOT SMOOTH
exit because of negative decrease
Stopped after 1417 iterations in 0.0 minutes and 2.0 seconds.
SMOOTH
exit because of negative decrease
Stopped after 2112 iterations in 0.0 minutes and 3.0 seconds.


## Printing statistics

In [None]:
print("average angles of {} samples with params [force simplex : {}, skip_G : {}, alpha : {}]".format(len(samples),force_simplex,skip_G,alpha),flush=True)
print("snmfem             : ",result_angles[0,:])
print("smooth snmfem      : ",result_angles[1,:])
print("scikit KL          : ",result_angles[2,:])
print("scikit frobenius   : ",result_angles[3,:])
print("scikit KL alpha*l1 : ",result_angles[4,:])


In [None]:
print("average MSE of {} samples with params [force simplex : {}, skip_G : {}, alpha : {}]".format(len(samples),force_simplex,skip_G,alpha),flush=True)
print("snmfem             : ",result_mse[0,:])
print("smooth snmfem      : ",result_mse[1,:])
print("scikit KL          : ",result_mse[2,:])
print("scikit frobenius   : ",result_mse[3,:])
print("scikit KL alpha*l1 : ",result_mse[4,:])

# Plotting results

# NMF for one sample

In [None]:
warnings.simplefilter(action='ignore', category=FutureWarning)

force_simplex = True
skip_G = False
alpha = 1.0
index = 2
nx,ny = data_dict["weights_parameters"]["shape_2D"]
k = default_params["n_components"]
Xflat, true_spectra, true_maps, G = load_data(samples[index])
snmfem = run_snmfem_NMF(Xflat, true_spectra, true_maps, G,default_params,force_simplex=force_simplex,skip_G = skip_G,get_ind=True)
scikit_KL = run_scikit_NMF(Xflat, true_spectra, true_maps, G,default_params,loss= "kullback-leibler",get_ind = True)
scikit_Fro = run_scikit_NMF(Xflat, true_spectra, true_maps, G,default_params, loss = "frobenius", get_ind = True)
scikit_alpha = run_scikit_NMF(Xflat, true_spectra, true_maps, G,default_params,loss="kullback-leibler",alpha = alpha,get_ind= True)

## Spectra visualisation

In [None]:
true_ind = 0
plt.rcParams.update({'font.size': 22})
fig1 = plt.figure(figsize=(50, 12))
plt.subplot(141)
plt.title("SNMFEM")
plt.plot(true_spectra[true_ind],'bo',label='truth',linewidth=4)
plt.plot( snmfem[2][:,snmfem[0][true_ind]],'r-',label='reconstructed',markersize=3.5)
# plt.legend(loc='best')
plt.xlabel("Energy")
plt.ylabel("Intensity")

plt.subplot(142)
plt.title("Scikit_KL")
plt.plot(true_spectra[true_ind],'bo',label='truth',linewidth=4)
plt.plot( scikit_KL[2][:,scikit_KL[0][true_ind]],'r-',label='reconstructed',markersize=3.5)
# plt.legend(loc='best')
plt.xlabel("Energy")
# plt.ylabel("Intensity")

plt.subplot(143)
plt.title("Scikit_Fro")
plt.plot(true_spectra[true_ind],'bo',label='truth',linewidth=4)
plt.plot( scikit_Fro[2][:,scikit_Fro[0][true_ind]],'r-',label='reconstructed',markersize=3.5)
plt.legend(loc='best')
plt.xlabel("Energy")
# plt.ylabel("Intensity")

plt.subplot(144)
plt.title("Scikit_alpha")
plt.plot(true_spectra[true_ind],'bo',label='truth',linewidth=4)
plt.plot( scikit_alpha[2][:,scikit_alpha[0][true_ind]],'r-',label='reconstructed',markersize=3.5)
# plt.legend(loc='best')
plt.xlabel("Energy")
# plt.ylabel("Intensity")


fig1.tight_layout()

## Maps visualisation

In [None]:
true_ind = 0
plt.rcParams.update({'font.size': 22})
fig1 = plt.figure(figsize=(50, 12))
plt.subplot(151)

plt.imshow(snmfem[3][snmfem[0][true_ind]].reshape(nx,ny), cmap="viridis")
plt.grid(b=30)
plt.title("SNMFEM")
plt.colorbar()

plt.subplot(152)

plt.imshow(scikit_KL[3][scikit_KL[0][true_ind]].reshape(nx,ny), cmap="viridis")
plt.grid(b=30)
plt.title("Scikit_KL")
plt.colorbar()

plt.subplot(153)

plt.imshow(scikit_Fro[3][scikit_Fro[0][true_ind]].reshape(nx,ny), cmap="viridis")
plt.grid(b=30)
plt.title("Scikit_Fro")
plt.colorbar()

plt.subplot(154)

plt.imshow(scikit_alpha[3][scikit_alpha[0][true_ind]].reshape(nx,ny), cmap="viridis")
plt.grid(b=30)
plt.title("Scikit_alpha")
plt.colorbar()
#plt.clim(0, 1)

plt.subplot(155)

plt.imshow(true_maps[true_ind].reshape(nx,ny), cmap="viridis")
plt.grid(b=30)
plt.title("Truth")
plt.colorbar()

In [None]:
true_maps.shape