In [None]:
import config
import numpy as np
import os
import time
import torch
import utils

from abc import ABC, abstractmethod
from matplotlib import pyplot as plt
from scipy import io
from torch import optim

# Custom Dataset

In [None]:
class IDNetData(ABC):
    def __init__(self) -> None:
        self.generate_dataset()

    @abstractmethod
    def read_data(self):
        pass

    def generate_dataset(self):
        input, m0, lib = self.read_data()
        self.L = m0.shape[1]
        self.P = m0.shape[0]
        self.N = input['nRow'].item() * input['nCol'].item()
        self.Nlib = lib[0,0].shape[1]
        self.nr, self.nc = input['nRow'].item(), input['nCol'].item()
        self.Y = torch.from_numpy(input['X']).type(torch.float32).reshape(self.N, self.L).T
        
        SNR = 40 
        ssigma = (self.Y.mean())*(10**(-SNR/10))
        
        self.data_sup = []
        for i in range(self.Nlib):
            M_s = torch.zeros((self.L,self.P))
            for j in range(self.P):
                m_ij = lib[0,j][:,i%lib[0,j].shape[1]] # circular shift
                M_s[:,j] = torch.from_numpy(m_ij)
            for k in range(self.P):
                a_s = torch.zeros((self.P,))
                a_s[k] = 1.0
                y_s = torch.mv(M_s,a_s) + ssigma * torch.randn((self.L,))
                self.data_sup.append((y_s,M_s,a_s))
                
        self.data_unsup = []
        for i in range(self.N):
            # self.data_unsup.append(torch.from_numpy(self.Y[:,i]).type(torch.float32))
            self.data_unsup.append(self.Y[:,i].type(torch.float32))
        self.A_gt = -torch.ones((self.P,self.N))
        self.A_cube_gt = -torch.ones((self.nr,self.nc,self.P))
        self.Mavg_th = torch.ones(self.L)
        self.Mn_th = -torch.ones((self.L,self.P,self.N))

    def getdata(self):
        return self.Y, self.data_sup, self.data_unsup
    
    def get_gt(self):
        return self.A_gt, self.A_cube_gt, self.Mavg_th, self.Mn_th
    
class JasperRidgeData(IDNetData):
    '''
        Jasper Ridge Data

        Params:
            vca: bool
                Indicate the reference used for generate the endmember library.
                If it is true, VCA is used. In other case, NFINDR.
    '''
    def __init__(self, vca=True) -> None:
        if vca:
            self.eea, self.lib_type = 'VCA', 'Lib_vca'
        else:
            self.eea, self.lib_type = 'NFINDR', 'Lib_nfindr'
        
        super().__init__()

    def read_data(self):
        input = io.loadmat(os.path.join(config.DATASET_PATH, 'jasperRidge/matlab/input.mat'))
        m0 = io.loadmat(os.path.join(config.DATASET_PATH, 'jasperRidge/matlab/endmember_estimation.mat'))[self.eea]
        lib = io.loadmat(os.path.join(config.DATASET_PATH, 'jasperRidge/matlab/extracted_bundles.mat'))[self.lib_type]
        return input, m0, lib
    
class SamsonData(IDNetData):
    '''
        Samson Data

        Params:
            vca: bool
                Indicate the reference used for generate the endmember library.
                If it is true, VCA is used. In other case, NFINDR.
    '''
    def __init__(self, vca=True) -> None:
        if vca:
            self.eea, self.lib_type = 'VCA', 'Lib_vca'
        else:
            self.eea, self.lib_type = 'NFINDR', 'Lib_nfindr'
        
        super().__init__()

    def read_data(self):
        input = io.loadmat(os.path.join(config.DATASET_PATH, 'samson/matlab/input.mat'))
        m0 = io.loadmat(os.path.join(config.DATASET_PATH, 'samson/matlab/endmember_estimation.mat'))[self.eea]
        lib = io.loadmat(os.path.join(config.DATASET_PATH, 'samson/matlab/extracted_bundles.mat'))[self.lib_type]
        return input, m0, lib
    
