In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from snmfem.experiments import perform_simulations, load_samples, print_results, load_data, run_experiment, build_exp, fill_exp_dict
import json
import re
import warnings
import snmfem.conf as c
from pathlib import Path
warnings.simplefilter(action='ignore', category=FutureWarning)

# I. Designing an experiment with synthetic data
## json of the input data

In [None]:
dataset = "EDXS_data_rand_phases.json"
samples, k = load_samples(dataset)

## parameters of the experiment

In [None]:
method = {
    "name" : "A_cool_name",
    "method" : "SmoothNMF"
}
params = {
    "lambda_L" : 1.0,
    "tol" : 1e-8,
    "max_iter" : 5
}

exp = build_exp(k,method,fill_exp_dict(params),name = method["name"])

## Saving the parameters in a json file

In [None]:
output_file = c.SCRIPT_CONFIG_PATH / Path(method["name"] + "_experiment.json")
json_dict = {"experiment" : exp, "dataset" : dataset}
with open(output_file,"w") as f : 
    json.dump(json_dict,f,indent = 4)




# II. Perform experiment from json
## Load the data

In [None]:
json_filename = c.SCRIPT_CONFIG_PATH / Path("A_cool_name_experiment.json")
with open(json_filename, "r") as f : 
    exp_dict = json.load(f)

samples, k = load_samples(exp_dict["dataset"])

## Perform the experiment on all samples

In [None]:
metrics = perform_simulations(samples, [exp_dict["experiment"]])

In [None]:
print(print_results([exp_dict["experiment"]], metrics))

## Perform the experiment on 1 sample

In [None]:
sample_num = 4
s = samples[sample_num]
Xflat, true_spectra, true_maps, G, shape_2d = load_data(s)

m, (GP, A), loss  = run_experiment(Xflat, true_spectra, true_maps, G, exp_dict["experiment"], shape_2d = shape_2d)
metric = m[:-1]
order = m[-1]

### Plot parameters

In [None]:
# Ploting parameters
fontsize = 15
aspect_ratio = 3/4
scale = 20
marker_list = ["-o","-s","->","-<","-^","-v","-d"]
mark_space = 20

## Plot the resulting spectra

In [None]:
fig, axes = plt.subplots(1,k,figsize = (scale,scale/k * aspect_ratio))
for j in range(k):
    axes[j].plot(true_spectra[j],'bo',label='truth',linewidth=4)
    axes[j].plot(GP[:,order[0][j]] ,'r-',label='reconstructed',markersize=3.5)
    axes[j].set_title("{:.2f} deg".format(metric[0][j]))

cols = ['Phase {}'.format(col) for col in range(k)]
row = exp_dict["experiment"]["name"]

for ax, col in zip(axes, cols):
    ax.set_title(col, fontsize=fontsize)

axes[0].set_ylabel(row, rotation=90, fontsize=fontsize)

fig.tight_layout()

plt.show()

## Plot the resulting maps

In [None]:
fig, axes = plt.subplots(2,k,figsize = (scale,scale/k * 2 * aspect_ratio))
vmin = 0
vmax = np.max(true_maps)
cmap = plt.cm.gist_heat_r
for j in range(k):
    axes[0,j].imshow(true_maps[j].reshape(*shape_2d), vmin=vmin, vmax=vmax, cmap=cmap)
    axes[1,j].imshow(A[order[1][j],:].reshape(*shape_2d), vmin=vmin, vmax=vmax, cmap=cmap)
    axes[1,j].set_title("Mse: {:.2f}".format(metric[1][j]))
    
cols = ['MAP {}'.format(col) for col in range(k)]
rows = ["Ground truth"] + [exp_dict["experiment"]["name"]]

for ax, col in zip(axes[0], cols):
    ax.set_title(col, fontsize=fontsize)

for ax, row in zip(axes[:,0], rows):
    ax.set_ylabel(row, rotation=90, fontsize=fontsize)

fig.tight_layout()

## Plot the losses

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(20, 4))

names = list(loss.dtype.names)
for j,name in enumerate(names) :
    if re.match(r".*(loss)",name) : 
        axes[0].plot(loss[name],marker_list[j%len(marker_list)],markersize=3.5,label = name,markevery = mark_space,linewidth = 2)
        axes[0].set_yscale("log")
        axes[0].legend()
        axes[0].set_xlabel("number of iterations")
    elif re.match(r"^(rel)",name) : 
        axes[1].plot(loss[name],marker_list[j%len(marker_list)],markersize=3.5,label = name,markevery = mark_space,linewidth = 2)
        axes[1].legend()
        axes[1].set_xlabel("number of iterations")
    elif re.match(r"^(ang)",name) :
        axes[2].plot(loss[name],marker_list[j%len(marker_list)],markersize=3.5,label = name,markevery = mark_space,linewidth = 2)
        axes[2].legend()
        axes[2].set_xlabel("number of iterations")
    elif re.match(r"^(mse)",name) :
        axes[3].plot(loss[name],marker_list[j%len(marker_list)],markersize=3.5,label = name,markevery = mark_space,linewidth = 2)
        axes[3].legend()
        axes[3].set_xlabel("number of iterations")

cols = ["Losses", "Evolution of A and P","Angles","MSE"]
row = exp_dict["experiment"]["name"]

for ax, col in zip(axes, cols):
    ax.set_title(col, fontsize=fontsize)

axes[0].set_ylabel(row, rotation=90, fontsize=fontsize)

fig.tight_layout()