In [None]:
import os, sys
project_root_dir = os.path.join(os.getcwd(),'..')
if project_root_dir not in sys.path:
    sys.path.append(project_root_dir)

from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import config

plt.style.use('seaborn-v0_8')

In [None]:
from HySpecLab.metrics import UnmixingLoss, NormalizedEntropy
from HySpecLab.metrics.regularization import SimplexVolumeLoss, SimilarityLoss

from HySpecLab.unmixing import ContrastiveUnmixing
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from torch import sigmoid
from torch.utils.data import Dataset, DataLoader

def train(model:nn.Module, n_endmembers:int, dataset:Dataset, n_batchs:int = 64, n_epochs:int = 100, lr=1e-3, simplex_weight=1e-5):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = UnmixingLoss() 
    volume_reg = SimplexVolumeLoss(dataset[:], n_endmembers).to(device)
    similarity_reg = SimilarityLoss(n_endmembers, temperature=.1, reduction='mean')
    vol_reg_norm = volume_reg(torch.sigmoid(model.ebk.detach()))
    print(vol_reg_norm)
    


    dataloader = DataLoader(dataset, batch_size=int(len(dataset)/n_batchs), shuffle=True)

    epoch_iterator = tqdm(
            range(n_epochs),
            leave=True,
            unit="epoch",
            postfix={"tls": "%.4f" % -1},
        )

    similarity_weight = 1e-1

    scaler = torch.cuda.amp.GradScaler()

    for epoch in epoch_iterator:
        epoch_loss = 0.
        for i, (x) in enumerate(dataloader):
            x = x.to(device)
            optimizer.zero_grad()
            y = model(x)
            loss = criterion(y, x) + simplex_weight*(volume_reg(sigmoid(model.ebk))/vol_reg_norm) + similarity_weight*similarity_reg(model.ebk)
            epoch_loss += loss.detach().item()

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
        epoch_iterator.set_postfix(tls="%.4f" % (epoch_loss/(i+1)))

# Samson

In [None]:
samson_save_path = os.path.join(config.IMG_PATH, 'Samson')

In [None]:
from dataset import Samson
dataset = Samson(config.Samson_PATH)

In [None]:
from HySpecLab.eea import VCA
n_endmembers = dataset.n_endmembers + 0

vca = VCA(n_endmembers=n_endmembers, snr_input=1, random_state=42)

E = vca.fit(dataset.X.numpy())
endmember_init = torch.from_numpy(vca.endmembers()).float()
# forces that the max value of each ealemend is 1 - 1e-3, For testing!!
# endmember_init = (endmember_init / endmember_init.max(dim=1, keepdim=True)[0]) * .9
logit_endmember_init = torch.log((endmember_init / (1-endmember_init) + 1e-12))

In [None]:
from utils import plot_endmembers
fig = plot_endmembers(endmember_init)
plt.show(fig)

In [None]:
from pysptools import eea
n_endmembers = dataset.n_endmembers

ee = eea.NFINDR()
endmember = ee.extract(dataset.image().numpy(), n_endmembers)
endmember_init = torch.from_numpy(endmember).float()
logit_endmember_init = torch.log((endmember_init / (1-endmember_init) + 1e-12))


fig = plot_endmembers(endmember_init)
plt.show(fig)
# with plt.style.context(("seaborn-colorblind")):
#     plt.plot(endmember_init.T)
# plt.show()

In [None]:
# n_endmembers = dataset.n_endmembers + 0

# ee = eea.FIPPI()
# endmember = ee.extract(dataset.image().numpy(), n_endmembers-1)
# endmember_init = torch.from_numpy(endmember).float()
# endmember_init = endmember_init[1:]
# logit_endmember_init = torch.log((endmember_init / (1-endmember_init) + 1e-12))

# with plt.style.context(("seaborn-colorblind")):
#     plt.plot(endmember_init.T)
# plt.show()



In [None]:
n_bands = dataset.n_bands
model = ContrastiveUnmixing(n_bands, n_endmembers, endmember_init=logit_endmember_init)
# train(model, n_endmembers, dataset, n_batchs=32, n_epochs=100, lr=1e-3)
train(model, n_endmembers, dataset, n_batchs=50, n_epochs=50, lr=1e-3, simplex_weight=1e-2)

In [None]:
from HySpecLab.metrics import NormalizedEntropy