class ApexData(IDNetData):
    '''
        Apex Data

        Params:
            vca: bool
                Indicate the reference used for generate the endmember library.
                If it is true, VCA is used. In other case, NFINDR.
    '''
    def __init__(self, vca=True) -> None:
        if vca:
            self.eea, self.lib_type = 'VCA', 'Lib_vca'
        else:
            self.eea, self.lib_type = 'NFINDR', 'Lib_nfindr'
        
        super().__init__()

    def read_data(self):
        input = io.loadmat(os.path.join(config.DATASET_PATH, 'apex/matlab/input.mat'))
        m0 = io.loadmat(os.path.join(config.DATASET_PATH, 'apex/matlab/endmember_estimation.mat'))[self.eea]
        lib = io.loadmat(os.path.join(config.DATASET_PATH, 'apex/matlab/extracted_bundles.mat'))[self.lib_type]
        return input, m0, lib
    

In [None]:
class dataset_maker(torch.utils.data.Dataset):
    ''' 
        Pytorch Dataset for IDNet.

        Params:
            data_opt: int
                Indicate the dataset to load.
            vca: bool
                Indicate the reference used for generate the endmember library.
                If it is true, VCA is used. In other case, NFINDR.
    '''
    def __init__(self, data: IDNetData):
        

        '''initialize variables and select which data to load
        data_opt = 1 : synthetic nonlinear mixture (DC1, with the BLMM) 
        data_opt = 2 : synthetic example with variability (DC2)
        data_opt = 3--5 : real data examples (samson, jasper, cuprite)
        '''

        self.Y, self.data_sup, self.data_unsup = data.getdata()
        self.A_u, self.A_u_cube, self.M_u_avg, self.M_u_ppx = data.get_gt()        
        self.L, self.P, self.N = data.L, data.P, data.N

        if len(self.data_sup) < len(self.data_unsup):
            self.flag_unsup_is_bigger = True
        else:
            self.flag_unsup_is_bigger = False       
        
    def __len__(self):
        # take the maximum length between the supervised and unsupervised datasets
        return max(len(self.data_sup),len(self.data_unsup))

    def __getitem__(self, idx):
        # now, idx corresponds to the index among the largest (sup or unsup) dataset.
        # We can multiply it by the ratio between the smallest and the largest datset
        # and round down to an integer, to obtain the corresponding index for the 
        # smaller dataset
        if self.flag_unsup_is_bigger:
            idx_sup = int(np.floor(idx*len(self.data_sup)/len(self.data_unsup)))
            idx_unsup = idx
        else:
            idx_sup = idx
            idx_unsup = int(np.floor(idx*len(self.data_unsup)/len(self.data_sup)))
        
        # return tuples? (y) and (y,M,a)
        return self.data_unsup[idx_unsup], self.data_sup[idx_sup]
    
    
    def plot_training_EMs(self, EM_idx=-1):
        '''small method to plot EMs in the training dataset'''
        L, P, Nsamp = self.data_sup[0][1].shape[0], self.data_sup[0][1].shape[1], len(self.data_sup)
        M_train = torch.zeros((L,P,Nsamp))
        for i in range(Nsamp):
            M_train[:,:,i] = self.data_sup[0][1]
        if EM_idx == -1:
            fig, axs = plt.subplots(1, P)
            for i in range(P):
                axs[i].plot(torch.squeeze(M_train[:,i,:]))
        else:
            plt.figure()
            plt.plot(torch.squeeze(M_train[:,EM_idx,:]))
            plt.show()

# Model

In [None]:
from main_IDNet import IDNet

# Training

