In [None]:
import os, sys
project_root_dir = os.path.join(os.getcwd(),'..')
if project_root_dir not in sys.path:
    sys.path.append(project_root_dir)

from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import config

plt.style.use('seaborn-v0_8')

In [None]:
result_path = os.path.join(config.RESULTS_PATH, 'syntheticData')

In [None]:
from HySpecLab.metrics import UnmixingLoss, NormalizedEntropy
from HySpecLab.metrics.regularization import SimplexVolumeLoss, SimilarityLoss

from HySpecLab.unmixing import ContrastiveUnmixing
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from torch import sigmoid
from torch.utils.data import Dataset, DataLoader

def train(model:nn.Module, n_endmembers:int, dataset:Dataset, n_batchs:int = 64, n_epochs:int = 100, lr=1e-3, simplex_weight=1e-5):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = UnmixingLoss() 
    volume_reg = SimplexVolumeLoss(dataset[:], n_endmembers).to(device)
    similarity_reg = SimilarityLoss(n_endmembers, temperature=.1, reduction='mean')
    vol_reg_norm = volume_reg(torch.sigmoid(model.ebk.detach()))
    print(vol_reg_norm)
    


    dataloader = DataLoader(dataset, batch_size=int(len(dataset)/n_batchs), shuffle=True)

    epoch_iterator = tqdm(
            range(n_epochs),
            leave=True,
            unit="epoch",
            postfix={"tls": "%.4f" % -1},
        )

    similarity_weight = 5e-1

    scaler = torch.cuda.amp.GradScaler()

    for epoch in epoch_iterator:
        epoch_loss = 0.
        for i, (x) in enumerate(dataloader):
            x = x.to(device)
            optimizer.zero_grad()
            y = model(x)
            loss = criterion(y, x) + simplex_weight*(volume_reg(sigmoid(model.ebk))/vol_reg_norm) + similarity_weight*similarity_reg(model.ebk)
            epoch_loss += loss.detach().item()

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
        epoch_iterator.set_postfix(tls="%.4f" % (epoch_loss/(i+1)))

In [None]:
from scipy import io as sio

synthetic_data_path = "/home/abian/Data/Dataset/HSI/SyntheticData/"
data = sio.loadmat(synthetic_data_path + "Y.mat")
M = sio.loadmat(synthetic_data_path + "M.mat")['M_avg']
A = sio.loadmat(synthetic_data_path + "A.mat")['A'].T

X = data['Y'].T
X = X.reshape(data['nRow'].item(), data['nCol'].item(), -1, order='F')
A = A.reshape(data['nRow'].item(), data['nCol'].item(), -1, order='F')

# Image to RGB

In [None]:
from dataset import SyntheticDataset
dataset = SyntheticDataset(config.Synthetic_PATH, transform=None)

# matlab_data = {
#     'X': dataset.image(),
#     'n_endmembers': dataset.n_endmembers,
#     'nRow': dataset.n_row,
#     'nCol': dataset.n_col,
#     'nBand': dataset.n_bands
# }

# sio.savemat(os.path.join(result_path, 'data/input.mat'), matlab_data)

In [None]:
jasper_wv = np.linspace(380, 2500, 224) # 224 bands from 380 to 2500 nm

data = sio.loadmat(os.path.join(config.JasperRidge_PATH, 'jasperRidge2_R198.mat'))
selected_bands = data['SlectBands'].squeeze()
selected_jasper_wv = jasper_wv[selected_bands].tolist()

In [None]:
from HSI2RGB import HSI2RGB

# Use the D65 illuminant
illuminant = 65

# Do minor thresholding
threshold = 0.02
X = dataset.image()
(ydim, xdim, zdim) = X.shape

# Reorder data so that each column holds the spectra of of one pixel
HSI_data = np.reshape(X, [-1,zdim])
rgb = HSI2RGB(selected_jasper_wv, HSI_data, xdim, ydim, illuminant, threshold)

fig = plt.figure(figsize=(7,5))
plt.subplot(3,4,1)
plt.imshow(rgb)
plt.axis('off')
plt.show()

fig.savefig(os.path.join(result_path, 'imgs/synthetic_rgb.png'), dpi=300, bbox_inches='tight')

In [None]:
from utils import plot_endmembers, show_abundance
fig = plot_endmembers(dataset.endmembers(), ticks_range=(0, .7))
plt.show(fig)
fig.savefig(os.path.join(result_path, 'imgs/E_ref.pdf'), bbox_inches='tight')

fig = show_abundance(dataset.abundance())
# fig.savefig(os.path.join(result_path, 'imgs/A_ref.png'), dpi=300, bbox_inches='tight')

## Endmember estimation

## VCA

