In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from snmfem.experiments import perform_simulations, load_samples, print_results, load_data, run_experiment, build_exp, fill_exp_dict
import re

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Running NMF

## Inputs

In [None]:
dataset = "dataset_3rdn_phases.json"
samples, k, g, mod = load_samples(dataset)
n_samples = 0
samples = samples[n_samples:]





In [None]:
params_evalution = {
    "u" : True,
}

# All parameters are contained here
exp_list_input = [ 
#     ({"name": "scikit L2", "method": "SKNMF"},{"max_iter" : 20}),
    ({"name": "smooth3", "method": "SmoothNMF"}, {"max_iter" : 100000, "tol" : 1e-9,"lambda_L" : 2}),
    # ({"name": "smooth05", "method": "NMF"}, {"max_iter" : 100000, "tol" : 1e-9,"lambda_L" : 0.5}),
    ({"name": "smooth1", "method": "SmoothNMF"}, {"max_iter" : 100000, "tol" : 1e-9,"lambda_L" : 1}),
    # ({"name": "smooth01", "method": "NMF"}, {"max_iter" : 100000, "tol" : 1e-9,"lambda_L" : 0.01}),
    ({"name": "smooth1mu1", "method": "SmoothNMF"}, {"max_iter" : 100000, "tol" : 1e-9,"lambda_L" : 1, "mu" : 1})
    # ({"name": "scikit KL", "method": "SKNMF"},{"beta_loss" : "kullback-leibler","max_iter" : 10000, "tol" : 1e-9})
#     ({"name": "snmfem_noG_noS", "method": "NMF"},{"skip_G" : True, "force_simplex" : False,"max_iter" : 20}),
#     ({"name": "snmfem no G", "method": "NMF"}, {"skip_G" : True,"max_iter" : 20}),
#     ({"name": "snmfem L2", "method": "NMF"}, {"l2": True,"max_iter" : 20}),
#     ({"name": "snmfem smooth 3", "method": "SmoothNMF"},{"lambda_L" : 3.0,"max_iter" : 20}),
#     ({"name": "snmfem smooth 30", "method": "SmoothNMF"},{"lambda_L" : 30.0,"max_iter" : 20}),
#     ({"name": "snmfem smooth 300", "method": "SmoothNMF"},{"lambda_L" : 300.0,"max_iter" : 20}),
#     ({"name": "snmfem L2 smooth 30", "method": "SmoothNMF"},{"l2": True, "lambda_L" : 30.0,"max_iter" : 20}),
]

exp_list = []
for elt in exp_list_input : 
    est_dict = fill_exp_dict(elt[1])
    exp_list.append(build_exp(k,elt[0],est_dict,name = elt[0]["name"]))

In [None]:
metrics_all = perform_simulations(samples, exp_list, params_evalution, G_func = True, g_pars = g, mod_pars = mod)

In [None]:
print(print_results(exp_list, metrics_all))

# Run a single experiment and plotting results

In [None]:
sample_num = 0

s = samples[sample_num]
Xflat, true_spectra, true_maps, G, shape_2d = load_data(s,G_func = True)

Gs = []
Ps = []
metrics = []
As = []
orders = []
losses = []

for exp in exp_list :
    print(exp["name"]) 
    m, (G, P, A), loss  = run_experiment(Xflat, true_spectra, true_maps, G, exp, params_evalution,shape_2d, g_pars = g, mod_pars = mod)
    metric = m[:-1]
    order = m[-1]
    Gs.append(G)
    Ps.append(P)
    As.append(A)
    metrics.append(metric)
    orders.append(order)
    losses.append(loss)
metrics = np.array(metrics)
orders = np.array(orders)

In [None]:
# Ploting parameters
fontsize = 30
aspect_ratio = 3/4
scale = 20
marker_list = ["-o","-s","->","-<","-^","-v","-d"]
mark_space = 20



