In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
from esmpy.datasets.generate_weights import generate_weights
from esmpy.datasets.generate_EDXS_phases import generate_random_phases, unique_elts, generate_modular_phases
from esmpy.datasets import generate_dataset
from esmpy.utils import arg_helper
import matplotlib.pyplot as plt
import numpy as np
from esmpy.conf import DEFAULT_SYNTHETIC_DATA_DICT



# Generate small but realistic data

In [None]:
seed = 0
n_phases = 3
N = 200

weights_dict = {
    "weight_type" : "sphere",
    "shape_2d" : [30,30],
    "weights_params": {"radius" : 1.2}
}

elts_dicts = [
    {
        "Mg" : 0.245, "Fe" : 0.035, "Ca" : 0.031, "Si" : 0.219, "Al" : 0.024, "O" : 0.446, "Cu" : 0.05, "Hf" : 0.01
    },
    {
        "Mg" : 0.522, "Fe" : 0.104, "O" : 0.374, "Cu" : 0.05
    },
    {
        "Mg" : 0.020, "Fe" : 0.018, "Ca" : 0.188, "Si" : 0.173, "Al" : 0.010, "O" : 0.591, "Ti" : 0.004, "Cu" : 0.05, "Sm" : 0.007, "Lu" : 0.006, "Nd" : 0.006
    }]

brstlg_pars = [
    {"b0" : 0.0003458, "b1" : 0.0006268},
    {"b0" : 0.0001629, "b1" : 0.0009812},
    {"b0" : 0.0007853, "b1" : 0.0003658}
]

model_params = {
        "e_offset" : 0.3,
        "e_size" : 1000,
        "e_scale" : 0.01,
        "width_slope" : 0.01,
        "width_intercept" : 0.065,
        "db_name" : "default_xrays.json",
        "E0" : 200,
        "params_dict" : {
            "Abs" : {
                "thickness" : 100.0e-7,
                "toa" : 35,
                "density" : 4.5,
                "atomic_fraction" : False
            },
            "Det" : "SDD_efficiency.txt"
        }
    }

phases, full_dict = generate_modular_phases(elts_dicts=elts_dicts, brstlg_pars = brstlg_pars, scales = [1, 1, 1], model_params= model_params, seed = seed)

np.random.seed(seed)
name = "small_FpBrgCaPv_N30"
data_dict = {
    "N" : N,
    "densities" : [1.0,0.8,1.2],
    "data_folder" : name,
    "seed" : seed
}

data_dict.update(weights_dict)
data_dict.update(full_dict)

input_dict = arg_helper(data_dict,DEFAULT_SYNTHETIC_DATA_DICT)

generate_dataset(**data_dict, seeds_range=5)

In [None]:
from esmpy.conf import DATASETS_PATH
from pathlib import Path
import hyperspy.api as hs
name = "small_FpBrgCaPv_N30"

number = 0
p = DATASETS_PATH / Path(name) / "sample_{}.hspy".format(number)
spim = hs.load(str(p))


In [None]:
from esmpy.estimators import SmoothNMF
# wmu = np.ones(3)
# wmu[0] = 1
est = SmoothNMF(n_components=3, lambda_L=10, mu=5, epsilon_reg=0.1,  hspy_comp=True, init="nndsvdar", G=spim.build_G(), tol=1e-6, force_simplex=True, accelerate=False)
# est = SmoothNMF(n_components=3, lambda_L=0, mu=0, epsilon_reg=0.1, force_simplex=True, hspy_comp=True, init="nndsvdar", G=spim.build_G(), tol=1e-6)
spim.decomposition(algorithm=est)

In [None]:
np.sum(np.abs(np.sum(H0, 0) -1))

In [None]:
H0 = spim.get_decomposition_loadings().data
W0 = spim.get_decomposition_factors().data
H0 = H0.reshape(3,-1)

In [None]:
# spim.plot_decomposition_loadings(3)