In [None]:
def loss_function(model, log_probs_unsup, log_probs_sup, omegas_nrmlzd, log_omegas, alphas_all, 
                     llambda, bbeta, tau, lamb_We, lamb_Wd):
    # llambda = my_llambda; # regularization between supervised and unsupervised part
    # bbeta   = 10; # extra regularization (high likelihood of endmembers and abundances training data in the posterior)
    # tau     = my_tau; # extra extra regularization (sparsity)
    # lamb_We = my_lamb_We; # penalizes network weights of nonlinear encoder
    # lamb_Wd = my_lamb_Wd; # penalizes network weights of nonlinear decoder
    
    K1 = log_probs_unsup['log_py_Ma'].shape[1]
    K2 = omegas_nrmlzd.shape[1]
    batch_size = omegas_nrmlzd.shape[0]
    
    # unsupervised part of the cost function --------------
    cost_unsup = 0
    for i in range(0,batch_size):
        for j in range(0,K1):
            # terms in the numerator
            cost_unsup = cost_unsup + log_probs_unsup['log_py_Ma'][i,j] 
            cost_unsup = cost_unsup + log_probs_unsup['log_pa'][i,j]
            cost_unsup = cost_unsup + log_probs_unsup['log_pM_Z'][i,j] 
            cost_unsup = cost_unsup + log_probs_unsup['log_pZ'][i,j]
            # terms in the denominator
            cost_unsup = cost_unsup - log_probs_unsup['log_qa_My'][i,j]
            cost_unsup = cost_unsup - log_probs_unsup['log_qM_Z'][i,j]
            cost_unsup = cost_unsup - log_probs_unsup['log_qZ_y'][i,j]
            
    # supervised part of the cost function  
    cost_sup = 0
    for i in range(0,batch_size):
        for j in range(0,K2):
            temp = 0
            # terms in the numerator
            temp = temp + log_probs_sup['log_py_Ma'][i,j] 
            temp = temp + log_probs_sup['log_pa'][i,j] 
            temp = temp + log_probs_sup['log_pM_Z'][i,j] 
            temp = temp + log_probs_sup['log_pZ'][i,j] 
            # terms in the denominator
            temp = temp - log_probs_sup['log_qa_My'][i,j] 
            temp = temp - log_probs_sup['log_qM_Z'][i,j] 
            temp = temp - log_probs_sup['log_qZ_y'][i,j] 
            # importance weight normalization (the omegas are already normalized now)
            temp = omegas_nrmlzd[i,j] * temp
            # accumulate in the cost function
            cost_sup = cost_sup + temp
            
    # regularization term
    cost_reg = 0
    for i in range(0,batch_size):
        for j in range(0,K2):
            cost_reg = cost_reg + log_probs_sup['log_qa_My'][i,j] 
            cost_reg = cost_reg + log_omegas[i,j]
    
    # yet another regularization term (sparsity on alphas)
    cost_reg_sprs = 0
    for i in range(0,batch_size):
        # this one computes it over the unsupervised data
        for j in range(0,K1):
            cost_reg_sprs = cost_reg_sprs + torch.linalg.norm(alphas_all['alphas_unsup'][:,i,j], ord=0.5) / K1
        # this one computes it over the supervised data
        for j in range(0,K2):
            cost_reg_sprs = cost_reg_sprs + torch.linalg.norm(alphas_all['alphas_sup'][:,i,j], ord=0.5) / K2
    
    
    # regularizes nonlinear mixing weights
    reg_nlin_d_weights = 0
    for param in model.fcy_Ma_nlin.parameters():
        reg_nlin_d_weights = reg_nlin_d_weights + torch.norm(param, p="fro")
    
    reg_nlin_e_weights = 0
    for param in model.fca_My_alphas.parameters():
        reg_nlin_e_weights = reg_nlin_e_weights + torch.norm(param, p="fro")    
        
    
    # now the total cost functions
    cost = cost_unsup/K1 + llambda * cost_sup/K2 + (1+bbeta) * cost_reg/K2 \
           - tau * cost_reg_sprs - lamb_We * reg_nlin_e_weights - lamb_Wd * reg_nlin_d_weights
    return -cost # maximize cost


def train(epoch, model, optimizer, train_loader, llambda, bbeta, tau, lamb_We, lamb_Wd):
    log_interval = 100 # how many batches to wait before logging training status    
    model.train()
    train_loss = 0
    
    # get one batch from supervised data and from unsupervised data
    for batch_idx, alldata in enumerate(train_loader):
        optimizer.zero_grad()
        log_probs_unsup, log_probs_sup, omegas_nrmlzd, log_omegas, alphas_all = model(alldata)
        loss = loss_function(model, log_probs_unsup, log_probs_sup, omegas_nrmlzd, log_omegas, alphas_all,
                             llambda, bbeta, tau, lamb_We, lamb_Wd)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(alldata), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(alldata)))
    
    avg_loss = train_loss / len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, avg_loss))
    return avg_loss

