In [None]:
import os, sys
project_root_dir = os.path.join(os.getcwd(),'../..')
if project_root_dir not in sys.path:
    sys.path.append(project_root_dir)

from matplotlib import pyplot as plt
import numpy as np
from torch import nn
import torch
import config

from utils import show_abundance, plot_endmembers
from HySpecLab.metrics import rmse, sad
from scipy import io as sio

# Dataset 

In [None]:
from dataset import NerveFat
dataset = NerveFat(root_dir=config.NerveFat_PATH)
result_path = os.path.join(config.RESULTS_PATH, 'nerveFat')

In [None]:
def preprocessing(X:np.ndarray):
        '''
            Preprocessing the dataset for removing high-frequency noise. 
            This preprocessing consists of three steps:
                1. Median filter in the spatial domain.
                2. Moving average filter in the spectral domain. (No!)
                3. Normalization of the data.

            Parameters
            ----------
                X : np.ndarray, shape (nRow, nCol, nBand)
                    HSI Cube.
        '''
        from skimage.filters import median
        from utils import moving_average

        # X = median(X, footprint=np.ones((3,3,1)))
        # X = moving_average(X.reshape(-1, X.shape[-1]), 5, padding_size=4).reshape(X.shape[0], X.shape[1], -1)
        return X

X_filtered = preprocessing(dataset.X.reshape(dataset.n_row, dataset.n_col, -1))
dataset.X = torch.tensor(X_filtered.reshape(-1, X_filtered.shape[-1])).float()

In [None]:
# matlab_data = {
#     'X': dataset.image(),
#     'n_endmembers': dataset.n_endmembers,
#     'nRow': dataset.n_row,
#     'nCol': dataset.n_col,
#     'nBand': dataset.n_bands
# }

# sio.savemat(os.path.join(result_path, 'matlab/input.mat'), matlab_data)

In [None]:
test = dataset.image()

In [None]:
n_endmembers = 3

from HySpecLab.metrics import sad

def sort_endmember(endmembers, gt):
    sad_result = sad(endmembers, gt)
    e_idx = torch.argmin(sad_result, dim=0) # Index for reordering the ground truth
    return endmembers[e_idx], e_idx, sad_result

In [None]:
from HySpecLab.eea import VCA
   
vca = VCA(n_endmembers, snr_input=1, random_state=42)
vca.fit(dataset.X.numpy())
endmembers = torch.from_numpy(vca.endmembers()).float()

vca_endmember_init = endmembers
vca_logit_endmember_init = torch.log((vca_endmember_init + 1e-12) / ((1-vca_endmember_init) + 1e-12))

fig = plot_endmembers(vca_endmember_init, ticks_range=(0, 1))
plt.show(fig)

fig.savefig(os.path.join(result_path, 'imgs/M_vca.pdf'), bbox_inches='tight')

In [None]:
from utils import plot_endmembers
from pysptools import eea

ee = eea.NFINDR()
endmember = torch.from_numpy(ee.extract(dataset.image(), n_endmembers)).float()

nfindr_endmember_init = endmember
nfindr_logit_endmember_init = torch.log((nfindr_endmember_init + 1e-12) / ((1-nfindr_endmember_init) + 1e-12))

nfindr_endmember_init, _, _ = sort_endmember(nfindr_endmember_init, vca_endmember_init)
nfindr_logit_endmember_init = torch.log((nfindr_endmember_init + 1e-12) / ((1-nfindr_endmember_init) + 1e-12))

fig = plot_endmembers(nfindr_endmember_init, ticks_range=(0, 1))
plt.show(fig)
fig.savefig(os.path.join(result_path, 'imgs/M_nfindr.pdf'), bbox_inches='tight')

In [None]:
# endmember_init_method = 'nfindr'
# endmember_init = nfindr_endmember_init
# logit_endmember_init = nfindr_logit_endmember_init

endmember_init_method = 'vca'
endmember_init = vca_endmember_init
logit_endmember_init = vca_logit_endmember_init

In [None]:
from utils import train 
from HySpecLab.unmixing import ContrastiveUnmixing

n_bands = dataset.n_bands
model = ContrastiveUnmixing(n_bands, n_endmembers, endmember_init=logit_endmember_init, sigma_sparsity=.5)
train(model, n_endmembers, dataset, n_batchs=50, n_epochs=100, lr=1e-3, similarity_weight=1, sparse_weight=1)

