In [None]:
import os, sys
project_root_dir = os.path.join(os.getcwd(),'..')
if project_root_dir not in sys.path:
    sys.path.append(project_root_dir)

from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import config

plt.style.use('seaborn-v0_8')

In [None]:
result_path = os.path.join(config.RESULTS_PATH, 'syntheticData')

In [None]:
from torch.utils.data import Dataset
from dataset import HSIDataset
from scipy import io as sio

class SyntheticDataset(HSIDataset):
    def __init__(self, root_dir, transform=None):
        super(SyntheticDataset, self).__init__()
        data = sio.loadmat(os.path.join(root_dir, "Y.mat"))

        self.n_row, self.n_col , self.n_bands = data['nRow'].item(), data['nCol'].item(), data['nBand'].item()
        self.X = np.abs(data['Y'].T) # Because of the noise, there are negative values
        self.X = self.X.reshape(self.n_row, self.n_col, -1)
        self.X = self.preprocessing(self.X, max_value=1).reshape(-1, self.X.shape[-1]) # (nRow*nCol, nBand)

        self.E = sio.loadmat(os.path.join(root_dir, "M.mat"))['M_avg'].T
        self.A = sio.loadmat(os.path.join(root_dir, "A.mat"))['A'].T

        self.X = torch.tensor(self.X, dtype=torch.float32)
        self.E = torch.tensor(self.E, dtype=torch.float32)
        self.A = torch.tensor(self.A, dtype=torch.float32)
        self.n_endmembers = self.E.shape[0]

        self.transform = transform

    def __len__(self):
        return self.n_row * self.n_col

    def __getitem__(self, idx):
        sample = self.X[idx]
        if self.transform:
            sample = self.transform(sample)

        return sample

    def endmembers(self):
        return self.E

    def abundance(self):
        return self.A.numpy().reshape(self.n_row, self.n_col, -1)

    def image(self):
        return self.X.numpy().reshape(self.n_row, self.n_col, -1, order='F')

In [None]:
from HySpecLab.metrics import UnmixingLoss, NormalizedEntropy
from HySpecLab.metrics.regularization import SimplexVolumeLoss, SimilarityLoss

from HySpecLab.unmixing import ContrastiveUnmixing
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from torch import sigmoid
from torch.utils.data import Dataset, DataLoader

def train(model:nn.Module, n_endmembers:int, dataset:Dataset, n_batchs:int = 64, n_epochs:int = 100, lr=1e-3, simplex_weight=1e-5):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = UnmixingLoss() 
    volume_reg = SimplexVolumeLoss(dataset[:], n_endmembers).to(device)
    similarity_reg = SimilarityLoss(n_endmembers, temperature=.1, reduction='mean')
    vol_reg_norm = volume_reg(torch.sigmoid(model.ebk.detach()))
    print(vol_reg_norm)
    


    dataloader = DataLoader(dataset, batch_size=int(len(dataset)/n_batchs), shuffle=True)

    epoch_iterator = tqdm(
            range(n_epochs),
            leave=True,
            unit="epoch",
            postfix={"tls": "%.4f" % -1},
        )

    similarity_weight = 5e-1

    scaler = torch.cuda.amp.GradScaler()

    for epoch in epoch_iterator:
        epoch_loss = 0.
        for i, (x) in enumerate(dataloader):
            x = x.to(device)
            optimizer.zero_grad()
            y = model(x)
            loss = criterion(y, x) + simplex_weight*(volume_reg(sigmoid(model.ebk))/vol_reg_norm) + similarity_weight*similarity_reg(model.ebk)
            epoch_loss += loss.detach().item()

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
        epoch_iterator.set_postfix(tls="%.4f" % (epoch_loss/(i+1)))

In [None]:
from scipy import io as sio

synthetic_data_path = "/home/abian/Data/Dataset/HSI/SyntheticData/"
data = sio.loadmat(synthetic_data_path + "Y.mat")
M = sio.loadmat(synthetic_data_path + "M.mat")['M_avg']
A = sio.loadmat(synthetic_data_path + "A.mat")['A'].T

X = data['Y'].T
X = X.reshape(data['nRow'].item(), data['nCol'].item(), -1, order='F')
A = A.reshape(data['nRow'].item(), data['nCol'].item(), -1, order='F')

# Image to RGB

In [None]:
dataset = SyntheticDataset(synthetic_data_path, transform=None)

# matlab_data = {
#     'X': dataset.image(),
#     'n_endmembers': dataset.n_endmembers,
#     'nRow': dataset.n_row,
#     'nCol': dataset.n_col,
#     'nBand': dataset.n_bands
# }

# sio.savemat(os.path.join(result_path, 'data/input.mat'), matlab_data)

In [None]:
jasper_wv = np.linspace(380, 2500, 224) # 224 bands from 380 to 2500 nm

data = sio.loadmat(os.path.join(config.JasperRidge_PATH, 'jasperRidge2_R198.mat'))
selected_bands = data['SlectBands'].squeeze()
selected_jasper_wv = jasper_wv[selected_bands].tolist()

In [None]:
from HSI2RGB import HSI2RGB

# Use the D65 illuminant
illuminant = 65

# Do minor thresholding
threshold = 0.02
X = dataset.image()
(ydim, xdim, zdim) = X.shape

