The purpose of this notebook is to refactor pre-existing code to make it generalizable to any dataset.

*(c) Amir Reza Vazifeh and Jatearoon (Keene) Boondicharern*

In [None]:
import sys
print(sys.version)

In [None]:
import matplotlib.pyplot as plt
from scipy.io import loadmat
import numpy as np
import earthpy.plot as ep
import cv2
import os
import torch
from torch import nn
from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
from collections import OrderedDict
from sklearn.metrics import mean_squared_error
from skimage.metrics import structural_similarity as ssim
from scipy.signal import wiener
import math
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import pickle

# Load Data

In [None]:
def load_simple_data(parent, image = ["PaviaU"]):
    """
    Loads simple PaviaU and Salinas (currently misformatted) data from mat file.
    """
    for p in image:
        x = loadmat(os.path.join(parent, p, p) + ".mat")
        x = x[p[0].lower() + p[1:]]
        y = loadmat(os.path.join(parent, p, p) + "_gt.mat")
        y = y[p[0].lower() + p[1:] + '_gt']
    return x, y

In [None]:
data, gt = load_simple_data("/home/bigkeenus/Desktop/JASON/Datasets/")
print(f"Data Shape: {data.shape[:-1]}\nNumber of Bands: {data.shape[-1]}")
# ep.plot_bands(gt, cmap='nipy_spectral', title='Ground Truth of PaviaU', figsize=(10, 8))
plt.show()
plt.imshow(data[:, :, 2])

In [None]:
def normalize(image, a=2, b=-1, h = 1/1024):
    """
    Normalizes image to the range of [0, 1].

    Given a final image X, adjust normalization range by performing:

    a * X + b
    """
    # Normalize the image to [0, 1]
    min_val = image.min()
    max_val = image.max()
    image_normalized = (image - min_val) / (max_val - min_val + h)

    # Scale to [-1, 1]
    image_scaled = a * image_normalized + b
    return image_scaled

In [None]:
def create_hyperspectral_video(image, video_name, colormap='jet', fps=10):
    """
    Create a video from a hyperspectral image with a fixed colorbar.

    Parameters:
    - image: Hyperspectral image (Height x Width x Bands).
    - video_name: Name of the output video file.
    - colormap: Colormap to apply (default: 'jet').
    - fps: Frames per second for the video (default: 10).
    """
    height, width, bands = image.shape
    temp_dir = "frames_temp"
    os.makedirs(temp_dir, exist_ok=True)

    # Prepare the figure
    fig, ax = plt.subplots(figsize=(8, 6))
    plt.subplots_adjust(left=0.05, right=1.00, top=0.95, bottom=0.05)
    # colorbar_ax = fig.add_axes([0.75, 0.2, 0.03, 0.6])  # Adjusted colorbar position

    # Save each frame
    frame_paths = []
    for band in range(bands):
        ax.clear()
        band_data = image[:, :, band]
        im = ax.imshow(band_data, cmap=colormap, vmin=np.min(image), vmax=np.max(image))
        ax.set_title(f"Band {band + 1}", fontsize=14)
        ax.axis('off')

        # Add a fixed colorbar
        if band == 0:  # Only add colorbar for the first frame
            # colorbar = fig.colorbar(im, cax=colorbar_ax)
            colorbar = fig.colorbar(im)
            colorbar.set_label('Intensity', fontsize=12)

        # Save frame as an image
        frame_path = os.path.join(temp_dir, f"frame_{band:03d}.png")
        plt.savefig(frame_path, dpi=200, bbox_inches='tight')
        frame_paths.append(frame_path)

    plt.close(fig)

    # Compile the frames into a video
    first_frame = cv2.imread(frame_paths[0])
    video_height, video_width, _ = first_frame.shape
    video_writer = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (video_width, video_height))

    for frame_path in frame_paths:
        frame = cv2.imread(frame_path)
        video_writer.write(frame)

    video_writer.release()

    # Clean up temporary files
    for frame_path in frame_paths:
        os.remove(frame_path)
    os.rmdir(temp_dir)