In [None]:
model.eval()
_ = model(dataset.X.cuda())
print(model.sparse_gate.variational_parameter().flatten())
print(model.sparse_gate.variational_parameter().flatten().mean())
print(model.sparse_gate.variational_parameter().flatten().min())
print(model.sparse_gate.regularize())

# Save Model

In [None]:
# torch.save(model.state_dict(), os.path.join(result_path, 'clhu/weights/clhu.pth'))

# Testing

In [None]:
# from HySpecLab.unmixing import ContrastiveUnmixing

# # load model
# model = ContrastiveUnmixing(dataset.n_bands, n_endmembers)

# model.load_state_dict(torch.load(os.path.join(result_path, 'clhu/weights/clhu.pth')))
model = model.eval()

In [None]:
model.eval()
from HySpecLab.metrics.regularization import SimplexVolumeLoss, SimilarityLoss
from HySpecLab.metrics import UnmixingLoss, NormalizedEntropy

criterion = UnmixingLoss()
entropy_reg  = NormalizedEntropy(S=n_endmembers)
volume_reg = SimplexVolumeLoss(dataset[:], n_endmembers)
similarity_reg = SimilarityLoss(n_endmembers, temperature=.1, reduction='mean')

In [None]:
from torch import sigmoid 
_X = dataset.X

model.eval()
model = model.cpu()
reconstruc = model(_X)
with torch.no_grad():
    print(criterion(reconstruc, _X).cpu(), entropy_reg(model.A).cpu(), volume_reg(sigmoid(model.ebk)).cpu(),
         similarity_reg(model.ebk).cpu())

In [None]:
volume_reg(endmember_init), similarity_reg(logit_endmember_init)

In [None]:
ebk = torch.sigmoid(model.ebk).detach()
fig = plot_endmembers(ebk, wv=dataset.wv, figsize=(6,4), ticks_range=(0, .92), endmember_estimation=[vca_endmember_init, nfindr_endmember_init], ee_labels=['CLHU', 'VCA', 'N-FINDR'])
# fig = plot_endmembers(ebk, wv=dataset.wv, ticks_range=(0, .8))
plt.show(fig)

# fig.savefig(os.path.join(result_path, 'imgs/M_clhu.pdf'), bbox_inches='tight')

In [None]:
from torch.nn.functional import softmax
test = softmax(model.A.detach(), dim=1).cpu().numpy().reshape(dataset.n_row, dataset.n_col, -1, order='F')

fig = show_abundance(test)
plt.show(fig)

# fig.savefig(os.path.join(result_path, 'imgs/A_clhu.png'), dpi=300, bbox_inches='tight')
# fig.savefig(os.path.join(result_path, 'imgs/A_clhu.pdf'), bbox_inches='tight')

In [None]:
# test = model._sparse.detach().cpu().numpy().reshape(dataset.n_row, dataset.n_col)
test = model.sparse_gate.variational_parameter().detach().cpu().numpy().reshape(dataset.n_row, dataset.n_col)
test = np.log(test)

fig = plt.figure(figsize=(6,4))
plt.imshow(test, cmap='jet')
# # set "log(\rho)" in colorbar
cbar = plt.colorbar()
cbar.set_label(r'$\log(\rho)$', labelpad=2, fontsize=14)
plt.axis('off')
plt.show()

In [None]:
M = sigmoid(model.ebk).detach()
M_hat = model(M).detach().cpu()

fig = plot_endmembers(M_hat, wv=dataset.wv, ticks_range=(0, .9), endmember_estimation=[M], ee_labels=['Reconstructed', 'M'])

plt.show(fig)

torch.softmax(model.A, dim=1).detach()

# Multiple Configurations