In [None]:
from HySpecLab.eea import VCA
_X = X.reshape(-1, X.shape[-1])
n_endmembers = 3
vca = VCA(n_endmembers, snr_input=30, random_state=25)
vca.fit(_X)
endmembers = vca.endmembers()

plot_endmembers(endmembers, ticks_range=(0, .7))
plt.show()

## NFINDR

In [None]:
from utils import plot_endmembers
from pysptools import eea
n_endmembers = dataset.n_endmembers

ee = eea.NFINDR()
endmember = ee.extract(X, n_endmembers)
endmember_init = torch.from_numpy(endmember).float()

In [None]:
from HySpecLab.metrics import sad

sad_result = sad(endmember_init, dataset.endmembers())
e_idx = torch.argmin(sad_result, dim=0) # Index for reordering the ground truth
endmember_init = endmember_init[e_idx]
logit_endmember_init = torch.log((endmember_init / (1-endmember_init) + 1e-12))

fig = plot_endmembers(endmember_init, ticks_range=(0, .7))
plt.show(fig)
# fig.savefig(os.path.join(result_path, 'imgs/E_nfindr_est.pdf'), bbox_inches='tight')

# Train

In [None]:
n_bands = dataset.n_bands
model = ContrastiveUnmixing(n_bands, n_endmembers, endmember_init=logit_endmember_init)
train(model, n_endmembers, dataset, n_batchs=50, n_epochs=50, lr=1e-3, simplex_weight=1e-2)

## save model

In [None]:
torch.save(model.state_dict(), os.path.join(result_path, 'weights/clhu.pth'))

# Testing

In [None]:
# load model
model = ContrastiveUnmixing(dataset.n_bands, dataset.n_endmembers)

model.load_state_dict(torch.load(os.path.join(result_path, 'weights/clhu.pth')))
model = model.eval()

In [None]:
criterion = UnmixingLoss()
entropy_reg  = NormalizedEntropy(S=dataset.n_endmembers)
volume_reg = SimplexVolumeLoss(dataset[:], dataset.n_endmembers)
similarity_reg = SimilarityLoss(dataset.n_endmembers, temperature=.1, reduction='mean')

In [None]:
_X = dataset.X

model.eval()
model = model.cpu()
reconstruc = model(_X)
with torch.no_grad():
    print(criterion(reconstruc, _X).cpu(), entropy_reg(model.A).cpu(), volume_reg(sigmoid(model.ebk)).cpu(),
         similarity_reg(sigmoid(model.ebk)).cpu())

In [None]:
volume_reg(endmember_init), similarity_reg(endmember_init)

In [None]:
from HySpecLab.metrics import rmse

rmse_result = rmse(reconstruc, _X, dim=None)
print(rmse_result)

In [None]:
ebk = sigmoid(model.ebk).detach().cpu()

fig = plot_endmembers(ebk, ticks_range=(0, .7))
plt.show(fig)
# fig.savefig(os.path.join(result_path, 'imgs/E_clhu_est.pdf'), bbox_inches='tight')

In [None]:
from torch.nn.functional import softmax
from utils import show_abundance

test = softmax(model.A.detach(), dim=1).cpu().numpy().reshape(dataset.n_row, dataset.n_col, -1)

fig = show_abundance(test)
plt.show(fig)

# fig.savefig(os.path.join(result_path, 'imgs/A_clhu_est.png'), dpi=300, bbox_inches='tight')


In [None]:
y = model(_X)
y = y.detach().cpu()

from HySpecLab.metrics import rmse
rmse(y, _X, dim=None)

In [None]:
from HySpecLab.metrics import sad

sad_result = sad(ebk, dataset.endmembers())
print(sad_result)
np.argmin(sad_result, axis=0)

In [None]:
sad_result = sad(endmember_init, dataset.endmembers())
print(sad_result)
np.argmin(sad_result, axis=0)

In [None]:
test = sio.loadmat(os.path.join(synthetic_data_path, 'A_est', 'FCLS.mat'))['A'].T
# from utils import plot_abundance
A = test.reshape(50,50,-1,order='C')
fig = show_abundance(A)
plt.show(fig)

In [None]:
data = sio.loadmat(os.path.join(synthetic_data_path, 'M_est', 'MESMA.mat'))
M = data['M']
plt.plot(M[:,:,2000])
plt.show()

In [None]:
plt.plot(M[:,:,1000])
plt.show()

# LMM and ELMM solution

In [None]:
result_path = os.path.join(config.RESULTS_PATH, 'syntheticData/data')

from dataset import SyntheticDataset
dataset = SyntheticDataset(config.Synthetic_PATH)

In [None]:
from scipy import io as sio
from HySpecLab.metrics import sad

endmember_estimation = sio.loadmat(os.path.join(result_path, 'endmember_estimation.mat'))
vca_ee = torch.tensor(endmember_estimation['VCA'])
nfindr_ee = torch.tensor(endmember_estimation['NFINDR'])