def test(epoch, model, train_loader):
    model.eval()
    L, N = train_loader.dataset.L, train_loader.dataset.N
    with torch.no_grad():
        Y = torch.zeros((L,N))
        for i in range(N):
            Y[:,i] = train_loader.dataset.data_unsup[i]
        A_est, Mn_est, M_avg, Y_rec, a_nlin_deg = model.unmix(Y)
        
    return A_est, Mn_est, Y_rec, a_nlin_deg, M_avg

# Experiments

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu") # The implementation does not work on GPU

## Samson

In [None]:
# Configuration proposed by the Authors
my_llambda = 1      # regularization between supervised and unsupervised part
my_bbeta = 10       # extra regularization (high likelihood of endmembers and abundances training data in the posterior)
my_tau     = 0.005  # extra extra regularization (sparsity)
my_lamb_We = 0.01   # penalizes network weights of nonlinear encoder
my_lamb_Wd = 0.1    # penalizes network weights of nonlinear decoder

    # llambda = my_llambda; 
    # bbeta   = 10; # extra regularization (high likelihood of endmembers and abundances training data in the posterior)
    # tau     = my_tau; # extra extra regularization (sparsity)
    # lamb_We = my_lamb_We; # penalizes network weights of nonlinear encoder
    # lamb_Wd = my_lamb_Wd; # penalizes network weights of nonlinear decoder

vca = True
dataset_name = 'Samson_'+('VCA' if vca else 'NFINDR')

dataset = dataset_maker(SamsonData(vca=vca))

In [None]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

L = train_loader.dataset.data_sup[0][1].shape[0]
P = train_loader.dataset.data_sup[0][1].shape[1]
N = len(train_loader.dataset.data_unsup)
nr, nc = train_loader.dataset.A_u_cube.shape[0], train_loader.dataset.A_u_cube.shape[1]
H = 2 # dimension of the latent EM space


model = IDNet(P, L, H=H).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

start_time = time.time()
num_epochs = 30; # number of epochs to train
loss_old = 1e30
for epoch in range(1, num_epochs + 1):
    loss_t = train(epoch, model, optimizer, train_loader, my_llambda, my_bbeta, my_tau, my_lamb_We, my_lamb_Wd)
    A_est, Mn_est, Y_rec, a_nlin_deg, M_avg = test(epoch, model, train_loader)
    
    # compute metrics -----------------------
    RMSE_A, NRMSE_A = utils.compute_metrics(train_loader.dataset.A_u, A_est)
    RMSE_M, NRMSE_M = utils.compute_metrics(train_loader.dataset.M_u_ppx, Mn_est)
    RMSE_Y, NRMSE_Y = utils.compute_metrics(train_loader.dataset.Y, Y_rec)
    
    metrics_str = '====> EPOCH: {:d}, Abundance NRMSE: {:.6f}, Endmember NRMSE: {:.6f}'.format(epoch, NRMSE_A, NRMSE_M)
    with open(os.path.join('results', f'{dataset_name}.txt'), "a") as text_file:
        print(metrics_str, file=text_file)
        print(metrics_str) # print to console too
    
    if epoch <= 10:
        scheduler.step() # reduce from 1e-3 to 1e-4 in 10 epochs with rate approx 0.8
    
    # check stopping condition
    if abs(loss_t - loss_old) / abs(loss_t) < 1e-2:
        break
    loss_old = loss_t
elapsed_time = time.time()-start_time


# plot abundances and average EMs ----------------------------------
utils.plotAbunds(A_est, nr=nr, nc=nc, 
                    thetitle='learned abundances',
                    savepath=os.path.join('results', f'{dataset_name}.pdf'))
utils.plotEMs(M_avg, thetitle='learned avg EMs')

# compare results to ground truth abundances if available
# utils.show_ground_truth(A_true=train_loader.dataset.A_u, Mgt_avg=train_loader.dataset.M_u_avg, nr=nr, nc=nc)


print('====> FINAL: Abundance NRMSE: {:.6f}, Endmember NRMSE: {:.6f}'.format(NRMSE_A, NRMSE_M))
print('====> Elapsed time: {:.6f}'.format(elapsed_time))
# plt.figure(), plt.plot(Mn_est[:,0,0:20]), plt.show();