# Implicit Neural Network

In [None]:
def make_grid(w, h, dim=2, low=-1, hi=1):
    """
    creates a dim-dimensional meshgrid of size (w, h) values ranging from [low, hi]
    """
    # Create the coordinate grid for x and y dimensions
    x_coords = torch.linspace(low, hi, steps=w)
    y_coords = torch.linspace(low, hi, steps=h)

    # Generate the meshgrid for coordinates
    grid = torch.meshgrid(x_coords, y_coords)

    # Stack the meshgrid and reshape to match the required flattened shape
    mgrid = torch.stack(grid, dim=-1).reshape(-1, dim)

    return mgrid

In [None]:
# All SIREN code has been moved into another class
from siren import SineLayer, Siren
from PosEnc import PosEncMLP

In [None]:
# Combined Dataset class for both grayscale and RGB images
class ImageFittingWithMask(Dataset):
    def __init__(self, img_tensor, mask_ratio=0.6, mask=None):
        super().__init__()

        self.img_shape = img_tensor.shape[1:]
        num_channels = img_tensor.shape[0]
        self.pixels = img_tensor.permute(1, 2, 0).reshape(-1, num_channels)  # Handle grayscale (1 channel) or RGB (3 channels)
        self.coords = make_grid(self.img_shape[0], self.img_shape[1], 2)

        # If a mask is provided, use it, otherwise generate a random mask
        if mask is not None:
            mask = cv2.resize(mask, (self.img_shape[1], self.img_shape[0]))
            mask = torch.tensor(mask, dtype=torch.bool).view(-1)
        else:
            num_pixels = self.pixels.shape[0]
            random_mask = torch.randperm(num_pixels)[:int(num_pixels * (1 - mask_ratio))]
            mask = torch.zeros(num_pixels, dtype=torch.bool)
            mask[random_mask] = True

        self.mask = mask

        # Apply the mask to training data
        self.visible_coords = self.coords[self.mask]
        self.visible_pixels = self.pixels[self.mask]

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        if idx > 0:
            raise IndexError
        return self.visible_coords, self.visible_pixels


# Make a Mask

In [None]:
def generate_hyperspectral_masks(height=610, width=340, bands=103, num_masks=10, mask_size_ratio=0.1):
    if (isinstance(num_masks, int)):
        np.random.seed(42)  # For reproducibility
        selected_bands = np.random.choice(bands, num_masks, replace=False)
    else:
        selected_bands = num_masks
        num_masks = selected_bands.shape[0] # channel dimension
    # make the grid
    W = np.linspace(0, width, width, endpoint=False, dtype=int)
    H = np.linspace(0, height, height, endpoint=False, dtype=int)
    coords = np.meshgrid(W, H)
    coords = np.stack(coords, axis=-1).reshape(-1, 2) # form a list of coordinates
    np.random.shuffle(coords) # randomly shuffle coordinates
    
    masks = {}
    groupings = None
    if (1/num_masks == mask_size_ratio):
        # assuming that it is an even split
        groupings = np.array_split(coords, num_masks)
        for idx, b in enumerate(selected_bands):
            masks[b] = list(groupings[idx])
    else:
        print(f"WARNING: number of masks {num_masks} is not the same as the mask size ratio {mask_size_ratio}")
        pixels_per_band = int(mask_size_ratio * height*width)
        # the split is not perfect so randomly select groups that will overlap in some way shape or form
        coords_index = np.array(list(range(coords.shape[0])))
        for b in selected_bands:
            choice = np.random.choice(coords_index, pixels_per_band)
            masks[b] = list(coords[choice]) 
    return masks

# Generate masks
h, w, c = np.shape(data)
hyperspectral_masks = generate_hyperspectral_masks(height=h, width=w, bands=c, num_masks=10, mask_size_ratio=0.1)

