In [None]:
import os, sys
project_root_dir = os.path.join(os.getcwd(),'../..')
if project_root_dir not in sys.path:
    sys.path.append(project_root_dir)

from matplotlib import pyplot as plt
import numpy as np
from torch import nn
import torch
import config

from utils import show_abundance, plot_endmembers
from dataset import JasperRidge
from HySpecLab.metrics import rmse

In [None]:
dataset = JasperRidge(config.JasperRidge_PATH)
dataset_name = 'Jasper'
wv = np.array(dataset.wv, dtype=np.uint)

result_path = os.path.join(config.RESULTS_PATH, 'jasperRidge/undip')

In [None]:
from utils import plot_endmembers, show_abundance
fig = plot_endmembers(dataset.endmembers(), wv, ticks_range=(0, 1), n_ticks=5)
plt.show(fig)

fig = show_abundance(dataset.abundance())
plt.show(fig)

In [None]:
from HySpecLab.metrics import sad

def sort_endmember(endmembers, gt):
    sad_result = sad(endmembers, gt)
    e_idx = torch.argmin(sad_result, dim=0) # Index for reordering the ground truth
    return e_idx, sad_result

In [None]:
from HySpecLab.eea import VCA

n_endmembers = dataset.n_endmembers
   
vca = VCA(n_endmembers, snr_input=-1, random_state=25)
vca.fit(dataset.X.numpy())
endmembers = torch.from_numpy(vca.endmembers()).float()
e_idx, sad_result = sort_endmember(endmembers, dataset.endmembers())

vca_endmember_init = endmembers[e_idx]
vca_logit_endmember_init = torch.log((vca_endmember_init + 1e-12) / ((1-vca_endmember_init) + 1e-12))

fig = plot_endmembers(vca_endmember_init, wv, ticks_range=(0, 1))
plt.show(fig)

In [None]:
from utils import plot_endmembers
from pysptools import eea
n_endmembers = dataset.n_endmembers

ee = eea.NFINDR()
endmember = torch.from_numpy(ee.extract(dataset.image(), n_endmembers)).float()

e_idx, _ = sort_endmember(endmember, dataset.endmembers())
nfindr_endmember_init = endmember[e_idx]
nfindr_logit_endmember_init = torch.log((nfindr_endmember_init + 1e-12) / ((1-vca_endmember_init) + 1e-12))

fig = plot_endmembers(nfindr_endmember_init, wv, ticks_range=(0, 1))
plt.show(fig)

In [None]:
fig = plot_endmembers(dataset.endmembers() / dataset.endmembers().max(), wv, ticks_range=(0, 1), endmember_estimation=[nfindr_endmember_init, vca_endmember_init], ee_labels=['Ground Truth', 'N-FINDR', 'VCA'])
plt.show(fig)
fig.savefig(os.path.join(result_path, 'imgs/M_init.pdf'), bbox_inches='tight')

In [None]:
endmember_init_method = 'nfindr'
endmember_init = nfindr_endmember_init

# endmember_init_method = 'vca'
# endmember_init = vca_endmember_init

In [None]:
from torchvision.transforms import Resize

Y = np.transpose(dataset.image(), (2, 0, 1))
X_tensor = torch.from_numpy(Y).float().unsqueeze(0)

from HySpecLab.unmixing import get_noise, NOISE_TYPE
noisy_input = get_noise(X_tensor.shape[1:], batch_size = X_tensor.shape[0], noise_type=NOISE_TYPE.uniform)
U_tensor = torch.unsqueeze(endmember_init.T, dim=0).float()

print('Z shape: {}'.format(noisy_input.shape))
print('HyperCube shape: {}'.format(X_tensor.shape))
print('Endmember shape: {}'.format(U_tensor.shape))

In [None]:
from HySpecLab.unmixing import UnDIP

n_bands = X_tensor.shape[1]
n_endmembers = U_tensor.shape[-1]

dims = [n_bands, 256, 256]
skip_connection = [6, 6]
out_channels = 64

model = UnDIP(n_endmembers, out_channels, dims, skip_connection, dropout=True, batch_norm=True, activation_func=nn.LeakyReLU())

In [None]:
from tqdm import tqdm
from HySpecLab.unmixing.utils import restoration
from HySpecLab.metrics import UnmixingLoss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

n_epoch = 5000
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)
criterion = UnmixingLoss()

model.train()
model = model.to(device)

epoch_iterator = tqdm(
        range(n_epoch),
        leave=True,
        unit="epoch",
        postfix={"tls": "%.4f" % 1},
    )

noisy_input = noisy_input.to(device)

for epoch in epoch_iterator:
    abundance = model(noisy_input)

    output = restoration(U_tensor.to(device), abundance)

    batch_loss = criterion(output, X_tensor.float().to(device))

    epoch_iterator.set_postfix(tls="%.4f" % np.mean(batch_loss.detach().item()))
    optimizer.zero_grad()
    batch_loss.backward()
    optimizer.step()
    scheduler.step()