criterion = UnmixingLoss()
entropy_reg  = NormalizedEntropy(S=n_endmembers)
volume_reg = SimplexVolumeLoss(dataset[:], n_endmembers)
similarity_reg = SimilarityLoss(n_endmembers, temperature=.1, reduction='mean')

In [None]:
_X = dataset.X

model.eval()
model = model.cpu()
reconstruc = model(_X)
with torch.no_grad():
    print(criterion(reconstruc, _X).cpu(), entropy_reg(model.A).cpu(), volume_reg(sigmoid(model.ebk)).cpu(),
         similarity_reg(sigmoid(model.ebk)).cpu())

In [None]:
volume_reg(endmember_init)

In [None]:
ebk = sigmoid(model.ebk).detach().cpu()
fig = plot_endmembers(ebk)
plt.show(fig)

In [None]:
# ordering the endmembers
endmembers = dataset.endmembers().detach().cpu()
from HySpecLab.metrics import sad
sad_result = sad(ebk, endmembers)
print(sad_result)
idx = torch.argmin(sad_result, dim=1) # Index for reordering the ground truth
print(idx)

# reorder the endmembers
endmembers = endmembers[idx]

In [None]:
# ordering the endmembers
endmembers = dataset.endmembers().detach().cpu()
from HySpecLab.metrics import sad
sad_result = sad(endmember_init, endmembers)
print(sad_result)
idx = torch.argmin(sad_result, dim=1) # Index for reordering the ground truth
print(idx)

# reorder the endmembers
endmembers = endmembers[idx]

In [None]:
fig = plot_endmembers(endmembers)
plt.show(fig)

In [None]:
from torch.nn.functional import softmax
test = softmax(model.A.detach(), dim=1).cpu().numpy().reshape(dataset.n_row, dataset.n_col, n_endmembers)
labels = list(map(lambda x: f'$E_{x}$', range(1, n_endmembers+1)))

# with plt.style.context(("seaborn-colorblind")):
fig = plt.figure(figsize=(12,7))
for i in range(n_endmembers):
    plt.subplot(4,5,i+1)
    # plt.imshow(test[:,:,i].T, vmin=0, vmax=softmax(model.A, dim=1).max(), cmap='viridis')
    plt.imshow(test[:,:,i].T, cmap='viridis')
    plt.xticks([])
    plt.yticks([])
    plt.title(labels[i], fontsize='x-large')
    plt.colorbar()

plt.tight_layout()
plt.show()

In [None]:
ebk = sigmoid(model.ebk).detach().cpu()
endmembers = dataset.endmembers()

from torch.nn.functional import mse_loss
def rmse(x: torch.Tensor, y: torch.Tensor):
    return torch.sqrt(mse_loss(x, y, reduction='none').mean(dim=1))

abundance = softmax(model.A.detach(), dim=1).cpu().reshape(dataset.n_row, dataset.n_col, n_endmembers)
abundance = abundance.permute(2,0,1)
abundance_gt = dataset.abundance()[:,:,idx].permute(2,0,1) # Reorder the ground truth
endmember_gt = dataset.endmembers()[idx, :]

rmse_result = rmse(abundance.flatten(1), abundance_gt.flatten(1))
print(rmse_result)

In [None]:
fig = plt.figure(figsize=(12,7))
for i in range(n_endmembers):
    plt.subplot(4,5,i+1)
    # plt.imshow(test[:,:,i].T, vmin=0, vmax=softmax(model.A, dim=1).max(), cmap='viridis')
    plt.imshow(abundance_gt[i,:,:].T, cmap='viridis')
    plt.xticks([])
    plt.yticks([])
    plt.title(labels[i], fontsize='x-large')
    plt.colorbar()

plt.tight_layout()
plt.show()

# Urban dataset

In [None]:
from dataset import Urban
dataset = Urban(root_dir=config.Urban_PATH)

In [None]:
from HySpecLab.eea import VCA
n_endmembers = dataset.n_endmembers
endmembers = dataset.endmembers()

# from HySpecLab.metrics import sad
# import numpy as np
# for i in range(256):
#     vca = VCA(n_endmembers=n_endmembers, snr_input=1, random_state=i)

#     E = vca.fit(dataset.X.numpy())
#     endmember_init = torch.from_numpy(vca.endmembers()).float()

#     sad_result = sad(endmember_init, endmembers)
#     idx = torch.argmin(sad_result, dim=1) # Index for reordering the ground truth
#     if np.unique(idx).shape[0] == n_endmembers:
#         print(i)
#         break