In [None]:
n_exp = len(exp_list)
fig, axes = plt.subplots(n_exp,k,figsize = (scale,scale/k * n_exp * aspect_ratio))
for i, exp in enumerate(exp_list): 
    for j in range(k):
        ind = np.arange(k)[orders[i,0,j]]
        axes[i,j].plot(true_spectra[j],'bo',label='truth',linewidth=4)
        axes[i,j].plot((Gs[i]@Ps[i])[:,ind] ,'r-',label='reconstructed',markersize=3.5)
        axes[i,j].set_title("{:.2f} deg".format(metrics[i,0,j]))

cols = ['Phase {}'.format(col) for col in range(k)]
rows = ['{}'.format(exp["name"]) for exp in exp_list]

for ax, col in zip(axes[0], cols):
    ax.set_title(col, fontsize=fontsize)

for ax, row in zip(axes[:,0], rows):
    ax.set_ylabel(row, rotation=90, fontsize=fontsize)

fig.tight_layout()

plt.show()


In [None]:
n_exp = len(exp_list)
fig, axes = plt.subplots(n_exp+1,k,figsize = (scale,scale/k * (n_exp+1) * aspect_ratio))
vmin = 0
vmax = np.max(true_maps)
cmap = plt.cm.gist_heat_r
for j in range(k):
    axes[0,j].imshow(true_maps[j].reshape(*shape_2d), vmin=vmin, vmax=vmax, cmap=cmap)
    for i, exp in enumerate(exp_list): 
        ind = np.arange(k)[orders[i,1,j]]
        axes[i+1,j].imshow(As[i][ind].reshape(*shape_2d), vmin=vmin, vmax=vmax, cmap=cmap)
        axes[i+1,j].set_title("Mse: {:.2f}".format(metrics[i,1,j]))
    
cols = ['MAP {}'.format(col) for col in range(k)]
rows = ["Ground truth"] + ['{}'.format(exp["name"]) for exp in exp_list]

for ax, col in zip(axes[0], cols):
    ax.set_title(col, fontsize=fontsize)

for ax, row in zip(axes[:,0], rows):
    ax.set_ylabel(row, rotation=90, fontsize=fontsize)

fig.tight_layout()

# Nice way to look at the Loss

In [None]:
indexes = np.array([i for i, l in enumerate(losses) if l is not None])
n_plot = len(indexes)



fig, axes = plt.subplots(n_plot, 4, figsize=(20, 4*n_plot))

for it, i in enumerate(indexes): 

    # plt.title(exp_list[i]["name"])
    names = list(losses[i].dtype.names)
    for j,name in enumerate(names) :
        if re.match(r".*(loss)",name) : 
            axes[it, 0].plot(losses[i][name],marker_list[j%len(marker_list)],markersize=3.5,label = name,markevery = mark_space,linewidth = 2)
            axes[it, 0].set_yscale("log")
            axes[it,0].legend()
            axes[it, 0].set_xlabel("number of iterations")
        elif re.match(r"^(rel)",name) : 
            axes[it, 1].plot(losses[i][name],marker_list[j%len(marker_list)],markersize=3.5,label = name,markevery = mark_space,linewidth = 2)
            axes[it,1].legend()
            axes[it, 1].set_xlabel("number of iterations")
        elif re.match(r"^(ang)",name) :
            axes[it, 2].plot(losses[i][name],marker_list[j%len(marker_list)],markersize=3.5,label = name,markevery = mark_space,linewidth = 2)
            axes[it,2].legend()
            axes[it, 2].set_xlabel("number of iterations")
        elif re.match(r"^(mse)",name) :
            axes[it, 3].plot(losses[i][name],marker_list[j%len(marker_list)],markersize=3.5,label = name,markevery = mark_space,linewidth = 2)
            axes[it,3].legend()
            axes[it, 3].set_xlabel("number of iterations")

cols = ["Losses", "Evolution of A and P","Angles","MSE"]
rows = ['{}'.format(exp_list[i]["name"]) for i in indexes]

for ax, col in zip(axes[0], cols):
    ax.set_title(col, fontsize=fontsize)

for ax, row in zip(axes[:,0], rows):
    ax.set_ylabel(row, rotation=90, fontsize=fontsize)

fig.tight_layout()