# Example: Print masks for the first selected band
print(f"There are {len(hyperspectral_masks.keys())}:")
for i, mask in enumerate(hyperspectral_masks.values()):
    print(f"Mask {i+1}: {mask[:5]} ... {len(mask)} pixels")  # Print first 5 coordinates for preview

In [None]:
def visualize_masks(masks, height=610, width=340, vis = False):
    band_masks = list(masks.values())
    num_masks = len(band_masks)
    num_rows = (num_masks + 4) // 5  # Ensure 5 images per row
    
    fig, axes = plt.subplots(num_rows, 5, figsize=(35, 10 * num_rows))

    # If there's only one row, axes will not be a 2D array
    if num_rows == 1 and vis:
        axes = np.expand_dims(axes, axis=0)
    if (not vis):
        plt.close(fig)
    all_masks = []
    for i, ax in enumerate(axes.flat):
        if i < num_masks:
            mask_img = np.zeros((height, width), dtype=np.uint8)
            for x, y in band_masks[i]:
                mask_img[y, x] = 1
            if (vis):
                ax.imshow(mask_img, cmap='gray')
                ax.set_title(f'Mask {i+1}', fontsize=24)
            all_masks.append(mask_img)
        else:
            if (vis):
                ax.axis('off')  # Hide unused subplots
    if (vis):
        ax.axis('off')
        plt.show()
    return all_masks

# Visualize masks
all_masks = visualize_masks(hyperspectral_masks, height=h, width=w)

In [None]:
running_matrix_sum = np.zeros_like(all_masks[0])
for m in all_masks:
    running_matrix_sum += m
print(f"sum of elements - matrix size (should be 0): {running_matrix_sum.sum() - running_matrix_sum.shape[0]*running_matrix_sum.shape[1]}, max: {running_matrix_sum.max()} min: {running_matrix_sum.min()}")
# plt.imshow(running_matrix_sum)

# Generalized Training Methods

In this section, a generalized data loading and hyperspectral data loader should be written.

In [None]:
from skimage.metrics import structural_similarity as ssim1

def mseloss(X, Y):
    return ((X - Y) ** 2).mean()
def psnr(X, Y):
    return 10 * np.log10(np.max(Y)**2 / mseloss(X, Y))
def ssim(X, Y, multichannel=True):
    if (multichannel):
        return np.mean([ssim1(Y[:, :, i], X[:, :, i], data_range=2) for i in range(X.shape[-1])])
    return ssim1(X, Y, data_range=1, multichannel=True)

In [None]:
def compute_metrics(X, Y):
    return mseloss(X, Y), psnr(X, Y), ssim(X, Y)
    