# save .mat file with the results
io.savemat(os.path.join('results', f'{dataset_name}.mat'), 
        {'A_est':A_est.numpy(),
            'Mn_est':Mn_est.numpy(),
            'Y_rec':Y_rec.numpy(),
            'a_nlin_deg':a_nlin_deg.numpy()})

In [None]:
print('====> FINAL: Abundance NRMSE: {:.6f}, Endmember NRMSE: {:.6f}'.format(NRMSE_A, NRMSE_M))
print('====> Elapsed time: {:.6f}'.format(elapsed_time))
# plt.figure(), plt.plot(Mn_est[:,0,0:20]), plt.show();

# save .mat file with the results
io.savemat(os.path.join('results', f'{dataset_name}.mat'), 
        {'A_est':A_est.numpy(),
            'Mn_est':Mn_est.numpy(),
            'Y_rec':Y_rec.numpy(),
            'a_nlin_deg':a_nlin_deg.numpy()})

## Jasper Ridge

In [None]:
my_llambda = 1      # regularization between supervised and unsupervised part
my_bbeta = 10       # extra regularization (high likelihood of endmembers and abundances training data in the posterior)
my_tau     = 0      # extra extra regularization (sparsity)
my_lamb_We = 0.01   # penalizes network weights of nonlinear encoder
my_lamb_Wd = 0.1    # penalizes network weights of nonlinear decoder

vca = True
dataset_name = 'JasperRidge_'+('VCA' if vca else 'NFINDR')

dataset = dataset_maker(JasperRidgeData(vca=vca))

In [None]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

L = train_loader.dataset.data_sup[0][1].shape[0]
P = train_loader.dataset.data_sup[0][1].shape[1]
N = len(train_loader.dataset.data_unsup)
nr, nc = train_loader.dataset.A_u_cube.shape[0], train_loader.dataset.A_u_cube.shape[1]
H = 2 # dimension of the latent EM space


model = IDNet(P, L, H=H).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

start_time = time.time()
num_epochs = 30; # number of epochs to train
loss_old = 1e30
for epoch in range(1, num_epochs + 1):
    loss_t = train(epoch, model, optimizer, train_loader, my_llambda, my_bbeta, my_tau, my_lamb_We, my_lamb_Wd)
    A_est, Mn_est, Y_rec, a_nlin_deg, M_avg = test(epoch, model, train_loader)
    
    # compute metrics -----------------------
    RMSE_A, NRMSE_A = utils.compute_metrics(train_loader.dataset.A_u, A_est)
    RMSE_M, NRMSE_M = utils.compute_metrics(train_loader.dataset.M_u_ppx, Mn_est)
    RMSE_Y, NRMSE_Y = utils.compute_metrics(train_loader.dataset.Y, Y_rec)
    
    metrics_str = '====> EPOCH: {:d}, Abundance NRMSE: {:.6f}, Endmember NRMSE: {:.6f}'.format(epoch, NRMSE_A, NRMSE_M)
    with open(os.path.join('results', f'{dataset_name}.txt'), "a") as text_file:
        print(metrics_str, file=text_file)
        print(metrics_str) # print to console too
    
    if epoch <= 10:
        scheduler.step() # reduce from 1e-3 to 1e-4 in 10 epochs with rate approx 0.8
    
    # check stopping condition
    if abs(loss_t - loss_old) / abs(loss_t) < 1e-2:
        break
    loss_old = loss_t
elapsed_time = time.time()-start_time


# plot abundances and average EMs ----------------------------------
utils.plotAbunds(A_est, nr=nr, nc=nc, 
                    thetitle='learned abundances',
                    savepath=os.path.join('results', f'{dataset_name}.pdf'))
utils.plotEMs(M_avg, thetitle='learned avg EMs')

# compare results to ground truth abundances if available
# utils.show_ground_truth(A_true=train_loader.dataset.A_u, Mgt_avg=train_loader.dataset.M_u_avg, nr=nr, nc=nc)


