In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import multiprocessing
from snmfem.updates import dichotomy_simplex
import tqdm

In [None]:
k = 5
p = 6400
span = np.logspace(-8,8,num=17)
iter = 100
TIMEOUT = 5

In [None]:
for i in tqdm.tqdm(range(iter)) : 
    scale_num = np.random.choice(span,size=(k,p))
    num = scale_num * np.random.rand(k,p)
    scale_denum = np.random.choice(span,size=(k,p))
    denum = scale_denum * np.random.rand(k,p)
    dichotomy_simplex(num, denum)
    # process = multiprocessing.Process(target=dichotomy_simplex, name="fonction de ses morts", args=(num,denum))
    # process.start()
    # process.join(TIMEOUT)
    # if process.is_alive() :
    #     print('function terminated')
    #     N, D = num, denum
    #     process.terminate()
    #     process.join()
    #     break

# Tentative de faire un plot interactif des SpIm

In [None]:
import hyperspy.api as hs 
import numpy as np
import snmfem.conf as conf
from pathlib import Path
import json

data_json = conf.SCRIPT_CONFIG_PATH / Path("dataset_EDXS_small.json")
with open(data_json,"r") as f :
    data_dict = json.load(f)

num_list = [0]
data_folder = conf.DATASETS_PATH / Path(data_dict["data_folder"])
samples = [data_folder / Path("sample_{}.npz".format(i)) for i in num_list ]

def load_data(sample) : 
    data = np.load(sample)
    X = data["X"]
    nx, ny, ns = X.shape
    Xflat = X.transpose([2,0,1]).reshape(ns, nx*ny)
    densities = data["densities"]
    phases = data["phases"]
    true_spectra_flat = np.expand_dims(densities, axis=1) * phases * data["N"]
    true_maps = data["weights"]
    k = true_maps.shape[2]
    true_maps_flat = true_maps.transpose([2,0,1]).reshape(k,nx*ny)
    G = data["G"]
    return Xflat, true_spectra_flat, true_maps_flat, G



In [None]:
nx,ny = data_dict["weights_parameters"]["shape_2D"]
e_size = data_dict["model_parameters"]["e_size"]
X = load_data(samples[0])[0].T.reshape(nx,ny,e_size)
S = hs.signals.Signal1D(X)

In [None]:
roi=hs.roi.RectangularROI(3,3,18,20)
S.plot(navigator="auto")
spim_ROI=roi.interactive(S)
    
sum_ROI=hs.interactive(spim_ROI.sum,
               event=spim_ROI.axes_manager.events.any_axis_changed,
               recompute_out_event=None)
sum_ROI.plot()

# Plot results

In [None]:
from snmfem.conf import RESULTS_PATH, DATASETS_PATH
from snmfem.experiments import load_data
from pathlib import Path
import matplotlib.pyplot as plt
file = "best_lambda.npz"
dataset_path = DATASETS_PATH / Path("aspim037_N100_2ptcls_brstlg/sample_4.npz")
path = RESULTS_PATH / Path(file)
data = np.load(path)
X,true_spectra,true_maps,G,shape_2d = load_data(dataset_path)

In [None]:
GP = data["GP"]
A = data["A"]
metrics = data["metrics"]
losses = data["losses"]

In [None]:

fig,axes = plt.subplots(3,3,figsize= (50,50))
for i in range(3) : 
    axes[i,0].plot(true_spectra[i],'bo',label='truth',markersize = 5)
    axes[i,0].plot(GP[:,i],'r-',label='reconstructed',linewidth = 2)
    axes[i,1].imshow(A[i].reshape(shape_2d))
    axes[i,2].imshow(true_maps[i].reshape(shape_2d))
axes[0,0].legend(fontsize = 22)
axes[0,0].set_title("True vs reconstructed spectra",fontsize = 22)
axes[0,1].set_title("Reconstructed maps",fontsize = 22)
axes[0,2].set_title("True maps",fontsize = 22)
    

In [None]:
fig, axes = plt.subplots(1,len(losses.dtype),figsize=(50,10))
for i,key in enumerate(losses.dtype.names) : 
    axes[i].plot(losses[key])
    axes[i].set_title(key,fontsize = 22)

In [None]:
import snmfem.measures as m
import numpy as np 
import matplotlib.pyplot as plt 

a = 130*np.random.rand(3,256)
c = 23*np.random.rand(3,256)
ac = m.square_distance(a,c)
b = m.square_distance(a,np.zeros_like(a))
x = ac/b

In [None]:
plt.imshow(ac)
plt.colorbar()

In [None]:
plt.imshow(b)
plt.colorbar()

In [None]:
plt.imshow(x)
plt.colorbar()

In [None]:
for i in range(b.shape[0])  : 
    b[i,i] = 1