In [None]:
import os, sys
project_root_dir = os.path.join(os.getcwd(),'../..')
if project_root_dir not in sys.path:
    sys.path.append(project_root_dir)

from matplotlib import pyplot as plt
import numpy as np
import torch
import config

In [None]:
from dataset import Apex
dataset = Apex(config.Apex_PATH)

result_path = os.path.join(config.RESULTS_PATH, 'apex')

# Ground Truth

In [None]:
from utils import plot_endmembers, show_abundance
fig = plot_endmembers(dataset.endmembers(), np.array(dataset.wv), ticks_range=(0, .5), n_ticks=5)
plt.show(fig)
# fig.savefig(os.path.join(result_path, 'imgs/M_ref.pdf'), bbox_inches='tight')

fig = show_abundance(dataset.abundance())
plt.show(fig)
# fig.savefig(os.path.join(result_path, 'imgs/A_ref.png'), dpi=300, bbox_inches='tight')

In [None]:
dataset.endmembers().shape, dataset.wv.shape

In [None]:
from HySpecLab.metrics import sad

def sort_endmember(endmembers, gt):
    sad_result = sad(endmembers, gt)
    e_idx = torch.argmin(sad_result, dim=0) # Index for reordering the ground truth
    return e_idx, sad_result

In [None]:
from HySpecLab.eea import VCA

n_endmembers = dataset.n_endmembers
   
vca = VCA(n_endmembers, snr_input=20, random_state=42)
vca.fit(dataset.X.numpy())
endmembers = torch.from_numpy(vca.endmembers()).float()
e_idx, sad_result = sort_endmember(endmembers, dataset.endmembers())

vca_endmember_init = endmembers[e_idx]
vca_logit_endmember_init = torch.log((vca_endmember_init + 1e-12) / ((1-vca_endmember_init) + 1e-12))

fig = plot_endmembers(vca_endmember_init, dataset.wv, ticks_range=(0, 1))
plt.show(fig)

# fig.savefig(os.path.join(result_path, 'imgs/M_vca.pdf'), bbox_inches='tight')

In [None]:
from utils import plot_endmembers
from pysptools import eea
n_endmembers = dataset.n_endmembers

ee = eea.NFINDR()
endmember = torch.from_numpy(ee.extract(dataset.image(), n_endmembers)).float()

e_idx, _ = sort_endmember(endmember, dataset.endmembers())
nfindr_endmember_init = endmember[e_idx]
nfindr_logit_endmember_init = torch.log((nfindr_endmember_init + 1e-12) / ((1-nfindr_endmember_init) + 1e-12))

fig = plot_endmembers(nfindr_endmember_init, dataset.wv, ticks_range=(0, 1))
plt.show(fig)
# fig.savefig(os.path.join(result_path, 'imgs/M_nfindr.pdf'), bbox_inches='tight')

In [None]:
fig = plot_endmembers(dataset.endmembers() / dataset.endmembers().max(), dataset.wv, ticks_range=(0, 1), endmember_estimation=[nfindr_endmember_init, vca_endmember_init], ee_labels=['Ground Truth', 'N-FINDR', 'VCA'])
plt.show(fig)
# fig.savefig(os.path.join(result_path, 'imgs/M_estimation.pdf'), bbox_inches='tight')

In [None]:
endmember_init_method = 'nfindr'
endmember_init = nfindr_endmember_init
logit_endmember_init = nfindr_logit_endmember_init

# endmember_init_method = 'vca'
# endmember_init = vca_endmember_init
# logit_endmember_init = vca_logit_endmember_init

# Training

In [None]:
from utils import train 
from HySpecLab.unmixing import ContrastiveUnmixing

n_bands = dataset.n_bands
model = ContrastiveUnmixing(n_bands, n_endmembers, endmember_init=logit_endmember_init, sigma_sparsity=.5)
_ = model(dataset.X)
print(model.sparse_gate.regularize())
# train(model, n_endmembers, dataset, n_batchs=100, n_epochs=100, lr=1e-3, similarity_weight=1, sparse_weight=.1)
train(model, n_endmembers, dataset, n_batchs=100, n_epochs=100, lr=1e-3, similarity_weight=1, sparse_weight=1e-3)

In [None]:
# model = ContrastiveUnmixing(n_bands, n_endmembers, endmember_init=logit_endmember_init, sigma_sparsity=.1)
model.eval()
z = model.encoder(dataset.X.cuda())
model.sparse_gate(z).mean()

In [None]:
model.eval()
_ = model(dataset.X.cuda())
print(model.sparse_gate.variational_parameter().flatten())
print(model.sparse_gate.variational_parameter().flatten().mean())
print(model.sparse_gate.variational_parameter().flatten().min())
print(model.sparse_gate.regularize())