sad_result = sad(vca_ee, dataset.endmembers())
vca_idx = torch.argmin(sad_result, dim=0) # Index for reordering the ground truth

sad_result = sad(nfindr_ee, dataset.endmembers())
nfindr_idx = torch.argmin(sad_result, dim=0) # Index for reordering the ground truth

vca_ee = vca_ee[vca_idx]
nfindr_ee = nfindr_ee[nfindr_idx]

In [None]:
input = sio.loadmat(os.path.join(result_path, 'input.mat'))
X = torch.tensor(input['X'].reshape(-1, input['X'].shape[-1], order='F'))

## Endmember

In [None]:
from utils import plot_endmembers
fig = plot_endmembers(dataset.endmembers(), ticks_range=(0, .7), 
                      endmember_estimation=[nfindr_ee, vca_ee],
                      ee_labels=['NFINDR', 'VCA'])
plt.show(fig)
fig.savefig(os.path.join(result_path, 'imgs/endmembers_estimation.pdf'), bbox_inches='tight')


## LMM

In [None]:
FCLS = sio.loadmat(os.path.join(result_path, 'FCLS.mat'))
save_dir = os.path.join(result_path, 'imgs/LMM')

Xhat = FCLS['Xhat_NFINDR'].T

from HySpecLab.metrics import rmse
nfindr_reconstruct = torch.tensor(FCLS['Xhat_NFINDR'].T)
vca_reconstruct = torch.tensor(FCLS['Xhat_VCA'].T)

print(rmse(torch.tensor(X), nfindr_reconstruct, dim=None), rmse(torch.tensor(X), vca_reconstruct, dim=None))

In [None]:
from utils import show_abundance
A = FCLS['A_NFINDR'].T.reshape(50,50,-1,order='C')
fig = show_abundance(A[:, :, nfindr_idx])
plt.show(fig)
fig.savefig(os.path.join(save_dir, 'A_NFINDR.png'), dpi=300, bbox_inches='tight')

A = FCLS['A_VCA'].T.reshape(50,50,-1,order='C')
fig = show_abundance(A[:, :, vca_idx])
plt.show(fig)
fig.savefig(os.path.join(save_dir, 'A_VCA.png'), dpi=300, bbox_inches='tight')

## ELMM

In [None]:
ELMM = sio.loadmat(os.path.join(result_path, 'ELMM.mat'))
save_dir = os.path.join(result_path, 'imgs/ELMM')

from HySpecLab.metrics import rmse
nfindr_reconstruct = torch.tensor(ELMM['Xhat_NFINDR'].T)
vca_reconstruct = torch.tensor(ELMM['Xhat_VCA'].T)

print(rmse(torch.tensor(X), nfindr_reconstruct, dim=None), rmse(torch.tensor(X), vca_reconstruct, dim=None))

In [None]:
from utils import show_abundance
A = ELMM['A_NFINDR'].T.reshape(50,50,-1,order='C')
fig = show_abundance(A[:, :, nfindr_idx])
plt.show(fig)
fig.savefig(os.path.join(save_dir, 'A_NFINDR.png'), dpi=300, bbox_inches='tight')

A = ELMM['A_VCA'].T.reshape(50,50,-1,order='C')
fig = show_abundance(A[:, :, vca_idx])
plt.show(fig)
fig.savefig(os.path.join(save_dir, 'A_VCA.png'), dpi=300, bbox_inches='tight')

In [None]:
from scipy import io as sio

path = os.path.join(config.RESULTS_PATH, 'syntheticData/data')

input = sio.loadmat(os.path.join(path, 'input.mat'))
X = input['X'].reshape(-1, input['X'].shape[-1], order='F')

FCLS = sio.loadmat(os.path.join(path, 'FCLS.mat'))
ELMM = sio.loadmat(os.path.join(path, 'ELMM.mat'))

Xhat = FCLS['Xhat_NFINDR'].T

from HySpecLab.metrics import rmse,sad
rmse(torch.tensor(X), torch.tensor(Xhat), dim=None)


In [None]:
from utils import show_abundance
A = FCLS['A_NFINDR'].T.reshape(50,50,-1,order='C')
fig = show_abundance(A[:, :, nfindr_idx])
plt.show(fig)

A = FCLS['A_VCA'].T.reshape(50,50,-1,order='C')
fig = show_abundance(A[:, :, vca_idx])
plt.show(fig)

In [None]:
A = ELMM['A_VCA'].T.reshape(50,50,-1,order='C')
fig = show_abundance(A[:,:,vca_idx])

A = ELMM['A_NFINDR'].T.reshape(50,50,-1,order='C')
fig = show_abundance(A[:,:,nfindr_idx])

In [None]:
Xhat.shape, X.shape