# Reorder data so that each column holds the spectra of of one pixel
HSI_data = np.reshape(X, [-1,zdim])
rgb = HSI2RGB(selected_jasper_wv, HSI_data, xdim, ydim, illuminant, threshold)

fig = plt.figure(figsize=(7,5))
plt.subplot(3,4,1)
plt.imshow(rgb)
plt.axis('off')
plt.show()

fig.savefig(os.path.join(result_path, 'imgs/synthetic_rgb.png'), dpi=300, bbox_inches='tight')

In [None]:
from utils import plot_endmembers, show_abundance
fig = plot_endmembers(dataset.endmembers(), ticks_range=(0, .7))
plt.show(fig)
fig.savefig(os.path.join(result_path, 'imgs/E_ref.pdf'), bbox_inches='tight')

fig = show_abundance(dataset.abundance())
fig.savefig(os.path.join(result_path, 'imgs/A_ref.png'), dpi=300, bbox_inches='tight')

## Endmember estimation

## VCA

In [None]:
from HySpecLab.eea import VCA
_X = X.reshape(-1, X.shape[-1])
n_endmembers = 3
vca = VCA(n_endmembers, snr_input=30, random_state=25)
vca.fit(_X)
endmembers = vca.endmembers()

plot_endmembers(endmembers, ticks_range=(0, .7))
plt.show()

## NFINDR

In [None]:
from utils import plot_endmembers
from pysptools import eea
n_endmembers = dataset.n_endmembers

ee = eea.NFINDR()
endmember = ee.extract(X, n_endmembers)
endmember_init = torch.from_numpy(endmember).float()

In [None]:
from HySpecLab.metrics import sad

sad_result = sad(endmember_init, dataset.endmembers())
e_idx = torch.argmin(sad_result, dim=0) # Index for reordering the ground truth
endmember_init = endmember_init[e_idx]
logit_endmember_init = torch.log((endmember_init / (1-endmember_init) + 1e-12))

fig = plot_endmembers(endmember_init, ticks_range=(0, .7))
plt.show(fig)
fig.savefig(os.path.join(result_path, 'imgs/E_nfindr_est.pdf'), bbox_inches='tight')

# Train

In [None]:
n_bands = dataset.n_bands
model = ContrastiveUnmixing(n_bands, n_endmembers, endmember_init=logit_endmember_init)
train(model, n_endmembers, dataset, n_batchs=50, n_epochs=50, lr=1e-3, simplex_weight=1e-2)

## save model

In [None]:
torch.save(model.state_dict(), os.path.join(result_path, 'weights/clhu.pth'))

# Testing

In [None]:
# load model
model = ContrastiveUnmixing(dataset.n_bands, dataset.n_endmembers)

model.load_state_dict(torch.load(os.path.join(result_path, 'weights/clhu.pth')))
model = model.eval()

In [None]:
criterion = UnmixingLoss()
entropy_reg  = NormalizedEntropy(S=dataset.n_endmembers)
volume_reg = SimplexVolumeLoss(dataset[:], dataset.n_endmembers)
similarity_reg = SimilarityLoss(dataset.n_endmembers, temperature=.1, reduction='mean')

In [None]:
_X = dataset.X

model.eval()
model = model.cpu()
reconstruc = model(_X)
with torch.no_grad():
    print(criterion(reconstruc, _X).cpu(), entropy_reg(model.A).cpu(), volume_reg(sigmoid(model.ebk)).cpu(),
         similarity_reg(sigmoid(model.ebk)).cpu())

In [None]:
volume_reg(endmember_init), similarity_reg(endmember_init)

In [None]:
ebk = sigmoid(model.ebk).detach().cpu()

fig = plot_endmembers(ebk, ticks_range=(0, .7))
plt.show(fig)
# fig.savefig(os.path.join(result_path, 'imgs/E_clhu_est.pdf'), bbox_inches='tight')

In [None]:
from torch.nn.functional import softmax
from utils import show_abundance

test = softmax(model.A.detach(), dim=1).cpu().numpy().reshape(dataset.n_row, dataset.n_col, -1)

fig = show_abundance(test)
plt.show(fig)

# fig.savefig(os.path.join(result_path, 'imgs/A_clhu_est.png'), dpi=300, bbox_inches='tight')


In [None]:
y = model(_X)
y = y.detach().cpu()

from HySpecLab.metrics import rmse
rmse(y, _X, dim=None)

In [None]:
from HySpecLab.metrics import sad

sad_result = sad(ebk, dataset.endmembers())
print(sad_result)
np.argmin(sad_result, axis=0)

In [None]:
sad_result = sad(endmember_init, dataset.endmembers())
print(sad_result)
np.argmin(sad_result, axis=0)

In [None]:
test = sio.loadmat(os.path.join(synthetic_data_path, 'A_est', 'FCLS.mat'))['A'].T
# from utils import plot_abundance
A = test.reshape(50,50,-1,order='C')
fig = show_abundance(A)
plt.show(fig)

In [None]:
data = sio.loadmat(os.path.join(synthetic_data_path, 'M_est', 'MESMA.mat'))
M = data['M']
plt.plot(M[:,:,2000])
plt.show()

In [None]:
plt.plot(M[:,:,1000])
plt.show()