vca = VCA(n_endmembers=n_endmembers, snr_input=1, random_state=42)

E = vca.fit(dataset.X.numpy())
endmember_init = torch.from_numpy(vca.endmembers()).float()

# forces that the max value of each ealemend is 1 - 1e-3, For testing!!
# endmember_init = (endmember_init / endmember_init.max(dim=1, keepdim=True)[0]) * .9
logit_endmember_init = torch.log((endmember_init / (1-endmember_init) + 1e-12))


with plt.style.context(("seaborn-colorblind")):
    plt.plot(endmember_init.T)
plt.show()

In [None]:
from pysptools import eea
n_endmembers = dataset.n_endmembers

ee = eea.NFINDR()
endmember = ee.extract(dataset.image().numpy(), n_endmembers)
endmember_init = torch.from_numpy(endmember).float()
logit_endmember_init = torch.log((endmember_init / (1-endmember_init) + 1e-12))


fig = plot_endmembers(endmember_init)
plt.show(fig)

In [None]:
n_bands = dataset.n_bands
model = ContrastiveUnmixing(n_bands, n_endmembers, endmember_init=logit_endmember_init)
# train(model, n_endmembers, dataset, n_batchs=32, n_epochs=100, lr=1e-3)
train(model, n_endmembers, dataset, n_batchs=50, n_epochs=50, lr=1e-3, simplex_weight=1e-2)

In [None]:
criterion = UnmixingLoss()
entropy_reg  = NormalizedEntropy(S=n_endmembers)
volume_reg = SimplexVolumeLoss(dataset[:], n_endmembers)
similarity_reg = SimilarityLoss(n_endmembers, temperature=.1, reduction='mean')

In [None]:
_X = dataset.X

model.eval()
model = model.cpu()
reconstruc = model(_X)
with torch.no_grad():
    print(criterion(reconstruc, _X).cpu(), entropy_reg(model.A).cpu(), volume_reg(sigmoid(model.ebk)).cpu(),
         similarity_reg(sigmoid(model.ebk)).cpu())

In [None]:
similarity_reg(endmember_init), volume_reg(endmember_init)

In [None]:
labels = list(map(lambda x: f'$E_{x}$', range(1, n_endmembers+1)))
ebk = sigmoid(model.ebk).detach().cpu()
with plt.style.context(("seaborn-colorblind")):
    fig = plt.figure(figsize=(7, 5))
    plt.plot(ebk.T, label=labels)
    plt.ylabel('Reflectance', fontsize='x-large')
    plt.xlabel('Bands', fontsize='x-large')
    #legend background white
    plt.legend(fontsize='x-large')
    plt.xticks(fontsize='x-large')
    plt.yticks(fontsize='x-large')
    plt.tight_layout()
    
plt.show()

In [None]:
# ordering the endmembers
endmembers = dataset.endmembers().detach().cpu()
from HySpecLab.metrics import sad
sad_result = sad(ebk, endmembers)
print(sad_result)
idx = torch.argmin(sad_result, dim=1) # Index for reordering the ground truth
print(idx)

# idx[1] = 1
idx[-1] = 1
# idx[-2] = 1

# reorder the endmembers
endmembers = endmembers[idx]
print(idx)


In [None]:
labels = list(map(lambda x: f'$E_{x}$', range(1, len(dataset.endmembers())+1)))
with plt.style.context(("seaborn-colorblind")):
    fig = plt.figure(figsize=(7, 5))
    plt.plot(dataset.endmembers().T, label=labels)
    plt.ylabel('Reflectance', fontsize='x-large')
    plt.xlabel('Bands', fontsize='x-large')
    plt.legend(fontsize='x-large', facecolor='white')
    plt.xticks(fontsize='x-large')
    plt.yticks(fontsize='x-large')
    # plt.title('Ground Truth', fontsize='x-large')
    plt.tight_layout()
    plt.show()

In [None]:
from torch.nn.functional import softmax
test = softmax(model.A.detach(), dim=1).cpu().numpy().reshape(dataset.n_row, dataset.n_col, n_endmembers)
labels = list(map(lambda x: f'$E_{x}$', range(1, n_endmembers+1)))

# with plt.style.context(("seaborn-colorblind")):
fig = plt.figure(figsize=(12,7))
for i in range(n_endmembers):
    plt.subplot(4,5,i+1)
    # plt.imshow(test[:,:,i].T, vmin=0, vmax=softmax(model.A, dim=1).max(), cmap='viridis')
    plt.imshow(test[:,:,i].T, cmap='viridis')
    plt.xticks([])
    plt.yticks([])
    plt.title(labels[i], fontsize='x-large')
    plt.colorbar()