# Save Model

In [None]:
torch.save(model.state_dict(), os.path.join(result_path, 'clhu/weights/clhu.pth'))

# Testing model

In [None]:
# from HySpecLab.unmixing import ContrastiveUnmixing

# # load model
# model = ContrastiveUnmixing(dataset.n_bands, dataset.n_endmembers)

# model.load_state_dict(torch.load(os.path.join(result_path, 'clhu/weights/clhu.pth')))
model = model.eval()

In [None]:
from HySpecLab.metrics.regularization import SimplexVolumeLoss, SimilarityLoss
from HySpecLab.metrics import UnmixingLoss, NormalizedEntropy

model.eval()
criterion = UnmixingLoss()
entropy_reg  = NormalizedEntropy(S=n_endmembers)
volume_reg = SimplexVolumeLoss(dataset[:], n_endmembers)
similarity_reg = SimilarityLoss(n_endmembers, temperature=.1, reduction='mean')

In [None]:
from torch import sigmoid 
_X = dataset.X

model.eval()
model = model.cpu()
reconstruc = model(_X)
with torch.no_grad():
    print(criterion(reconstruc, _X).cpu(), entropy_reg(model.A).cpu(), volume_reg(sigmoid(model.ebk)).cpu(),
         similarity_reg(model.ebk).cpu())
        #  similarity_reg(sigmoid(model.ebk)).cpu())
    

In [None]:
volume_reg(endmember_init), similarity_reg(logit_endmember_init)

In [None]:
ebk = torch.sigmoid(model.ebk).detach().cpu()
if endmember_init_method == 'vca':
    label = 'VCA'
else:
    label = 'N-FINDR'
fig = plot_endmembers(ebk, ticks_range=(0, 1), endmember_estimation=[endmember_init], ee_labels=['CLHU',label])
# fig = plot_endmembers(ebk, ticks_range=(0, 1))

plt.show(fig)
fig.savefig(os.path.join(result_path, f'clhu/imgs/M_clhu_{endmember_init_method}.pdf'), bbox_inches='tight')

In [None]:
from torch.nn.functional import softmax
 
test = softmax(model.A.detach(), dim=1).cpu().numpy().reshape(dataset.n_row, dataset.n_col, -1)

fig = show_abundance(test)
plt.show(fig)

fig.savefig(os.path.join(result_path, f'clhu/imgs/A_clhu_{endmember_init_method}.pdf'), bbox_inches='tight')

# # imshow bigger test[:,:,0]
# test1 = test[:,:,1]
# # test1[test1>.3] = 1
# fig = plt.figure(figsize=(10,10))
# plt.imshow(test1, cmap='viridis')
# plt.axis('off')
# plt.show(fig)

In [None]:
# test = model._sparse.detach().cpu().numpy().reshape(dataset.n_row, dataset.n_col)
test = model.sparse_gate.variational_parameter().detach().cpu().numpy().reshape(dataset.n_row, dataset.n_col)
test = np.log(test)
plt.imshow(test.T, cmap='jet')
plt.colorbar()
plt.show()

In [None]:
from HySpecLab.metrics import rmse, sad

X_true = dataset.A @ dataset.endmembers()
# X_true = dataset.X
X_hat = model(dataset.X).detach().cpu()
A_hat = torch.softmax(model.A.detach().cpu(), dim=1)
M_hat = sigmoid(model.ebk.detach().cpu())

_M_hat = model(M_hat).detach().cpu()

import pandas as pd
df = pd.DataFrame(columns=['Method', 'RMSE_X', 'RMSE_A', 'SAD_M'])
df['Method'] = ['CLHU']
df['RMSE_X'] = [rmse(X_true, X_hat, dim=None).numpy()]
df['RMSE_A'] = [rmse(dataset.A, A_hat, dim=None).numpy()]

# sad_result = sad(M_hat, dataset.endmembers()).numpy()
sad_result = sad(_M_hat, dataset.endmembers()).numpy()

df['SAD_M'] = np.diagonal(sad_result).mean()

# df.to_csv(os.path.join(result_path, 'clhu/metrics.csv'), index=False)
df

In [None]:
test2 = test[:, :, 0]

# test2[test2 > .5] = 1
plt.imshow( test2, cmap='viridis')
plt.axis('off')
plt.show()



In [None]:
A_road = dataset.abundance()[:,:,0]
# find the coordinates of the 10 pixels with highest values
idx = np.unravel_index(np.argsort(A_road.ravel())[-50:], A_road.shape)
# plot the 10 pixels
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(dataset.image()[:,:, 10], cmap='gray')
ax.scatter(idx[1], idx[0], s=10, c='r')
ax.axis('off')  
plt.show()

