In [None]:
import os
import matplotlib.pyplot as plt
import random
from utils import *
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

In [None]:
hsi_np, wlens = LoadHSI('../Data/cropped_hdf5/FX10_09SEPT2023_3D1_0.hdf5', return_wlens=True)

In [None]:
mask_np = read_mask('../Data/cropped_masks/FX10_09SEPT2023_3D1_0.png')

Defining image paths for healthy, early-, mid- and late-diseased leaves

In [None]:
def filter_filenames(folder_path, camera_id, date_stamps, tray_ids):
    """
    Filters filenames in a folder based on camera ID, date stamps, and tray IDs.

    Parameters:
    - folder_path (str): Path to the folder containing the files.
    - camera_id (str): Camera ID.
    - date_stamps (list): List of selected date stamps.
    - tray_ids (list): List of selected tray IDs.

    Returns:
    - list: Filtered list of full file paths that match the given criteria.
    """
    all_files = os.listdir(folder_path)

    filtered_files = [
        os.path.join(folder_path, f) for f in all_files
        if f.startswith(camera_id + "_") and 
           any(date in f for date in date_stamps) and 
           any(f.split("_")[2].startswith(tray) for tray in tray_ids)
    ]
    
    return filtered_files

In [None]:
# FX10 camera
IMG_DIR = '../Data/cropped_hdf5'
CAMERA = 'FX10'

# Healthy leaves
DATES = ['07SEPT2023', '08SEPT2023', '09SEPT2023', '10SEPT2023', '11SEPT2023', '12SEPT2023',
         '13SEPT2023', '14SEPT2023', '15SEPT2023', '18SEPT2023', '19SEPT2023']
TRAYS = ['3D', '4C', '4D', '2D']    # Some files from the FX17 camera are mistakenly named in 2D instead of 4D
healthy_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Early diseased leaves
DATES = ['07SEPT2023']
TRAYS = ['3C']
early_diseased_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Mid diseased leaves
DATES = ['08SEPT2023', '09SEPT2023']
TRAYS = ['3C']
mid_diseased_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Late diseased leaves
DATES = ['10SEPT2023', '11SEPT2023', '12SEPT2023', '13SEPT2023', '14SEPT2023', '15SEPT2023']
TRAYS = ['3C']
late_diseased_FX10 = filter_filenames(folder_path=IMG_DIR, camera_id=CAMERA, date_stamps=DATES, tray_ids=TRAYS)

# Number of samples in each category
print(f"Healthy: {len(healthy_FX10)}")
print(f"Early diseased: {len(early_diseased_FX10)}")
print(f"Mid diseased: {len(mid_diseased_FX10)}")
print(f"Late diseased: {len(late_diseased_FX10)}")

In [None]:
# Sample 12 images from healthy, mid-and late-diseased leaves
random.seed(10)
healthy_sample = random.sample(healthy_FX10, 12)
mid_diseased_sample = random.sample(mid_diseased_FX10, 12)
late_diseased_sample = random.sample(late_diseased_FX10, 12)

print(healthy_sample)
print(mid_diseased_sample)
print(late_diseased_sample)

# Concatenate the lists
pca_data = healthy_sample + early_diseased_FX10 + mid_diseased_sample + late_diseased_sample
print(len(pca_data))

