In [None]:
import os, sys
project_root_dir = os.path.join(os.getcwd(),'../..')
if project_root_dir not in sys.path:
    sys.path.append(project_root_dir)

from matplotlib import pyplot as plt
import numpy as np
from torch import nn
import torch
import config

from utils import show_abundance, plot_endmembers
# from dataset import JasperRidgeDataset
from HySpecLab.metrics import rmse, sad
from scipy import io as sio

In [None]:
from dataset import JasperRidge

dataset = JasperRidge(config.JasperRidge_PATH)
dataset_name = 'JasperRidge'
wv = np.array(dataset.wv, dtype=np.uint)

result_path = os.path.join(config.RESULTS_PATH, 'jasperRidge')

# Ground Truth

In [None]:
from utils import plot_endmembers, show_abundance
fig = plot_endmembers(dataset.endmembers(), wv, ticks_range=(0, 1), n_ticks=5)
plt.show(fig)
# fig.savefig(os.path.join(result_path, 'imgs/M_ref.pdf'), bbox_inches='tight')

fig = show_abundance(dataset.abundance())
plt.show(fig)
# fig.savefig(os.path.join(result_path, 'imgs/A_ref.png'), dpi=300, bbox_inches='tight')

In [None]:
from HySpecLab.metrics import sad

def sort_endmember(endmembers, gt):
    sad_result = sad(endmembers, gt)
    e_idx = torch.argmin(sad_result, dim=0) # Index for reordering the ground truth
    return e_idx, sad_result

In [None]:
from HySpecLab.eea import VCA

n_endmembers = dataset.n_endmembers
   
vca = VCA(n_endmembers, snr_input=-1, random_state=25)
vca.fit(dataset.X.numpy())
endmembers = torch.from_numpy(vca.endmembers()).float()
e_idx, sad_result = sort_endmember(endmembers, dataset.endmembers())

vca_endmember_init = endmembers[e_idx]
vca_logit_endmember_init = torch.log((vca_endmember_init + 1e-12) / ((1-vca_endmember_init) + 1e-12))

fig = plot_endmembers(vca_endmember_init, wv, ticks_range=(0, 1))
plt.show(fig)

# fig.savefig(os.path.join(result_path, 'imgs/M_vca.pdf'), bbox_inches='tight')

In [None]:
from utils import plot_endmembers
from pysptools import eea
n_endmembers = dataset.n_endmembers

ee = eea.NFINDR()
endmember = torch.from_numpy(ee.extract(dataset.image(), n_endmembers)).float()

e_idx, _ = sort_endmember(endmember, dataset.endmembers())
nfindr_endmember_init = endmember[e_idx]
nfindr_logit_endmember_init = torch.log((nfindr_endmember_init + 1e-12) / ((1-vca_endmember_init) + 1e-12))

fig = plot_endmembers(nfindr_endmember_init, wv, ticks_range=(0, 1))
plt.show(fig)
# fig.savefig(os.path.join(result_path, 'imgs/M_nfindr.pdf'), bbox_inches='tight')

In [None]:
fig = plot_endmembers(dataset.endmembers() / dataset.endmembers().max(), wv, ticks_range=(0, 1), endmember_estimation=[nfindr_endmember_init, vca_endmember_init], ee_labels=['Ground Truth', 'N-FINDR', 'VCA'])
plt.show(fig)

In [None]:
# endmember_init_method = 'nfindr'
# endmember_init = nfindr_endmember_init
# logit_endmember_init = nfindr_logit_endmember_init

endmember_init_method = 'vca'
endmember_init = vca_endmember_init

In [None]:
# UnDIP input
X_tensor = torch.tensor(dataset.image().T)
X_tensor = X_tensor.unsqueeze(0).float()

from HySpecLab.unmixing import get_noise, NOISE_TYPE
noisy_input = get_noise(X_tensor.shape[1:], batch_size = 1, noise_type=NOISE_TYPE.uniform)
U_tensor = torch.unsqueeze(endmember_init.T, dim=0).float()

In [None]:
from HySpecLab.unmixing import UnDIP

n_bands = dataset.n_bands

dims = [n_bands, 256, 256]
skip_connection = [4, 4]
out_channels = 60

