In [None]:
import os
import sip
for t in ["QDate", "QVariant", "QDateTime", "QTextStream", "QString", "QTime", "QUrl"]:
    sip.setapi(t, 2)
import numpy as np
import hyperspy.api as hs
import matplotlib.pyplot as plt
from scipy.linalg import svd

In [None]:
%matplotlib qt5

In [None]:
def NNRPCA_2d(Y, lambda_val=1.0, mu=1.0, max_iter=100, rho=1.6):
    
    m, n = Y.shape

    L = np.zeros((m, n))  # Initialize low-rank matrix
    S = np.zeros((m, n))  # Initialize sparse matrix
    E = np.zeros((m, n))  # Initialize auxiliary variable

    for iter in range(max_iter):
        # Update low-rank component (L) using Singular Value Thresholding (SVT)
        U, Sigma, Vt = svd(Y - S + (1 / rho) * E, full_matrices=False)
        shrinkage = np.maximum(Sigma - 1.0 / rho, 0)
        L = np.dot(U, np.dot(np.diag(shrinkage), Vt))

        # Update sparse component (S) with soft thresholding
        S = np.maximum(Y - L + (1 / rho) * E - lambda_val / rho, 0)

        # Update auxiliary variable (E)
        E = E + rho * (Y - L - S)

        # Apply nonnegative constraint with penalty mu
        L = np.maximum(L - mu, 0)

    return L, S

In [None]:
s = hs.load('/your/path/try.dm3')

In [None]:
s.plot()
print(s.data.shape)
image_x, image_y, spec_len = s.data.shape

In [None]:
data = np.array(s)
data -= Y.min()

In [None]:
data = s.data.reshape(image_x*image_y, spec_len)

In [None]:
lambda_val = 1 / np.sqrt(np.max(data.shape))  # regularization parameter
mu = 1.25 / np.linalg.norm(data, 2)  # penalty parameter
max_iter = 2  # maximum number of iterations
result_L, result_S = NNRPCA_2d(data, lambda_val, mu, max_iter)

In [None]:
denoised_s = s.deepcopy()
noise_s = s.deepcopy()

In [None]:
denoised_s.data = result_L.reshape((image_x, image_y, spec_len))
denoised_s.plot()

In [None]:
noise_s.data = result_S.reshape((image_x, image_y, spec_len))
noise_s.plot()

In [None]:
s_decomp = denoised_s.deepcopy()
s_decomp.decomposition(method='NMF')
s_decomp.plot_explained_variance_ratio(n=10)
s_decomp.plot_decomposition_loadings(comp_ids=9, axes_decor="off")