In [None]:
def preprocess(hsi_np, wlens, min_wavelength=0, normalize=True):
    """
    Removes spectral bands with wavelengths below min_wavelength from the hyperspectral image.
    Also replaces negative values with 0 and applies optional normalization.

    Parameters:
    - hsi_np (nunmpy array): Hyperspectral image with shape (bands, height, width).
    - wlens (numpy array): Wavelength values with shape (bands, ) corresponding to bands in hsi_np
    - min_wavelength (int or float): Minimum wavelength (in nm) to keep in the hyperspectral data
    - normalize (bool): Whether to normalize the hyperspectral image data

    Returns:
    - hsi_np_filtered (numpy array): Hyperspectral image with selected spectral bands of shape (filtered bands, height, width)
    - wlens_filtered (numpy array): Updated wavelengths array of shape (filtered bands, )
    """
    # Determine wavelengths to keep
    valid_bands = wlens >= min_wavelength  

    # Filter hsi_np and wlens to keep only relevant bands
    hsi_np_filtered = hsi_np[valid_bands, :, :]
    wlens_filtered = wlens[valid_bands]
    
    # Set all negative values to 0 (these are noise)
    hsi_np_filtered = np.maximum(hsi_np_filtered, 0)
    
    # Normalize the data if required
    if normalize and np.max(hsi_np_filtered) > 0:    # Avoid division by zero
        hsi_np_filtered = hsi_np_filtered / np.max(hsi_np_filtered)
    
    return hsi_np_filtered, wlens_filtered

In [None]:
def load_and_flatten_hsi(img_paths, mask_dir=None, individual_normalize=False, apply_mask=False, mask_method=1, min_wavelength=0):
    """
    Transforms the 3D hyperspectral images into a 2D array by flattening the spatial dimensions.
    The resulting "rows" are the pixels and the "columns" store their values for the different spectral bands. 
    Can be used with a single- or multiple HSI-s. If multiple HSI-s are provided, they are stacked together.

    Parameters:
    - img_paths (list of str): Paths to the HSI files to be loaded and flattened
    - mask_dir (str): Path to the folder containing the masks for the HSI-s
    - individual_normalize (bool): Whether to normalize each HSI individually before flattening and stacking
    - apply_mask (bool): Whether to apply the mask to the HSI-s
    - mask_method (int): 0 for keeping only leaf, 1 for keeping leaf+stem after applying the mask
    - min_wavelength (int or float): Minimum wavelength (in nm) to keep in the hyperspectral data

    Returns:
    - numpy array, shape (n_pixels, n_bands), the flattened and stacked HSI pixels
    """
    all_pixels = []
    
    # Load wavelengths from the first image (but could be any image since images have the same spectral bands)
    _, wlens = LoadHSI(img_paths[0], return_wlens=True)
    
    for file in img_paths:
        hsi_np = LoadHSI(file)
        
        # Load and apply mask if required
        if apply_mask and mask_dir:
            mask_file = os.path.join(mask_dir, os.path.basename(file).replace('.hdf5', '.png'))    # Find the mask for the HSI (same name)
            mask_np = read_mask(mask_file)
            hsi_np = hsi_np * np.where(mask_np == 2, mask_method, mask_np)
        
        # Preprocess the hyperspectral image   
        hsi_np, _ = preprocess(hsi_np, wlens, min_wavelength=min_wavelength, normalize=individual_normalize)

        # Flatten: (bands, height, width) → (height*width, bands)
        hsi_np = hsi_np.reshape(hsi_np.shape[0], -1).T
        
        # Remove background (zero) pixels (those that were masked out)
        if apply_mask:
            hsi_np = hsi_np[~np.all(hsi_np == 0, axis=1)]

        all_pixels.append(hsi_np)
    
    # Stack all pixels together
    return np.vstack(all_pixels)

Conduct PCA

In [None]:
# Preprocess data
X = load_and_flatten_hsi(img_paths=pca_data, mask_dir='../Data/cropped_masks',
                         apply_mask=True, individual_normalize=True, mask_method=0, min_wavelength=430)
print(f"Data shape before PCA: {X.shape}")

# Standardize data
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Apply PCA
pca = PCA(n_components=30)    # Keep 30 components
X_pca = pca.fit_transform(X_scaled)

# Print results
explained_variance_ratio = pca.explained_variance_ratio_
cumulative_variance_ratio = np.cumsum(explained_variance_ratio)
print("---------------------------------")
print(f"Number of components chosen: {pca.n_components_}")
print(f"Explained variance ratio: {cumulative_variance_ratio[-1]:.4f}")
print(f"Data shape after PCA: {X_pca.shape}")