print('====> FINAL: Abundance NRMSE: {:.6f}, Endmember NRMSE: {:.6f}'.format(NRMSE_A, NRMSE_M))
print('====> Elapsed time: {:.6f}'.format(elapsed_time))
# plt.figure(), plt.plot(Mn_est[:,0,0:20]), plt.show();

# save .mat file with the results
io.savemat(os.path.join('results', f'{dataset_name}.mat'), 
        {'A_est':A_est.numpy(),
            'Mn_est':Mn_est.numpy(),
            'Y_rec':Y_rec.numpy(),
            'a_nlin_deg':a_nlin_deg.numpy()})

## Apex

In [None]:
my_llambda = 1      # regularization between supervised and unsupervised part
my_bbeta = 10       # extra regularization (high likelihood of endmembers and abundances training data in the posterior)
my_tau     = 0.01      # extra extra regularization (sparsity)
my_lamb_We = 0.005   # penalizes network weights of nonlinear encoder
my_lamb_Wd = 0.1    # penalizes network weights of nonlinear decoder

vca = True
dataset_name = 'Apex'+('VCA' if vca else 'NFINDR')

dataset = dataset_maker(ApexData(vca=vca))

In [None]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

L = train_loader.dataset.data_sup[0][1].shape[0]
P = train_loader.dataset.data_sup[0][1].shape[1]
N = len(train_loader.dataset.data_unsup)
nr, nc = train_loader.dataset.A_u_cube.shape[0], train_loader.dataset.A_u_cube.shape[1]
H = 2 # dimension of the latent EM space


model = IDNet(P, L, H=H).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

start_time = time.time()
num_epochs = 30; # number of epochs to train
loss_old = 1e30
for epoch in range(1, num_epochs + 1):
    loss_t = train(epoch, model, optimizer, train_loader, my_llambda, my_bbeta, my_tau, my_lamb_We, my_lamb_Wd)
    A_est, Mn_est, Y_rec, a_nlin_deg, M_avg = test(epoch, model, train_loader)
    
    # compute metrics -----------------------
    RMSE_A, NRMSE_A = utils.compute_metrics(train_loader.dataset.A_u, A_est)
    RMSE_M, NRMSE_M = utils.compute_metrics(train_loader.dataset.M_u_ppx, Mn_est)
    RMSE_Y, NRMSE_Y = utils.compute_metrics(train_loader.dataset.Y, Y_rec)
    
    metrics_str = '====> EPOCH: {:d}, Abundance NRMSE: {:.6f}, Endmember NRMSE: {:.6f}'.format(epoch, NRMSE_A, NRMSE_M)
    with open(os.path.join('results', f'{dataset_name}.txt'), "a") as text_file:
        print(metrics_str, file=text_file)
        print(metrics_str) # print to console too
    
    if epoch <= 10:
        scheduler.step() # reduce from 1e-3 to 1e-4 in 10 epochs with rate approx 0.8
    
    # check stopping condition
    if abs(loss_t - loss_old) / abs(loss_t) < 1e-2:
        break
    loss_old = loss_t
elapsed_time = time.time()-start_time


# plot abundances and average EMs ----------------------------------
utils.plotAbunds(A_est, nr=nr, nc=nc, 
                    thetitle='learned abundances',
                    savepath=os.path.join('results', f'{dataset_name}.pdf'))
utils.plotEMs(M_avg, thetitle='learned avg EMs')

# compare results to ground truth abundances if available
# utils.show_ground_truth(A_true=train_loader.dataset.A_u, Mgt_avg=train_loader.dataset.M_u_avg, nr=nr, nc=nc)


print('====> FINAL: Abundance NRMSE: {:.6f}, Endmember NRMSE: {:.6f}'.format(NRMSE_A, NRMSE_M))
print('====> Elapsed time: {:.6f}'.format(elapsed_time))
# plt.figure(), plt.plot(Mn_est[:,0,0:20]), plt.show();

# save .mat file with the results
io.savemat(os.path.join('results', f'{dataset_name}.mat'), 
        {'A_est':A_est.numpy(),
            'Mn_est':Mn_est.numpy(),
            'Y_rec':Y_rec.numpy(),
            'a_nlin_deg':a_nlin_deg.numpy()})