In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from espm.conf import log_shift, dicotomy_tol, sigmaL
from espm.utils import create_laplacian_matrix
from espm.estimators import SmoothNMF
from espm.models import ToyModel
from copy import deepcopy
from espm.weights import generate_weights as gw
from espm.datasets.base import generate_spim_sample
from espm.estimators.updates import initialize_algorithms



In [None]:
laplacian = True
noise = True
force_simplex = True

# load data 
filename = f"losses_{laplacian}_{noise}_{force_simplex}.npz"
data = np.load(filename, allow_pickle=True)
losses = data["losses"]
l_infty = data["l_infty"]
params = data["params"]
captions = data["captions"]
true_D = data["true_D"]
true_H = data["true_H"]
X = data["X"]
Xdot = data["Xdot"]
W = data["W"]
H = data["H"]
gammas = data["gammas"]

In [None]:
losses_mean = np.mean(losses-l_infty.reshape(-1, 1, 1), axis=0)


In [None]:
plt.figure(figsize=[10, 6])
# plt.figure(figsize=[15, 10])

for loss, caption in zip(losses_mean, captions):
    iterations = np.arange(len(loss))+1
    if len(iterations)>10:
        plt.plot(iterations, loss, ".-", label=caption)
max_y = np.max(losses_mean)
min_y = np.min(losses_mean)
plt.ylim([min_y, max_y])
plt.xlim([1, losses_mean.shape[1]])
plt.yscale("log")
plt.xscale("log")
plt.xlabel("Iterations")
plt.legend()

In [None]:
# for gamma, caption in zip(gammas, captions):
#     iterations = np.arange(len(gamma))+1

#     plt.plot(iterations, gamma, ".", label=caption)
# plt.yscale("log")
# plt.legend()


In [None]:
Hmat = H.reshape(k, shape_2d[0], shape_2d[1])
Hmat_true = true_H.reshape(k, shape_2d[0], shape_2d[1])
scale = 4
cmap = plt.cm.viridis
plt.figure(figsize=(scale*k,2*scale))
for i in range(k):
    plt.subplot(2,k,i+1)
    plt.imshow(Hmat[i], cmap=cmap, vmin=0, vmax=1)
    plt.title(f"Estimated H {i}")
    plt.axis('off')
    plt.colorbar()
    plt.subplot(2,k,i+1+k)
    plt.imshow(Hmat_true[i], cmap=cmap, vmin=0, vmax=1)
    plt.title(f"True H {i}")
    plt.axis('off')
    plt.colorbar()

plt.tight_layout()