In [None]:
# Scree plot
plt.figure(figsize=(8, 6))
plt.plot(range(1, len(explained_variance_ratio) + 1), explained_variance_ratio, marker='o', label='Individual Variance')
plt.plot(range(1, len(cumulative_variance_ratio) + 1), cumulative_variance_ratio, marker='s', label='Cumulative Variance')
plt.axhline(y=0.95, color='r', linestyle='--', label='95% Variance Explained')  # Optional: mark the 95% line
plt.xlabel('Principal Component Number')
plt.ylabel('Explained Variance Ratio')
plt.title('Scree Plot')
plt.legend()
plt.grid()
plt.show()

In [None]:
# Compute loadings
loadings = pca.components_.T * np.sqrt(pca.explained_variance_)

# Plot PC loadings
plt.figure(figsize=(10, 6))
for i in range(pca.n_components_):
    plt.plot(wlens, loadings[:, i], label=f'PC {i+1}')

plt.xlabel('Wavelength (nm)')
plt.ylabel('Loading Value')
plt.title('Principal Component Loadings')
plt.legend()
plt.grid()
plt.show()

In [None]:
import matplotlib.ticker as ticker

def plot_spectra(X, wlens):
    # Initialize a list to store spectral values for averaging
    all_spectral_values = []
    
    # Plot the spectral values for the pixels
    plt.figure(figsize=(12, 6))
    for i in range(len(X)):
        all_spectral_values.append(X[i])  # Collect spectral values
        plt.plot(wlens, X[i], color='lightblue', linewidth=0.4, alpha=0.3)  # Blue color with reduced transparency
        
    # Convert the list to a numpy array for easier statistical computation
    all_spectral_values = np.array(all_spectral_values)
    
    # Compute the mean and standard deviation
    mean_spectral_values = np.mean(all_spectral_values, axis=0)
    std_spectral_values = np.std(all_spectral_values, axis=0)
    
    # Plot the mean spectral values as a thick dark blue line
    plt.plot(wlens, mean_spectral_values, color='darkblue', linewidth=3, label='Mean Spectral Value')
    
    # Plot the standard deviation as shaded areas around the mean (dark blue with transparency)
    plt.fill_between(
        wlens, 
        mean_spectral_values - std_spectral_values, 
        mean_spectral_values + std_spectral_values, 
        color='darkblue', alpha=0.5, label='±1 Std Dev'
    )
    
    # Customize ticks on both axes
    ax = plt.gca()
    # X-axis ticks
    ax.xaxis.set_major_locator(ticker.MultipleLocator(100))  # Big ticks every 100
    ax.xaxis.set_minor_locator(ticker.MultipleLocator(10))   # Small ticks every 20
    # Y-axis ticks
    ax.yaxis.set_major_locator(ticker.MultipleLocator(5000))  # Big ticks every 5000
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(1000)) # Small ticks every 1000

    # Add titles, labels, and legend
    plt.xlabel('Wavelength (nm)')
    plt.ylabel('Reflectance')
    plt.legend()
    
    # Show the plot
    plt.show()

Reconstruct

In [None]:
# Reconstruct PCA
X_reconstructed = pca.inverse_transform(X_pca)
X_reconstructed = scaler.inverse_transform(X_reconstructed)

In [None]:
print(X_pca.shape)
print(X_reconstructed.shape)

In [None]:
# Original
plot_spectra(X[5000:10000], wlens)

In [None]:
# Reconstructed
plot_spectra(X_reconstructed[5000:10000], wlens)