model = UnDIP(n_endmembers, out_channels, dims, skip_connection, dropout=False, batch_norm=True, activation_func=nn.LeakyReLU())

In [None]:
img = dataset.image().T
plt.subplot(1, 2, 1)
plt.imshow(X_tensor[0,25])
plt.subplot(1, 2, 2)
# plt.imshow(X_tensor.T[:,:,0,25])
plt.show()

In [None]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from HySpecLab.utils import fig_to_image
from HySpecLab.unmixing.utils import restoration
from HySpecLab.metrics import UnmixingLoss

batch_size = X_tensor.shape[0]
n_bands = X_tensor.shape[1]
w, h = noisy_input.shape[2:]

n_epoch = 3000
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)
# criterion = nn.MSELoss()
criterion = UnmixingLoss()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

epoch_iterator = tqdm(
        range(n_epoch),
        leave=True,
        unit="epoch",
        postfix={"tls": "%.4f" % 1},
    )

# tb_writer = SummaryWriter('logs/test')

# #Endmember signal image
# fig, ax = plt.subplots(1,1, figsize=(16,9))

# labels = list(map(lambda x: 'Endmember {}'.format(x), range(len(U))))
# ax.plot(U.T, label=labels)
# ax.set_ylabel('Reflectance')

# image = ToTensor()(fig_to_image(fig)).unsqueeze(0)
# tb_writer.add_image('Endmembers', image, dataformats='NCHW')


# Target Image
# show_band_idx = np.linspace(0, X_tensor.shape[1]-1, num=16, dtype=np.int64)
# for i in range(4):
#     target_imgs = torch.unsqueeze(X_tensor[i, show_band_idx], dim=1)
#     img_grid = make_grid(target_imgs)
    # tb_writer.add_image('Target/{}'.format(i), img_grid, 0)


noisy_input = noisy_input.to(device)

for epoch in epoch_iterator:
    abundance = model(noisy_input)

    output = restoration(U_tensor.to(device), abundance)

    # if epoch % 100 == 0: # Cada 100 epoch
    #     for i in range(4):
    #         rest_imgs = torch.unsqueeze(output[i, show_band_idx], dim=1)
    #         img_grid = make_grid(rest_imgs)
    #         tb_writer.add_image('Output/{}'.format(i), img_grid, epoch)

    #         abundance_imgs = torch.unsqueeze(abundance[i], dim=1)
    #         img_grid = make_grid(abundance_imgs)
    #         tb_writer.add_image('Abundance/{}'.format(i), img_grid, epoch)

    batch_loss = criterion(output, X_tensor.float().to(device))

    epoch_iterator.set_postfix(tls="%.4f" % np.mean(batch_loss.detach().item()))
    # tb_writer.add_scalar('Loss', batch_loss.detach().item(), epoch)

    optimizer.zero_grad()
    batch_loss.backward()
    optimizer.step()
    # scheduler.step()

In [None]:
from torchvision.transforms import ToPILImage
from matplotlib import pyplot as plt

A = model(noisy_input).cpu().detach()
fig, ax = plt.subplots(1,4, figsize=(6,4))
for i in range(len(A[0])):
    if i >= 4:
        break
    
    cax = ax[i].imshow(A[0,i].T, cmap='viridis')
    fig.colorbar(cax, ax=ax[i])
    ax[i].grid(False)
    # for j in range(len(A[0])):
    #     ax[i, j].imshow(A[i,j])
    
plt.show()

In [None]:
A = (model(noisy_input).cpu().detach()[0])
Y = torch.matmul(endmember_init.T, A.flatten(1)).numpy()
plt.imshow(A[0], cmap='viridis')
plt.show()

In [None]:
_Y = Y.reshape(n_bands, w, h, order='F')
plt.imshow(_Y[25], cmap='viridis')
plt.show()

In [None]:
n_bands, n_row, n_col = dataset.n_bands, dataset.n_row, dataset.n_col
Y_true =  (dataset.A @ dataset.endmembers()).numpy().reshape(n_row, n_col, n_bands).T
plt.imshow(Y_true[25], cmap='viridis')
plt.show()

In [None]:
rmse_y = rmse(torch.tensor(_Y), torch.tensor(Y_true), dim=None)
rmse_y