In [None]:
plt.figure(figsize=(15, 8))
for i in range(3):
    plt.subplot(2,3,i+1)
    plt.imshow(spim.maps_2d[:,:,i], vmin=0, vmax=1, cmap=plt.cm.hot_r)
    plt.colorbar()
    plt.subplot(2,3,i+4)
    plt.imshow(spim.get_decomposition_loadings().data[i], vmin=0, vmax=1, cmap=plt.cm.hot_r)
    plt.colorbar()


In [None]:
X = spim.X
Xdot = spim.Xdot
Hdot = spim.maps
densities = spim.metadata.Truth.Params.densities

Wdot = spim.phases @np.diag(densities)
G = spim.build_G()()

In [None]:
np.sum(X), np.sum(Xdot) /N

In [None]:
Hdot.T @ Wdot.T  - Xdot.T /N

In [None]:
W = np.linalg.lstsq(Hdot.T, X.T*30, rcond=None)[0].T

In [None]:
from esmpy.updates import multiplicative_step_w

W = np.ones([G.shape[1],3])
for i in range(100):
    W = multiplicative_step_w(X, G, W, H0)

W = G @ W

In [None]:
plt.figure(figsize=(15, 4))
for i in range(3):
    plt.subplot(1,3,i+1)
    plt.plot(np.abs(W[:,i]))
    plt.plot(Wdot[:,i], "--")



In [None]:
def special_angle(vec_gd, vec_algo):
    m = np.mean(vec_gd)
    norm = 1/(m+vec_gd)
    return angle(vec_gd*norm, vec_algo*norm)


In [None]:
plt.figure(figsize=(15, 4))
for i in range(3):
    plt.subplot(1,3,i+1)
    plt.plot(np.abs(W0.T[:,i])*5)
    plt.plot(Wdot[:,i], "--")

In [None]:
def build_problematic_dataset(shape_2d, k, l, n_poisson=200):
    assert(l>k)
    W = np.random.rand(l,k)
    W[np.random.rand(l,k)>0.3] = 0
    for i in range(k):
        W[i,i] = 1
    H = np.random.rand(k, *shape_2d)/k
    for i in range(H.size//2):
        i, j = np.random.randint(shape_2d[0]), np.random.randint(shape_2d[1])
        k2 = np.random.randint(k-1)
        H[k2+1,i,j] = 0
    H[0,:,:] = 1 - np.sum(H[1:,:,:], axis=0, keepdims=True)
    
    H = H.reshape(k, -1)
    Xdot = W @ H

    X = 1/n_poisson * np.random.poisson(n_poisson * Xdot)

    return Xdot, X, W, H

shape_2d = [5,5]
l = 10
k = 3
n_poisson = 200

Xdot, X, W, H = build_problematic_dataset(shape_2d, k, l, n_poisson)

In [None]:
from esmpy.estimators import SmoothNMF
mu = 10
epsilon = 0.1
est = SmoothNMF(n_components=k, lambda_L=0, mu=mu, epsilon_reg=epsilon, init="nndsvdar",  tol=1e-6, force_simplex=True, accelerate=False, debug=True)
W0 = est.fit_transform(X)
H0 = est.H_

In [None]:
W0.shape, H0.shape


In [None]:
from esmpy.measures import KL_loss_surrogate, KLdiv_loss, log_reg, log_surrogate
l = 10
k = 3
p = 25
mu = 10
epsilon = 0.1

W0  = np.random.rand(l, k)
H0 = np.random.rand(k, p)
H0T = np.random.rand(k, p)
X = np.random.rand(l, p)

eps =0
np.testing.assert_allclose(
    KL_loss_surrogate(X, W0, H0, H0, eps=eps), 
    KLdiv_loss(X, W0, H0, eps=eps))
np.testing.assert_allclose(
    log_surrogate(H0, H0, mu=mu, epsilon=epsilon),
    log_reg(H0, mu=mu, epsilon=epsilon))

KL_loss_surrogate(X, W0, H0, H0T, eps=eps), KL_loss_surrogate(X, W0, H0, H0, eps=eps)