In [None]:
def hsi_transform_to_pca_space(hsi_np, pca, scaler, mask_np=None, mask_method=1):
    """
    Apply pre-trained PCA on HSI to reduce spectral bands, i.e. transform data to PCA space.
    If a mask is provided, only the valid non-background pixels are transformed to the PCA space with the pre-trained PCA model. Also,
    the background pixels will be set to 0 in the PCA space as well.

    Parameters:
    - hsi_np (numpy array): Hyperspectral image with shape (bands, height, width).
    - pca: Pre-fitted PCA model.
    - scaler: Pre-fitted StandardScaler model.
    - mask_np (numpy array): Mask to apply on the HSI.
    - mask_method (int): 0 for keeping only leaf, 1 for keeping leaf+stem after applying the mask.

    Returns:
    - numpy array of PCA-transformed HSI with shape (pca_components, height, width).
    """
    # Set all negative values to 0 (these are noise)
    # hsi_np = np.maximum(hsi_np, 0)    # Q: Interestingly if we do this, the reconstruction errors show vertical lines
    # A: Because we would need to change negative values to 0 in the function that calculates reconstruction error as well to the
    # "original" hsi_np that we compare the reconstructed to. → Better apply np.maximum(hsi_np, 0) before calling the function and not inside it.
    
    # Flatten: (bands, height, width) → (height*width, bands)
    hsi_flattened = hsi_np.reshape(hsi_np.shape[0], -1).T
    
    # Apply mask if provided (exclude background from PCA)
    if mask_np is not None:
        mask_np = np.where(mask_np == 2, mask_method, mask_np)
        valid_indices = mask_np.flatten() != 0
        hsi_valid = hsi_flattened[valid_indices]    # Keep only non-background pixels
    else:
        hsi_valid = hsi_flattened

    # Standardize using the previously fitted scaler
    hsi_valid_scaled = scaler.transform(hsi_valid)

    # Apply the pre-trained PCA model
    hsi_pca_valid = pca.transform(hsi_valid_scaled)

    # Return to the original number of pixels, with 0s for background pixels if mask was applied and PCA-transformed pixels for the rest
    if mask_np is not None:
        hsi_pca = np.zeros((hsi_flattened.shape[0], pca.n_components_))    # Initialize empty array with the shape of the original pixels
        hsi_pca[valid_indices] = hsi_pca_valid    # Insert only valid transformed pixels
    else:
        hsi_pca = hsi_pca_valid    # No mask applied, just return transformed pixels

    # Reshape back to (pca_components, height, width)
    return hsi_pca.T.reshape(pca.n_components_, hsi_np.shape[1], hsi_np.shape[2])

In [None]:
def reconstruct_hsi_from_pca_space(hsi_pca, pca, scaler, mask_np=None, mask_method=1):
    """
    Back-transform PCA-transformed HSI to original space.
    If a mask is provided, only the valid non-background pixels are reconstructed from the PCA space to the original space, with background
    pixels set to 0.
    
    Parameters:
    - hsi_pca (numpy array): PCA-transformed HSI with shape (pca_components, height, width).
    - pca: Pre-fitted PCA model.
    - scaler: Pre-fitted StandardScaler model.
    - mask_np (numpy array): Mask to apply on the HSI.
    - mask_method (int): 0 for keeping only leaf, 1 for keeping leaf+stem after applying the mask.

    Returns:
    - numpy array of back-transformed (reconstructed) HSI with shape (bands, height, width).
    """
    # Flatten: (pca_components, height, width) → (height*width, pca_components)
    hsi_pca_flattened = hsi_pca.reshape(pca.n_components_, -1).T
    
    # Apply mask if provided (exclude background from PCA)
    if mask_np is not None:
        mask_np = np.where(mask_np == 2, mask_method, mask_np)
        valid_indices = mask_np.flatten() != 0
        hsi_valid = hsi_pca_flattened[valid_indices]    # Keep only non-background pixels
    else:
        hsi_valid = hsi_pca_flattened

    # Apply inverse PCA
    hsi_valid_reconstructed = pca.inverse_transform(hsi_valid)

    # Apply inverse scaling
    hsi_valid_reconstructed = scaler.inverse_transform(hsi_valid_reconstructed)
    
    # Reconstruct full spatial structure if mask was applied
    if mask_np is not None:
        hsi_reconstructed = np.zeros((hsi_pca_flattened.shape[0], hsi_valid_reconstructed.shape[1]))    # Initialize empty array of shape (original pixels, original bands)
        hsi_reconstructed[valid_indices] = hsi_valid_reconstructed    # Insert only valid transformed pixels
    else:
        hsi_reconstructed = hsi_valid_reconstructed    # No mask applied, just return reconstructed pixels

    # Reshape back to (bands, height, width)
    return hsi_reconstructed.T.reshape(-1, hsi_pca.shape[1], hsi_pca.shape[2])