plt.tight_layout()
plt.show()

In [None]:
ebk = sigmoid(model.ebk).detach().cpu()
endmembers = dataset.endmembers()
print(endmembers.shape)

from torch.nn.functional import mse_loss
def rmse(x: torch.Tensor, y: torch.Tensor):
    return torch.sqrt(mse_loss(x, y, reduction='none').mean(dim=1))

abundance = softmax(model.A.detach(), dim=1).cpu().reshape(dataset.n_row, dataset.n_col, n_endmembers)
abundance = abundance.permute(2,0,1)
abundance_gt = dataset.abundance()[:,:,idx].permute(2,0,1) # Reorder the ground truth
print(abundance_gt.shape)
endmember_gt = dataset.endmembers()[idx, :]

rmse_result = rmse(abundance.flatten(1), abundance_gt.flatten(1))
print(rmse_result)

In [None]:
fig = plt.figure(figsize=(12,7))
for i in range(n_endmembers):
    plt.subplot(4,5,i+1)
    # plt.imshow(test[:,:,i].T, vmin=0, vmax=softmax(model.A, dim=1).max(), cmap='viridis')
    plt.imshow(abundance_gt[i,:,:].T, cmap='viridis')
    plt.xticks([])
    plt.yticks([])
    plt.title(labels[i], fontsize='x-large')
    plt.colorbar()

plt.tight_layout()
plt.show()

# Cuprite

In [None]:
cuprite_save_path = os.path.join(config.IMG_PATH, 'Cuprite')

In [None]:
from dataset import Cuprite
dataset = Cuprite(config.Cuprite_PATH)

plt.imshow(dataset.image()[:,:,0], cmap='viridis')
plt.axis('off')
plt.show()

In [None]:
from HySpecLab.eea import VCA
# n_endmembers = dataset.n_endmembers
n_endmembers = 12

vca = VCA(n_endmembers=n_endmembers, snr_input=1, random_state=1024)

E = vca.fit(dataset.X.numpy())
endmember_init = torch.from_numpy(vca.endmembers()).float()

# forces that the max value of each ealemend is 1 - 1e-3, For testing!!
# endmember_init = (endmember_init / endmember_init.max(dim=1, keepdim=True)[0]) * .9
logit_endmember_init = torch.log((endmember_init / (1-endmember_init) + 1e-12))