In [59]:
from torch.nn.functional import softmax
def save_result(model, dataset, result_dir):
    model = model.eval()
    model = model.cpu()

    X = dataset.X
    _ = model(X)
    
    M = torch.sigmoid(model.ebk).detach()
    fig = plot_endmembers(M, wv=dataset.wv, figsize=(6,4), ticks_range=(0, .92), endmember_estimation=[vca_endmember_init, nfindr_endmember_init], ee_labels=['CLHU', 'VCA', 'N-FINDR'])
    fig.savefig(os.path.join(result_dir, 'M_clhu.pdf'), bbox_inches='tight')
    plt.close(fig)
    
    A = softmax(model.A.detach(), dim=1).cpu().numpy().reshape(dataset.n_row, dataset.n_col, -1, order='F')
    fig = show_abundance(A)
    fig.savefig(os.path.join(result_dir, 'A_clhu.pdf'), bbox_inches='tight')
    plt.close(fig)
    
    sparse = model.sparse_gate.variational_parameter().detach().cpu().numpy().reshape(dataset.n_row, dataset.n_col)
    sparse = np.log(sparse)

    fig, ax = plt.subplots(figsize=(6,4))
    im = ax.imshow(sparse, cmap='jet')
    cbar = plt.colorbar(im)
    cbar.ax.tick_params(labelsize=12)
    cbar.set_label(r'$\log(\rho)$', labelpad=2, fontsize=14)
    plt.axis('off')
    plt.savefig(os.path.join(result_dir, 'sparse_clhu.pdf'), bbox_inches='tight')
    plt.close(fig)


In [60]:
import os
from sklearn.model_selection import ParameterGrid
from utils import train 
from HySpecLab.unmixing import ContrastiveUnmixing


param_grid = {
    'sigma': [.05, .1, .25, .5],
    'lambda': [0, .05, .1, .25, .5, 1],
}


n_bands = dataset.n_bands
for params in ParameterGrid(param_grid):
    sigma = params['sigma']
    lambda_ = params['lambda']

    result_dir = os.path.join(result_path, f'clhu/sparse_no_normalized/sigma_{sigma}/lambda_{lambda_}')
    os.makedirs(result_dir, exist_ok=True)

    model = ContrastiveUnmixing(n_bands, n_endmembers, endmember_init=logit_endmember_init, sigma_sparsity=sigma)
    train(model, n_endmembers, dataset, n_batchs=50, n_epochs=50, lr=1e-3, similarity_weight=1, sparse_weight=lambda_)
    save_result(model, dataset, result_dir)

100%|██████████| 50/50 [00:16<00:00,  3.05epoch/s, tls=0.0535]
100%|██████████| 50/50 [00:16<00:00,  2.99epoch/s, tls=0.0614]
100%|██████████| 50/50 [00:17<00:00,  2.94epoch/s, tls=0.0699]
100%|██████████| 50/50 [00:17<00:00,  2.91epoch/s, tls=0.1275]
100%|██████████| 50/50 [00:16<00:00,  2.99epoch/s, tls=0.0752]
100%|██████████| 50/50 [00:16<00:00,  3.05epoch/s, tls=0.0767]
100%|██████████| 50/50 [00:16<00:00,  3.06epoch/s, tls=0.0903]
100%|██████████| 50/50 [00:16<00:00,  2.98epoch/s, tls=0.1291]
100%|██████████| 50/50 [00:16<00:00,  3.05epoch/s, tls=0.1139]
100%|██████████| 50/50 [00:16<00:00,  3.00epoch/s, tls=0.1087]
100%|██████████| 50/50 [00:17<00:00,  2.93epoch/s, tls=0.1108]
100%|██████████| 50/50 [00:16<00:00,  3.05epoch/s, tls=0.1490]
100%|██████████| 50/50 [00:16<00:00,  2.95epoch/s, tls=0.5171]
100%|██████████| 50/50 [00:16<00:00,  2.98epoch/s, tls=0.1599]
100%|██████████| 50/50 [00:16<00:00,  2.95epoch/s, tls=0.1689]
100%|██████████| 50/50 [00:17<00:00,  2.90epoch/s, tls=

In [None]:
import torch 
n_endmembers=3
torch.vstack((torch.zeros((n_endmembers-1,)), torch.eye(n_endmembers-1)))

In [None]:
from sklearn.decomposition import PCA
X = dataset.X
mu = X.reshape(-1, X.shape[-1]).mean(axis=0, keepdims=True).T
U = PCA(n_components=n_endmembers-1, random_state=42).fit(X.reshape(-1, X.shape[-1])).components_.T # (bands, endmembers-1)
U = torch.tensor(U, dtype=torch.float32)
mu.shape, U.shape

In [None]:
print(U.shape)

In [None]:
dataset.X.shape

In [None]:
B = torch.vstack((torch.zeros((n_endmembers-1,)), torch.eye(n_endmembers-1)))
C = torch.zeros((n_endmembers, n_endmembers))
C[0, :] = 1
print(B)
print(C)
Z = C+B@U.T@(vca_endmember_init.T-mu)
print(Z)

In [None]:
n_endmembers