In [None]:
def compress_and_reconstruct_hsi_pca(hsi_np, pca, scaler, mask_np=None, mask_method=1):
    '''
    Preform PCA compression and reconstruction right after on a HSI data.
    
    Parameters:
    - hsi_np (numpy array): Hyperspectral image with shape (bands, height, width).
    - pca: Pre-fitted PCA model.
    - scaler: Pre-fitted StandardScaler model.
    - mask_np (numpy array): Mask to apply on the HSI.
    - mask_method (int): 0 for keeping only leaf, 1 for keeping leaf+stem after applying the mask.    
    '''
    # Transform data to PCA space
    hsi_pca = hsi_transform_to_pca_space(hsi_np, pca, scaler, mask_np, mask_method)
    
    # Reconstruct data from PCA space
    hsi_reconstructed = reconstruct_hsi_from_pca_space(hsi_pca, pca, scaler, mask_np, mask_method)
    
    return hsi_reconstructed

In [None]:
def get_pca_reconstruction_error(hsi_np, pca, scaler, mask_np=None, mask_method=1, show_plot=True):
    """
    Reconstruct an input HSI using the pre-trained PCA, calculate the reconstruction error and optionally plot the sum of it across bands.
    Should be used with a single HSI file.

    Parameters:
    - hsi_np (numpy array): Hyperspectral image with shape (bands, height, width)
    - pca: pre-fitted PCA model
    - scaler: pre-fitted StandardScaler model
    - mask_np (numpy array): Mask to apply on the HSI
    - mask_method (int): 0 for keeping only leaf, 1 for keeping leaf+stem after applying the mask
    - show_plot (bool): Whether to show the plot of the sum of reconstruction errors across the bands

    Returns:
    - reconstruction_error (numpy array): The reconstruction error of the HSI (for each band) with shape (bands, height, width) 
    - Plot (optional): "Heatmap" of the sum of reconstruction errors across the bands
    """
    # Apply PCA and reconstruct
    hsi_pca = hsi_transform_to_pca_space(hsi_np, pca, scaler, mask_np, mask_method)
    hsi_reconstructed = reconstruct_hsi_from_pca_space(hsi_pca, pca, scaler, mask_np, mask_method)
    
    # Compute reconstruction error
    if mask_np is not None:
        mask_np = np.where(mask_np == 2, mask_method, mask_np)
        reconstruction_error = np.abs(hsi_np - hsi_reconstructed) * mask_np
    else:
        reconstruction_error = np.abs(hsi_np - hsi_reconstructed)
        
    if show_plot:
        # Sum reconstruction error across the bands
        pixel_errors = np.sum(reconstruction_error, axis=0)

        # Plot pixel errors
        plt.figure(figsize=(8, 6))
        plt.imshow(pixel_errors, cmap='hot')
        plt.colorbar(label='Total Reconstruction Error')
        plt.title('Total Reconstruction Error per Pixel')
        plt.axis('off')
        plt.show()
    
    return reconstruction_error