In [None]:
model.eval()
with torch.no_grad():
    abundance = model(noisy_input).detach().cpu()

In [None]:
_abundance = abundance[0].detach().cpu().numpy().transpose(1, 2, 0)
fig = show_abundance(_abundance, transpose=False)
fig.savefig(os.path.join(result_path, f'imgs/A_estimation_{endmember_init_method}.pdf'), bbox_inches='tight')
plt.show(fig)

# Metrics

In [None]:
from torchvision.transforms import Resize, InterpolationMode

A = torch.tensor(np.transpose(dataset.abundance().numpy(), (2,0,1))).cpu()
A_hat = abundance[0].cpu()

In [None]:
M_hat = endmember_init
M = dataset.endmembers()

In [None]:
X_true = dataset.A @ dataset.endmembers()
X_true = X_true.T.reshape(198, 100, 100)

X_hat = restoration(U_tensor, abundance)
X_hat = X_hat[0]

In [None]:
import pandas as pd
df = pd.DataFrame(columns=['Method', 'RMSE_X', 'RMSE_A', 'SAD_M'])
df['Method'] = ['CLHU']
df['RMSE_X'] = [rmse(X_true, X_hat, dim=None).numpy()]
df['RMSE_A'] = [rmse(A, A_hat, dim=None).numpy()]

sad_result = sad(M_hat, dataset.endmembers()).numpy()
df['SAD_M'] = np.diagonal(sad_result).mean()

df.to_csv(os.path.join(result_path, f'metrics_{endmember_init_method}.csv'), index=False)
df

## Batch test

In [None]:
def test(model, dataset, noisy_input, U):
    model.eval()
    model = model.cpu()
    with torch.no_grad():
        abundance = model(noisy_input.cpu()).detach().cpu()

    
    A = torch.tensor(np.transpose(dataset.abundance(), (2,0,1))).cpu()
    A_hat = abundance[0].cpu()

    M_hat = U[0].T
    M = dataset.endmembers()

    X_true = dataset.A @ dataset.endmembers()
    X_true = X_true.T.reshape(198, 100, 100)

    X_hat = restoration(U, abundance)
    X_hat = X_hat[0]
    
    rmse_x = rmse(X_true, X_hat, dim=None).numpy()
    rmse_a = rmse(A, A_hat, dim=None).numpy()
    sad_m = np.diagonal(sad(M, M_hat).numpy()).mean()
    return rmse_x.item(), rmse_a.item(), sad_m

In [None]:
from tqdm import tqdm
from HySpecLab.unmixing.utils import restoration
from HySpecLab.metrics import UnmixingLoss

def train(model, X_tensor, U_tensor, noisy_input):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    n_epoch = 5000
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)
    criterion = UnmixingLoss()

    model.train()
    model = model.to(device)
    noisy_input = noisy_input.to(device)

    epoch_iterator = tqdm(
            range(n_epoch),
            leave=True,
            unit="epoch",
            postfix={"tls": "%.4f" % 1},
        )

    for _ in epoch_iterator:
        abundance = model(noisy_input)

        output = restoration(U_tensor.to(device), abundance)

        batch_loss = criterion(output, X_tensor.float().to(device))

        epoch_iterator.set_postfix(tls="%.4f" % np.mean(batch_loss.detach().item()))

        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        scheduler.step()

    noisy_input = noisy_input.cpu()
    return model.cpu()

In [None]:
from HySpecLab.unmixing import UnDIP

n_bands = X_tensor.shape[1]
n_endmembers = U_tensor.shape[-1]
dims = [n_bands, 256, 256]
skip_connection = [4, 4]
out_channels = 256

batch_rmse_x = []
batch_rmse_a = []
batch_sad_m = []
for i in range(10):
    model = UnDIP(n_endmembers, out_channels, dims, skip_connection, dropout=True, batch_norm=True, activation_func=nn.LeakyReLU())

    model = train(model, X_tensor, U_tensor, noisy_input)
    rmse_x, rmse_a, sad_m = test(model, dataset, noisy_input, U_tensor)
    batch_rmse_x.append(rmse_x)
    batch_rmse_a.append(rmse_a)
    batch_sad_m.append(sad_m)

    print(rmse_x, rmse_a, sad_m)

In [None]:
import pandas as pd
df = pd.DataFrame(columns=['RMSE_X', 'RMSE_A', 'SAD_M'])
df['RMSE_X'] = batch_rmse_x
df['RMSE_A'] = batch_rmse_a
df['SAD_M'] = batch_sad_m

# extract mean and std
df.to_csv(os.path.join(result_path, 'metrics_{}_batch.csv'.format(endmember_init_method)), index=False)
print(df['RMSE_X'].mean(), df['RMSE_X'].std(), df['RMSE_A'].mean(), df['RMSE_A'].std(), df['SAD_M'].mean(), df['SAD_M'].std())