In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from snmfem.experiments import load_samples, print_results, load_data, run_experiment
from snmfem.measures import KLdiv
import re

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Running NMF

## Inputs

In [None]:
dataset = "dataset_EDXS.json"
n_sample = 0

In [None]:
samples, k = load_samples(dataset)
sample = samples[n_sample]

In [None]:
# About initialization
# 'random': non-negative random matrices, scaled with: sqrt(X.mean() / n_components)
# 'nndsvd': Nonnegative Double Singular Value Decomposition (NNDSVD) initialization (better for sparseness)
# 'nndsvda': NNDSVD with zeros filled with the average of X (better when sparsity is not desired)
# 'nndsvdar' NNDSVD with zeros filled with small random values (generally faster, less accurate alternative to NNDSVDa for when sparsity is not desired)

# for me random was the best initialization...
default_params = {
    "n_components" : k,
    "tol" : 1e-6,
    "max_iter" : 10000,
    "init" : "random",
    "random_state" : 1,
    "verbose" : 1
    }

params_snmf = {
    "force_simplex" : True,
    "skip_G" : False,
    "mu": np.array([0, 1, 1])
}

params_evalution = {
    "u" : True,
}

# All parameters are contained here
exp = {"name": "snmfem smooth 30", "method": "SmoothNMF", "params": {**default_params, **params_snmf, "lambda_L" : 100.0}}

In [None]:
Xflat, true_spectra, true_maps, G, shape_2d = load_data(sample)


# Run a single experiment and plotting results

In [None]:
Xflat, true_spectra, true_maps, G, shape_2d = load_data(sample)

m, (GP, A), loss  = run_experiment(Xflat, true_spectra, true_maps, G, exp, params_evalution,shape_2d)


In [None]:
Xtrue = true_spectra.T @ true_maps

In [None]:
from snmfem.measures import KL
print(KL(Xflat, Xtrue, average=True))
print(KL(Xflat, GP @ A, average=True))
print(KL(Xtrue, GP @ A, average=True))


In [None]:
# Ploting parameters
fontsize = 15
aspect_ratio = 3/4
scale = 20
cmap = plt.cm.gist_heat_r
vmin = 0
vmax = np.max(true_maps)



In [None]:
metric = np.array(m[:-1])
order = np.array(m[-1])

fig, axes = plt.subplots(k,3,figsize = (scale, scale/3*k * aspect_ratio))


for j in range(k):
    ind = np.arange(k)[order[0,j]]
    axes[j, 0].plot(true_spectra[j],'bo',label='truth',linewidth=4)
    axes[j, 0].plot(GP[:,ind] ,'r-',label='reconstructed',markersize=3.5)
    axes[j, 0].set_title("{:.2f} deg".format(metric[0,j]))

for j in range(k):
    ind = np.arange(k)[order[1,j]]
    axes[j, 1].imshow(A[ind].reshape(*shape_2d), vmin=vmin, vmax=vmax, cmap=cmap)
    axes[j, 1].set_title("Mse: {:.2f}".format(metric[1,j]))
    axes[j, 2].imshow(true_maps[j].reshape(*shape_2d), vmin=vmin, vmax=vmax, cmap=cmap)   
    
rows = ['Phase {}'.format(col) for col in range(k)]
cols = ["Phase", "Map", "Real map"]

for ax, col in zip(axes[0], cols):
    ax.set_title(col, fontsize=fontsize)

for ax, row in zip(axes[:,0], rows):
    ax.set_ylabel(row, rotation=90, fontsize=fontsize)

fig.tight_layout()

plt.show()


# Quick and dirty way to look at the Loss

In [None]:
# fig, axes = plt.subplots(1, 2, figsize=(10, 3))

# names = list(loss.dtype.names)
# values = np.array([list(e) for e in loss])

# # axes[0].plot(values[:,1:-2], markersize=3.5)
# axes[0].plot(values[:,1],'b',markersize=3.5)
# axes[0].plot(values[:,0],'r--',markersize=3.5)
# axes[0].set_yscale("log")
# axes[0].set_xlabel("number of iterations")
# # axes[0].legend(names[1:-2] + [names[0]])
# axes[0].legend([names[1]] + [names[0]])
# axes[0].set_title("Losses")

# axes[1].plot(values[:,-2:], markersize=3.5)
# axes[1].legend(names[-2:])
# axes[1].set_xlabel("number of iterations")
# axes[1].set_title("Evolution of A and P")
# axes[1].set_yscale("log")

# fig.tight_layout()

In [None]:
mark_space = 20
marker_list = ["-o","-s","->","-<","-^","-v","-d"]


fig, axes = plt.subplots(1, 4, figsize=(15, 3))

names = list(loss.dtype.names)
for j,name in enumerate(names) :
    if re.match(r".*(loss)",name) : 
        axes[0].plot(loss[name],marker_list[j%len(marker_list)],markersize=3.5,label = name,markevery = mark_space,linewidth = 2)
        axes[0].set_yscale("log")
        axes[0].legend()
        axes[0].set_xlabel("number of iterations")
    elif re.match(r"^(rel)",name) : 
        axes[1].plot(loss[name],marker_list[j%len(marker_list)],markersize=3.5,label = name,markevery = mark_space,linewidth = 2)
        axes[1].legend()
        axes[1].set_xlabel("number of iterations")
    elif re.match(r"^(ang)",name) :
        axes[2].plot(loss[name],marker_list[j%len(marker_list)],markersize=3.5,label = name,markevery = mark_space,linewidth = 2)
        axes[2].legend()
        axes[2].set_xlabel("number of iterations")
    elif re.match(r"^(mse)",name) :
        axes[3].plot(loss[name],marker_list[j%len(marker_list)],markersize=3.5,label = name,markevery = mark_space,linewidth = 2)
        axes[3].legend()
        axes[3].set_xlabel("number of iterations")

cols = ["Losses", "Evolution of A and P","Angles","MSE"]

for ax, col in zip(axes, cols):
    ax.set_title(col, fontsize=fontsize)

fig.tight_layout()