def fit_image(hsi, masks, train_steps=2501, siren_model=None, lr = 1e-4):
    """
    Given pre-selected normalized hyperspectral image bands with dimension (H, W, C)
    As well as the masks perform reconstruction
    """
    width, height = hsi.shape[1], hsi.shape[-1]

    mses, psnrs, ssims = ([], [], [])
    reconstructed, gt = ([], [])
    for idx, HSI in enumerate(hsi):
        channel = HSI
        channel_tensor = torch.from_numpy(channel).float()
        channel_tensor = channel_tensor.unsqueeze(0)
        
        target_image = ImageFittingWithMask(
            img_tensor=channel_tensor,
            mask_ratio=0.9, # irrelevant since we provide the masks
            mask=masks[idx, :, :])
        dataloader = DataLoader(target_image, batch_size=1, pin_memory=True, num_workers=0)

        if (siren_model is None):
            siren = Siren(
                in_features=2,
                out_features=channel_tensor.size()[0],
                hidden_features=256,
                hidden_layers=3,
                outermost_linear=True,
                activation="SINE")
        elif (siren_model == "RELUPE"):
            siren = PosEncMLP(
                in_features=2, 
                out_features=channel_tensor.size()[0], 
                hidden_features=256,
                hidden_layers=3, 
                num_encoding_freqs=15, 
                include_input=True, 
                outermost_linear=True
            )
        elif (siren_model == "RELU"):
            siren = PosEncMLP(
                in_features=2, 
                out_features=channel_tensor.size()[0], 
                hidden_features=256,
                hidden_layers=3, 
                num_encoding_freqs=0, 
                include_input=True, 
                outermost_linear=True
            )
        siren.cuda() # comment this out if testing locally
        optim = torch.optim.Adam(lr=lr, params=siren.parameters())
        
        original_pixels = target_image.pixels.view(width, height).numpy()
        
        model_input, ground_truth = next(iter(dataloader))
        model_input, ground_truth = model_input.cuda(), ground_truth.cuda()
        pbar = tqdm(range(train_steps))
        for step in pbar:
            output, coords = siren(model_input)
            loss = mseloss(output, ground_truth)
            # losses.append(loss.detach())
            pbar.set_description(
                f"Step {step+1}/{train_steps}. Total loss = {loss}")
            optim.zero_grad()
            loss.backward()
            optim.step()

        with torch.no_grad():
            full_output, _ = siren(target_image.coords.cuda())
            full_output = full_output.cpu().view(width, height).detach().numpy()

            inpainted_image = original_pixels.copy()
            inpainted_image[~target_image.mask.view(width, height).numpy()] = full_output[~target_image.mask.view(width, height).numpy()]

            # clip 
            predpix = np.clip((inpainted_image + 1) / 2, 0, 1)
            ogpix = np.clip((original_pixels + 1) / 2, 0, 1)
            
            # mse00, psnr00, ssim00 = compute_metrics(ogpix, predpix)
            # mses.append(mse00)
            # psnrs.append(psnr00)
            # ssims.append(ssim00)
            reconstructed.append(predpix)
            gt.append(ogpix)
    return reconstructed, gt


In [None]:
def prepare_img(img, num_channels=3, mask_p=1/3, create_mask=True, band_selection=None):
    """
    Given an image produce a mask and an image shape
    """
    image_transposed = np.transpose(img, (2, 0, 1))
    h = img.shape[0]
    w = img.shape[1]
    # if a selection is not provided, create one
    if (band_selection is None):
        selected = np.random.choice(image_transposed.shape[0], num_channels, replace=False)
    else:
        selected = np.array(band_selection)
    
    all_masks_np_array = None
    # should I create a mask, or do I just process images
    if (create_mask):
        hsmask = generate_hyperspectral_masks(height=h, width=w, num_masks=selected, mask_size_ratio=mask_p)
        all_masks = visualize_masks(hsmask, height=h, width=w, vis = True)
        all_masks_np_array = np.array(all_masks) # Moves the first axis to the last position
        
    img_selected = image_transposed[selected, :, :]
    img_selected = normalize(img_selected)    
    return img_selected, all_masks_np_array

def generate_hsi_masks(datacube, number_masks=10, randomized=False):
    """
    Given a datacube with defined dimensions (W, H, C), select a fixed number of hyperspectral masks
    """    
    # do I select a uniform selection of bands, or a random set?
    if randomized:
        band_selection = np.random.choice(datacube.shape[-1], number_masks, replace=False)
    else:
        band_selection = np.linspace(0, datacube.shape[-1]-1, number_masks, dtype=int)
    # select relevant bands
    data_selection = datacube[:, :, band_selection]
    
    _, hyperspectral_masks  = prepare_img(data_selection, num_channels=number_masks, mask_p=1./number_masks)
    print(f"selected bands: {band_selection} final data shape: {hyperspectral_masks.shape}")
    return band_selection, hyperspectral_masks, data_selection 