In [None]:
def plot_pca_reconstruction_error_dir(hsi_path, pca, scaler, mask_dir=None, apply_mask=False, mask_method=1):
    """
    Reconstruct an input HSI using the pre-trained PCA and plot the reconstruction error.
    Should be used with a single HSI file.
    Since the hsi_np is not fed to the function, but it is loaded inside the function, we can't np.max(hsi_np, 0) to set negative values to 0,
    which is a bit undesirable, so this function is meant to used mainly for quick experimentation purposes.
    (Well, we could np.max(hsi_np, 0) inside this function but then we would also need that in the hsi_transform_to_pca_space() function.)

    Parameters:
    - hsi_path (str): Path to the HSI file.
    - pca: Pre-fitted PCA model.
    - scaler: Pre-fitted StandardScaler model.
    - mask_dir (str): Path to the folder containing the masks for the HSI-s.
    - apply_mask (bool): Whether to apply the mask to the HSI-s.
    - mask_method (int): 0 for keeping only leaf, 1 for keeping leaf+stem after applying the mask.

    Returns:
    - Plot ("heatmap") of the sum of reconstruction errors across the bands
    """
    # Load a HSI
    hsi_np = LoadHSI(hsi_path)
    
    # Load mask if required
    if apply_mask and mask_dir:
        mask_file = os.path.join(mask_dir, os.path.splitext(os.path.basename(hsi_path))[0] + ".png")    # Find the mask for the HSI (same name)
        mask_np = read_mask(mask_file)
        mask_np = np.where(mask_np == 2, mask_method, mask_np)
        
    # Apply PCA and reconstruct
    hsi_reconstructed = compress_and_reconstruct_hsi_pca(hsi_np, pca, scaler, mask_np, mask_method)
    
    # Compute reconstruction error
    if apply_mask and mask_dir:
        reconstruction_error = np.abs(hsi_np - hsi_reconstructed) * mask_np
    else:
        reconstruction_error = np.abs(hsi_np - hsi_reconstructed) 
    
    # Sum reconstruction error across the bands
    pixel_errors = np.sum(reconstruction_error, axis=0)

    # Plot pixel errors
    plt.figure(figsize=(8, 6))
    plt.imshow(pixel_errors, cmap='hot')
    plt.colorbar(label='Total Reconstruction Error')
    plt.title('Total Reconstruction Error per Pixel')
    plt.axis('off')
    plt.show()

In [None]:
reconstruction_error = get_pca_reconstruction_error(hsi_np, pca, scaler, mask_np, mask_method=1)

In [None]:
plot_pca_reconstruction_error_dir('../Data/HDF5_FILES/train/FX10_07SEPT2023_1B1.hdf5', pca, scaler
                              , mask_dir='../Data/MASKS/train', apply_mask=True, mask_method=1)

Visualize original img and img in PC space (first 3 bands/PCs) and reconstructed img (with 3 selected "RGB" bands)

In [None]:
hsi_np = np.maximum(hsi_np, 0)
# hsi_np = hsi_np  / np.max(hsi_np)    # Could add this as well but then pca would need to be trained with individual_normalize=True

In [None]:
hsi_reconstructed = compress_and_reconstruct_hsi_pca(hsi_np, pca, scaler, mask_np)
hsi_pca = hsi_transform_to_pca_space(hsi_np, pca, scaler, mask_np=mask_np, mask_method=1)

In [None]:
# we find the bands corresponding to the RGB channels
RGB_wlens = (445, 535, 575)
RGB_bands = np.argmin(np.abs(np.array(wlens)[:, np.newaxis] - RGB_wlens), axis=0)
print(f'RGB indices in hsi -> {RGB_bands}')

# we create and RGB image from the hsi by selecting those bands, 
# but first set all negative values to 0 (these are noise) and normalize the hsi
# hsi_np = np.maximum(hsi_np, 0)
hsi_np = hsi_np  / np.max(hsi_np)