from utils import plot_endmembers
import numpy as np
def plot_endmembers(E: np.ndarray, wv:np.ndarray = None, labels:list = None, figsize:tuple = (7,5), ticks_range:tuple=(0, 1), n_ticks:int=5):
    '''
        Plot endmembers.

        Parameters
        ----------
            E : 2-D array, shape (n_endmembers, n_bands)
                Endmembers.
            wv : 1-D array, optional, shape (n_bands)
                Wavelengths in nm. Default is None.
            labels : list, optional
                Labels for endmembers. Default is None.
            figsize : tuple, optional
                Figure size. Default is (7,5).
            ticks_range : tuple, optional
                Range of yticks. Default is (0, 1).
            n_ticks : int, optional
                Number of yticks. Default is 5.
    '''
    ticks_formatter = plt.FormatStrFormatter('%.2f')

    n_endmembers, n_bands = E.shape
    if labels is None:
        labels = list(map(lambda x: r'$E_{{{}}}$'.format(x), range(1, n_endmembers+1)))

    with plt.style.context(("seaborn-colorblind")):
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        ticks = np.linspace(*ticks_range, n_ticks)
        if wv is None:
            ax.plot(E.T, label=labels)
            ax.set_xlabel('Bands', fontsize='x-large')
        else:
            ax.plot(wv, E.T, label=labels)
            ax.set_xlabel('Wavelength (nm)', fontsize='x-large')

        ax.set_ylabel('Reflectance', fontsize='x-large')           
        ax.set_yticks(ticks)
        ax.yaxis.set_major_formatter(ticks_formatter) # set format in y ticks labels
        ax.set_ylim(ticks_range[0] - 0.025, ticks_range[1] + 0.025)
        ax.set_xlim(0 - 1.5, n_bands + 1.5)
        ax.tick_params(axis='both', labelsize='large')
    
        handles, labels = ax.get_legend_handles_labels()
        fig.legend(handles, labels, loc='upper center', ncol=6, fontsize='large', borderpad=-.25)
        fig.tight_layout(pad=(((n_endmembers-1)//6)+1)*2) # padding based on the endmembers number

    return fig

fig = plot_endmembers(endmember_init[:12], ticks_range=(0, endmember_init.max()))
plt.show(fig)

In [None]:
from pysptools import eea
n_endmembers = dataset.n_endmembers
# n_endmembers = 3

ee = eea.NFINDR()
endmember = ee.extract(dataset.image().numpy(), n_endmembers)
endmember_init = torch.from_numpy(endmember).float()
logit_endmember_init = torch.log((endmember_init / (1-endmember_init) + 1e-12))

# from utils import plot_endmembers
fig = plot_endmembers(endmember_init)
plt.show(fig)

In [None]:
criterion = UnmixingLoss()
entropy_reg  = NormalizedEntropy(S=n_endmembers)
volume_reg = SimplexVolumeLoss(dataset[:], n_endmembers)
similarity_reg = SimilarityLoss(n_endmembers, temperature=.1, reduction='mean')

In [None]:
volume_reg(endmember_init), similarity_reg(endmember_init)

In [None]:
n_bands = dataset.n_bands
model = ContrastiveUnmixing(n_bands, n_endmembers, endmember_init=logit_endmember_init)
# train(model, n_endmembers, dataset, n_batchs=32, n_epochs=100, lr=1e-3)
train(model, n_endmembers, dataset, n_batchs=50, n_epochs=50, lr=1e-3, simplex_weight=1e-5)

In [None]:
criterion = UnmixingLoss()
entropy_reg  = NormalizedEntropy(S=n_endmembers)
volume_reg = SimplexVolumeLoss(dataset[:], n_endmembers)
similarity_reg = SimilarityLoss(n_endmembers, temperature=.1, reduction='mean')

In [None]:
_X = dataset.X

model.eval()
model = model.cpu()
reconstruc = model(_X)
with torch.no_grad():
    print(criterion(reconstruc, _X).cpu(), entropy_reg(model.A).cpu(), volume_reg(sigmoid(model.ebk)).cpu(),
         similarity_reg(sigmoid(model.ebk)).cpu())

In [None]:
volume_reg(endmember_init), similarity_reg(endmember_init)

In [None]:
ebk = sigmoid(model.ebk).detach().cpu()

fig = plot_endmembers(ebk)
plt.show(fig)

fig = plot_endmembers(dataset.endmembers())
plt.show(fig)

In [None]:
from torch.nn.functional import softmax

test = softmax(model.A.detach(), dim=1).cpu().numpy().reshape(dataset.n_row, dataset.n_col, n_endmembers)

def show_abundance(A, labels:list = None, figsize:tuple=(7,5)):
    n_rows, n_cols, n_endmembers = A.shape

    if labels is None:
        labels = list(map(lambda x: r'$E_{{{}}}$'.format(x), range(1, n_endmembers+1)))
        
    ticks_formatter = plt.FormatStrFormatter('%.1f')
    fig = plt.figure(figsize=(7,5))
    for i in range(n_endmembers):
        data = A[:,:,i].T
        plt.subplot(3,4,i+1)
        plt.imshow(data, cmap='viridis')
        plt.axis('off')
        plt.title(labels[i], fontsize='x-large')
        cb = plt.colorbar(format=ticks_formatter, ticks=[data.min() + 1e-3, data.max() - 1e-3],
                         orientation='horizontal', fraction=0.1, pad=0.01)

    plt.tight_layout()
    return fig

fig = show_abundance(test)
fig.savefig('abundance.png', dpi=300, bbox_inches='tight')
plt.show(fig)

In [None]:
import os, sys
project_root_dir = os.path.join(os.getcwd(),'../..')
if project_root_dir not in sys.path:
    sys.path.append(project_root_dir)

from matplotlib import pyplot as plt
import numpy as np
from torch import nn
import torch
import config

from utils import show_abundance, plot_endmembers
from dataset import JasperRidgeDataset
from HySpecLab.metrics import rmse, sad
from scipy import io as sio

In [None]:
endmembers = dataset.endmembers()
ebk = sigmoid(model.ebk).detach().cpu()

from HySpecLab.metrics import sad
sad_result = sad(ebk, endmembers)
print(sad_result)
import numpy as np
idx = np.argmin(sad_result, axis=1)
print(idx)

print(np.unique(idx))