def compute_dataset_statistics(datacube, dataset_name, number_masks=10, train_steps=2501):
    # don't use this mask.
    # selected - bands used
    bands_selected, _, data_selection = generate_hsi_masks(datacube, number_masks=number_masks)
    selected = np.array([i for i in range(number_masks)]) # use this for the actual bands selected
    """
    Run the training loop to generate statistics
    """
    save_data = {f"{dataset_name}": {}}
    MSE_stat, PSNR_stat, SSIM_stat = ([], [], [])
    for i in range(3, number_masks+1):
        # load each image
        dtemp = np.array(data_selection[..., :i])
        # prepare each image
        img_selected, hsi_masks = prepare_img(dtemp, mask_p=1./i, band_selection=selected[:i], create_mask=True)
        print(f"{i}/{number_masks}")
        i += 1
        recon, gt = (None, None)
        recon, gt = fit_image(img_selected, hsi_masks, train_steps=train_steps)
        full_recon = np.transpose(np.array(recon), (1, 2, 0))
        full_ground = np.transpose(np.array(gt), (1, 2, 0))
        
        mse2, psnr2, ssim2 = compute_metrics(full_recon, full_ground)
        
        MSE_stat.append(mse2)
        PSNR_stat.append(psnr2)
        SSIM_stat.append(ssim2) 
        # save data to pickle
        save_data[dataset_name][f"{i}"] = {
            "reconstructions": full_recon,
            "ground_truth": full_ground,
            "bands": bands_selected,
            "masks": hsi_masks
        }
            
        if (i == number_masks):
            # generate hyperspectral videos
            create_hyperspectral_video(full_recon, dataset_name + '_reconstruction.mp4')
            create_hyperspectral_video(full_ground, dataset_name + '_truth.mp4')
    with open(f'{dataset_name}_checkpoint.pickle', 'wb') as file:
        pickle.dump(save_data, file)
    return MSE_stat, PSNR_stat, SSIM_stat
    
def create_fig(mse_stat, psnr_stat, ssim_stat, name):
    plt.figure()
    X = [i for i in range(3, 11)]
    fig, ax1 = plt.subplots()
    
    # Plot on the left y-axis
#     ax1.plot(X[2:], mse_stat[2:], 'b-', label="MSE") # forget about display MSE
    ax1.plot(X[2:], ssim_stat[2:], 'b-', label="SSIM")
    ax1.set_xlabel("Number of Channels")
    ax1.set_ylabel("Squared", color='b')
    
    # Create a twin axis for the right side
    ax2 = ax1.twinx()
    ax2.plot(X[2:], psnr_stat[2:], 'r--', label="PSNR")
    ax2.set_ylabel("PSNR (dB)", color='r')
    
    # Adding legends
    ax1.legend(loc="upper left")
    ax2.legend(loc="upper right")
    
    # Show the plot
    plt.title(f"{name} Reconstruction Performance")
    plt.show()

## PaviaU

In [None]:
mse_stat, psnr_stat, ssim_stat = compute_dataset_statistics(data, 'pavia_u_relu', train_steps=2501)

In [None]:
create_fig(mse_stat, psnr_stat, ssim_stat, 'PaviaUReLU')

## Real Images

In [None]:
real_path = '/home/bigkeenus/Desktop/JASON/Datasets/CZ_hsdb'
real_dir = os.listdir(real_path)

In [None]:
real_image = loadmat(os.path.join(real_path, real_dir[2]))['ref']

In [None]:
# generate masks
num_masks = 10

band_selection = np.random.choice(real_image.shape[-1], num_masks, replace=False)
data_selection = real_image[:, :, band_selection]
band_selection = np.array([i for i in range(10)])
_, hyperspectral_masks  = prepare_img(data_selection, num_channels=num_masks)
print(f"bands: {data_selection.shape} final data shape: {hyperspectral_masks.shape}")

In [None]:
plt.imshow(data_selection[...,1])