A_water = dataset.abundance()[:,:,-1]
idx_water = np.unravel_index(np.argsort(A_water.ravel())[-50:], A_water.shape)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(dataset.image()[:,:, 10], cmap='gray')
ax.scatter(idx_water[1], idx_water[0], s=10, c='r')
ax.axis('off')
plt.show()


signal_road = dataset.image()[idx[0], idx[1], :]
signal_water = dataset.image()[idx_water[0], idx_water[1], :]

print(signal_road.shape, signal_water.shape)
plt.plot(signal_road.mean(axis=0), label='road')
plt.plot(signal_water.mean(axis=0), label='water')
plt.legend()
plt.show()

# compute sad between signal_road mean and signal_water mean
from HySpecLab.metrics import sad
sad(torch.tensor(signal_road.mean(axis=0)).reshape(1,-1), torch.tensor(signal_water.mean(axis=0)).reshape(1,-1))


In [None]:
#plot the signal of the water
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(dataset.wv, dataset.image()[0, 0, :].T)
ax.set_xlabel('Wavelength')
ax.set_ylabel('Reflectance')
plt.show()



In [None]:
dataset = Apex('/home/abian/Data/Dataset/HSI/Apex/')
A = dataset.abundance()
for i in range(4):
    plt.subplot(2,2,i+1)
    plt.imshow(A[:,:,i], cmap='viridis')
    plt.axis('off')
plt.show()



# ...

In [None]:
from HySpecLab.metrics import rmse
from torch import sigmoid
def test(model, dataset):
    X = dataset.X
    model.eval()
    model = model.cpu()
    
    X_true = dataset.A @ dataset.endmembers()
    with torch.no_grad():
        X_hat = model(dataset.X)
        A_hat = torch.softmax(model.A, dim=1)
        M_hat = sigmoid(model.ebk).detach()
        _M_hat = model(M_hat).detach().cpu()
    
    rmse_x = rmse(X_true, X_hat, dim=None).numpy()      
    rmse_a = rmse(dataset.A, A_hat, dim=None).numpy()
    sad_m = np.diagonal(sad(M_hat, dataset.endmembers()).numpy()).mean()
    sad_m_2 = np.diagonal(sad(_M_hat, dataset.endmembers()).numpy()).mean()
    return rmse_x.item(), rmse_a.item(), sad_m, sad_m_2

In [None]:
from HySpecLab.unmixing import ContrastiveUnmixing

n_bands = dataset.n_bands

batch_rmse_x = []
batch_rmse_a = []
batch_sad_m = []
for i in range(10):
    model = ContrastiveUnmixing(n_bands, n_endmembers, endmember_init=logit_endmember_init, sigma_sparsity=.5)
    train(model, n_endmembers, dataset, n_batchs=100, n_epochs=100, lr=1e-3, similarity_weight=1, sparse_weight=1e-3)

    rmse_x, rmse_a, sad_m, x = test(model, dataset)
    batch_rmse_x.append(rmse_x)
    batch_rmse_a.append(rmse_a)
    batch_sad_m.append(sad_m)

    print(rmse_x, rmse_a, sad_m, x)

In [None]:
# generate dataframe
import pandas as pd
df = pd.DataFrame(columns=['RMSE_X', 'RMSE_A', 'SAD_M'])
df['RMSE_X'] = batch_rmse_x
df['RMSE_A'] = batch_rmse_a
df['SAD_M'] = batch_sad_m

# extract mean and std
df['RMSE_X'].mean(), df['RMSE_X'].std(), df['RMSE_A'].mean(), df['RMSE_A'].std(), df['SAD_M'].mean(), df['SAD_M'].std()

In [None]:
df.to_csv(os.path.join(result_path, 'clhu/metrics_{}_batch.csv'.format(endmember_init_method)), index=False)

In [None]:
os.path.join(result_path, 'clhu/metrics_{}_batch.csv'.format(endmember_init_method))

In [None]:
print(np.diagonal(sad(endmember_init, dataset.endmembers()).numpy()).mean())
print(np.diagonal(sad(sigmoid(model.ebk.detach()).cpu(), dataset.endmembers()).numpy()).mean())

In [None]:
print(sad(sigmoid(model.ebk.detach()).cpu(), dataset.endmembers()))

print(sad(endmember_init, dataset.endmembers()))




In [None]:
fig = plot_endmembers(sigmoid(model.ebk.detach()).cpu(), dataset.wv, ticks_range=(0, 1), endmember_estimation=[endmember_init], ee_labels=['CLHU', '$M_0$'])
plt.show(fig)

In [None]:
sad(endmember_init, endmember_init)

In [None]:
sad(sigmoid(model.ebk.detach()).cpu(), sigmoid(model.ebk.detach()).cpu())