# The pca space can easily have negative values and positives larger that 255. The reconstructed images should have values between 0 and 1 in
# case hsi_np was already between 0 and 1, although we can expect that reconstruction is not perfect and some values might be outside this range.
# Clipping the values between 0 and 1 is not that bad though for VISUALIZATION purposes. For VISUALIZATION PURPOSES only, we may also clip the
# values in the PCA space between 0 and 1. (This should not be done e.g. at calculating the reconstruction error.)
hsi_pca = np.maximum(hsi_pca, 0)
hsi_pca = hsi_pca / np.max(hsi_pca)

hsi_reconstructed = np.maximum(hsi_reconstructed, 0)
hsi_reconstructed = hsi_reconstructed  / np.max(hsi_reconstructed)

# select the rgb bands
rgb_img = hsi_np[RGB_bands,:,]
hsi_pca_img = hsi_pca[0:3,:,]
rgb_img_reconstructed = hsi_reconstructed[RGB_bands,:,]
print(f'shape, {rgb_img.shape}, but for other libraries usually the bands is the last dimension, so we change the order and get:')

rgb_img = rgb_img.transpose((1,2,0))
hsi_pca_img = hsi_pca_img.transpose((1,2,0))
rgb_img_reconstructed = rgb_img_reconstructed.transpose((1,2,0))
print(rgb_img.shape)

#  now we can visualize the image
plt.figure(figsize = (16,50))
plt.subplot(1,3,1)
plt.imshow(rgb_img)
plt.title('RGB')

plt.subplot(1,3,2)
plt.imshow(hsi_pca_img)
plt.title('PCA reduced HSI')

plt.subplot(1,3,3)
plt.imshow(rgb_img_reconstructed)
plt.title('Reconstructed RGB')

I believe I saw somewhere that the PCA reduced HSI that is on the plot visualizes the parts with high(est) variance. It explains why we don't see a very clearly outlined image. We would need to add the average (or average + rescale?) to get a more meaningful image. Or something like this

In [None]:
hsi_np.min(), hsi_np.max(), hsi_pca.min(), hsi_pca.max(), hsi_reconstructed.min(), hsi_reconstructed.max()

In [None]:
rgb_img.min(), rgb_img.max(), hsi_pca_img.min(), hsi_pca_img.max(), rgb_img_reconstructed.min(), rgb_img_reconstructed.max()

As we see above, the small "dirt" spot on the healthy leaf can be reconstructed with the false-RGB image, even though that part showed higher reconstruction errors. Probably other bands (not RGB) produce higher reconstruction errors there

To confirm this, plot the reconstruction error for the RGB bands only and plot it for all bands
Indeed seemingly the RGB bands don't produce high reconstruction errors at the "dirt spot".
The reconstruction errors for the different bands show that there might be some bands that are more important (produce higher reconstruction error), which steers us towards band selection

In [None]:
# we find the bands corresponding to the RGB channels
RGB_wlens = (445, 535, 575)
RGB_bands = np.argmin(np.abs(np.array(wlens)[:, np.newaxis] - RGB_wlens), axis=0)
print(f'RGB indices in hsi -> {RGB_bands}')

# select RGB bands from reconstruction error
reconstruction_error_rgb = reconstruction_error[RGB_bands,:,]
print(reconstruction_error_rgb.shape)

# Sum reconstruction error across the RGB bands
pixel_errors = np.sum(reconstruction_error_rgb, axis=0)

# Plot pixel errors
plt.figure(figsize=(8, 6))
plt.imshow(pixel_errors, cmap='hot')
plt.colorbar(label='Total Reconstruction Error')
plt.title('Total Reconstruction Error per Pixel')
plt.axis('off')
plt.show() 

In [None]:
# Plot reconstruction errors by band (channel)
for i in range(reconstruction_error.shape[0]):
    plt.imshow(reconstruction_error[i], cmap='hot')
    plt.colorbar(label='Reconstruction Error')
    wavelength = wlens[i]
    plt.title(f'Reconstruction Error for wavelength {wavelength} nm')
    plt.axis('off')
    plt.show()