In [None]:
mse_stat, psnr_stat, ssim_stat = compute_dataset_statistics(data_selection, 'CZ_hsdb_relu', train_steps=2501)

In [None]:
create_fig(mse_stat, psnr_stat, ssim_stat, 'CZ_hsdbReLU')

## Skin Cancer

In [None]:
from PIL import Image
import io

skin_cancer_path = '/home/bigkeenus/Desktop/JASON/Datasets/PaperData/PD1C1/Image.bin'
skin_cancer_image = np.fromfile(skin_cancer_path, dtype=np.float32)
skin_cancer_cube = skin_cancer_image.reshape(100, 1000, 1000)
skin_cancer_cube = np.transpose(skin_cancer_cube, [1, 2, 0])

In [None]:
plt.imshow(skin_cancer_cube[..., 30])

In [None]:
mse_stat, psnr_stat, ssim_stat = compute_dataset_statistics(skin_cancer_cube, 'dermatology_relu', train_steps=2501)

In [None]:
create_fig(mse_stat, psnr_stat, ssim_stat, 'Dermatology')

## HS-SOD

In [None]:
import h5py

hssod_path = '/home/bigkeenus/Desktop/JASON/Datasets/HS-SOD/hyperspectral/0006.mat'
with h5py.File(hssod_path, 'r') as f:
    hssod_image = np.array(f['hypercube'])
    hssod_image = np.transpose(hssod_image, [1, 2, 0])

In [None]:
plt.imshow(hssod_image[..., 80])
print(hssod_image.shape)

In [None]:
mse_stat, psnr_stat, ssim_stat = compute_dataset_statistics(hssod_image, 'HS-SOD_relu', train_steps=2501)

In [None]:
create_fig(mse_stat, psnr_stat, ssim_stat, 'HS-SODReLU')

## Crops

In [None]:
import tifffile as tiff

folder_selection = 5
subimage_selection = 3

crops_path = '/home/bigkeenus/Desktop/JASON/Datasets/hyspecnet-11k/patches'
patches = os.listdir(crops_path)
subpatch_path = os.path.join(crops_path, patches[folder_selection])
subpatches = os.listdir(subpatch_path)
image_directory = os.path.join(subpatch_path, subpatches[subimage_selection])
image_bands = os.listdir(image_directory)
##
target_image = None
for t in image_bands:
    if ('SPECTRAL_IMAGE' in t):
        target_image = tiff.imread(os.path.join(image_directory, t))
        break
target_image = np.transpose(target_image, [1, 2, 0])
print(f"Final selected image shape {target_image.shape}")

In [None]:
plt.imshow(target_image[..., 0])

In [None]:
mse_stat, psnr_stat, ssim_stat = compute_dataset_statistics(target_image, 'HySpecNet-11K_relu', train_steps=2501)

In [None]:
create_fig(mse_stat, psnr_stat, ssim_stat, 'HySpecNet-11KReLU')

## Kodak

In [None]:
def read_kodak(m):
    kodak_path = '/home/bigkeenus/Desktop/JASON/Datasets/archive/'
    kodak_img = os.listdir(kodak_path)
    img = cv2.imread(os.path.join(kodak_path, kodak_img[m]))
    img2 = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img2

In [None]:
kodak1 = read_kodak(1)

In [None]:
mse_stat, psnr_stat, ssim_stat = compute_dataset_statistics(kodak1, 'KODAK1ReLU', number_masks=3, train_steps=2501)

In [None]:
kodak2 = read_kodak(2)
mse_stat, psnr_stat, ssim_stat = compute_dataset_statistics(kodak2, 'KODAK2ReLU', number_masks=3, train_steps=2501)

In [None]:
kodak3 = read_kodak(4)
plt.imshow(kodak3)
mse_stat, psnr_stat, ssim_stat = compute_dataset_statistics(kodak3, 'KODAK3ReLU', number_masks=3, train_steps=2501)