This script provides an interactive tool for visualizing hyperspectral imaging (HSI) data from `.mat` or `.npy` files. It displays a two-panel window showing a specific spectral band of the HSI cube on one side and a plot for spectral signatures on the other. Users can left-click a single pixel or click-and-drag to select a region of interest (ROI) on the image to plot its corresponding spectrum. The tool also allows for right-clicking on the spectral plot to clear all selections and pressing the `s` key to save the currently displayed spectra to a CSV file.

In [None]:
import os 
import h5py
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from scipy.io import loadmat

def load_mat_data(filepath, variable_name=None, transpose=False):
    """
    Loads a numerical array from a .mat file, trying both h5py and scipy.

    This function is robust and can handle both modern (v7.3+) and legacy
    .mat file formats. It can automatically detect the variable name or use a
    specified one.

    Args:
        filepath (str): The full path to the .mat file.

        variable_name (str, optional): The specific name of the variable to load
            from the .mat file. If None, the function will attempt to load the
            first non-private variable it finds. Defaults to None.

        transpose (bool, optional): If True, the loaded data matrix will be
            transposed before being returned. Defaults to False.

    Returns:
            np.array: The data loaded from the .mat file.
    """

    try:  # Method 1: Try h5py (for v7.3+ .mat files)
        with h5py.File(filepath, "r") as f:
            if variable_name:
                key = variable_name
            else:
                key = [k for k in f.keys() if not k.startswith("#")][0]
            data = np.array(f[key])
        print(f"Loaded '{filepath}' using h5py.")
    except Exception:  # Method 2: Try scipy.io.loadmat
        try:
            mat_contents = loadmat(filepath)
            if variable_name:
                key = variable_name
            else:
                key = [k for k in mat_contents.keys() if not k.startswith("__")][0]
            data = mat_contents[key]
            print(f"Loaded '{filepath}' using scipy.")
        except Exception as e:
            raise IOError(f"Failed to load data from {filepath}") from e

    if transpose:
        data = data.T
    print(f"  Final matrix shape: {data.shape}")
    return data


def interactive_hsi_plotter(
    hsi_data,
    reference_band,
    wavelengths=None,
    variable_name=None,
    save_dir="hsi_spectra",
):
    """
    Creates an advanced, interactive plot for exploring HSI data.

    Features:
    - Displays a reference band image with a selection box that appears as you drag.
    - LEFT-CLICK & DRAG to select a Region of Interest (ROI) and plot its average spectrum.
    - LEFT-CLICK plots a single pixel's spectrum.
    - RIGHT-CLICK on the spectrum plot to clear all plotted spectra.
    - Press the 's' key to SAVE the currently plotted spectra to a CSV file.
    - Displays live pixel coordinates in the status bar.

    Args:
        hsi_data (str or np.array): File path or pre-loaded 3D NumPy array [height, width, bands].

        reference_band (int): The index of the band to display as the main image.

        wavelengths (np.array, optional): An array of wavelength values for the x-axis.

        variable_name (str, optional): The variable name inside the .mat file.

        save_dir (str): Directory to save exported spectra into.
    """
    hsi_cube = None
    plot_title = "Interactive HSI Explorer"

    if isinstance(hsi_data, str):
        hsi_filepath = hsi_data
        print(f"\n--- Loading HSI data from: {hsi_filepath} ---")
        plot_title = f"Interactive HSI Explorer: {os.path.basename(hsi_filepath)}"
        if hsi_filepath.endswith(".npy"):
            hsi_cube = np.load(hsi_filepath)
        elif hsi_filepath.endswith(".mat"):
            hsi_cube = load_mat_data(hsi_filepath, variable_name=variable_name)
        else:
            raise ValueError("File path must end in .npy or .mat")
    elif isinstance(hsi_data, np.ndarray):
        print("\n--- Using pre-loaded HSI data ---")
        hsi_cube = hsi_data
    else:
        raise TypeError("hsi_data must be a file path (str) or a NumPy array.")

    if hsi_cube.ndim != 3:
        raise ValueError(
            f"Data must be a 3D cube (height, width, bands). Got {hsi_cube.ndim} dimensions."
        )

    height, width, n_bands = hsi_cube.shape
    print(
        f"Data loaded successfully. Shape: (Height: {height}, Width: {width}, Bands: {n_bands})"
    )

    # Plotting Setup
    plt.style.use("seaborn-v0_8-whitegrid")
    fig, (ax_img, ax_spec) = plt.subplots(1, 2, figsize=(15, 7))
    fig.suptitle(plot_title, fontsize=16)

    # Display the reference band image
    ref_image = hsi_cube[:, :, reference_band]
    ax_img.imshow(ref_image, cmap="jet", aspect="auto")
    ax_img.set_title(f"Reference Image (Band {reference_band})")
    ax_img.set_xlabel("Width")
    ax_img.set_ylabel("Height")
    ax_img.format_coord = (
        lambda x, y: f"x={int(x + 0.5)}, y={int(y + 0.5)}"
    )  # Live coordinates

    # Prepare the spectrum plot
    x_axis = wavelengths if wavelengths is not None else np.arange(n_bands)
    x_label = "Wavelength" if wavelengths is not None else "Band Number"
    ax_spec.set_title(
        'Click/Drag on image. Right-click here to clear. Press "s" to save.'
    )
    ax_spec.set_xlabel(x_label)
    ax_spec.set_ylabel("Intensity")
    ax_spec.grid(True)

    # Interactivity State and Callbacks
    # Store the state of the drag action
    drag_state = {"start_xy": None, "rect": None}

    def on_press(event):
        """Callback for mouse button press."""
        if event.inaxes != ax_img or event.button != 1:
            return
        drag_state["start_xy"] = (event.xdata, event.ydata)
        drag_state["rect"] = Rectangle(
            (event.xdata, event.ydata),
            0,
            0,
            facecolor="red",
            edgecolor="black",
            alpha=0.2,
            fill=True,
        )
        ax_img.add_patch(drag_state["rect"])
        fig.canvas.draw_idle()

    def on_motion(event):
        """Callback for mouse motion (dragging)."""
        if drag_state["start_xy"] is None or event.inaxes != ax_img:
            return
        x0, y0 = drag_state["start_xy"]
        x1, y1 = event.xdata, event.ydata
        drag_state["rect"].set_width(x1 - x0)
        drag_state["rect"].set_height(y1 - y0)
        drag_state["rect"].set_xy((x0, y0))
        fig.canvas.draw_idle()

    def on_release(event):
        """Callback for mouse button release."""
        if drag_state["start_xy"] is None or event.button != 1:
            return

        # Remove the visual rectangle
        if drag_state["rect"] in ax_img.patches:
            drag_state["rect"].remove()

        x1, y1 = drag_state["start_xy"]
        x2, y2 = event.xdata, event.ydata

        # Reset the drag state
        drag_state["start_xy"] = None
        drag_state["rect"] = None

        if x2 is None or y2 is None:
            return  # Click was outside axes

        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)

        if x1 == x2 and y1 == y2:  # Single pixel click
            if 0 <= x1 < width and 0 <= y1 < height:
                spectrum = hsi_cube[y1, x1, :]
                label = f"Pixel ({x1}, {y1})"
            else:
                return
        else:  # ROI selection
            x_start, x_end = sorted([x1, x2])
            y_start, y_end = sorted([y1, y2])
            x_start, x_end = max(0, x_start), min(width - 1, x_end)
            y_start, y_end = max(0, y_start), min(height - 1, y_end)
            roi = hsi_cube[y_start : y_end + 1, x_start : x_end + 1, :]
            spectrum = np.mean(roi, axis=(0, 1))
            label = f"ROI Avg ({x_start}:{x_end}, {y_start}:{y_end})"

        ax_spec.plot(x_axis, spectrum, label=label)
        ax_spec.set_title("Spectra Comparison")
        ax_spec.legend()
        fig.canvas.draw_idle()

    def on_key_press(event):
        """Callback for saving spectra."""
        if event.key == "s":
            if not ax_spec.lines:
                print("No spectra to save.")
                return
            os.makedirs(save_dir, exist_ok=True)
            from datetime import datetime

            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            save_filename = os.path.join(save_dir, f"spectra_{timestamp}.csv")
            header = "wavelength," + ",".join(
                [line.get_label() for line in ax_spec.lines]
            )
            data_to_save = [line.get_ydata() for line in ax_spec.lines]
            wavelength_data = x_axis if wavelengths is not None else np.arange(n_bands)
            output_data = np.vstack([wavelength_data] + data_to_save).T
            np.savetxt(
                save_filename, output_data, delimiter=",", header=header, comments=""
            )
            print(f"\nSpectra saved to: {save_filename}")

    def on_right_click(event):
        """Callback for clearing the plot."""
        if event.inaxes == ax_spec and event.button == 3:
            ax_spec.clear()
            ax_spec.set_title(
                'Click/Drag on image. Right-click here to clear. Press "s" to save.'
            )
            ax_spec.set_xlabel(x_label)
            ax_spec.set_ylabel("Intensity")
            ax_spec.grid(True)
            fig.canvas.draw_idle()

    # Connect Events
    fig.canvas.mpl_connect("key_press_event", on_key_press)
    fig.canvas.mpl_connect("button_press_event", on_press)
    fig.canvas.mpl_connect("motion_notify_event", on_motion)
    fig.canvas.mpl_connect("button_release_event", on_release)
    fig.canvas.mpl_connect("button_press_event", on_right_click)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    print("--- Interactive plot is now active. Close the window to continue. ---")
    plt.show()


if __name__ == "__main__":

    path = r"path to your file"

    # Number of band in your data
    bands = 993 
    sample_wavelengths = np.linspace(400, 1000, bands)

    interactive_hsi_plotter(
        hsi_data=path, reference_band=450, wavelengths=sample_wavelengths
    )

This script provides a GUI built with OpenCV for manually
selecting and extracting fixed-size Regions of Interest (ROIs) from Hyperspectral
Imaging (HSI) data files. It is designed to process a directory of HSI files,
allowing the user to visually inspect each image and save multiple rectangular ROIs
of a predefined size.

HOW IT WORKS:
1.  Scanning: 
    The script recursively scans a specified input directory (`INPUT_DIR`)
    for HSI files with `.mat` or `.npy` extensions.

2.  Visualization: 
    For each HSI file found, it generates a single-channel greyscale
    image for display purposes. This greyscale image is created by averaging all the
    spectral bands of the HSI cube.

3.  Interactive ROI Selection: It opens an OpenCV window displaying the greyscale image.
     - A green rectangle, representing the ROI to be extracted, follows the user's mouse
       cursor. The size of this rectangle is defined by `ROI_HEIGHT` and `ROI_WIDTH`.
     - The box is centered on the cursor and is constrained to stay within the image
       boundaries.

4.  Saving ROIs:
     - When the user presses the SPACEBAR, the script extracts the corresponding region
       from the *original, full-depth HSI data cube* (not the greyscale preview).
     - This extracted ROI is then saved as a new file in the specified output directory
       (`OUTPUT_DIR`).
     - The script preserves the subfolder structure from the input directory in the
       output directory.
     - Saved ROIs are marked with a persistent blue rectangle on the image for the
       current session to prevent re-selection.
       
5.  File Handling:
     - The script automatically handles both `.mat` (MATLAB) and `.npy` (NumPy) file formats.
     - For `.mat` files, it intelligently tries to find the variable containing the HSI
       data cube by looking for the largest 3D NumPy array in the file.
     - When saving, it preserves the original file format.

User Control:
 - Mouse Movement: Positions the green selection box.
 - SPACEBAR: Saves the HSI data within the current green box to a new file.
 - 'n' key: Skips the current image and moves to the next one.
 - 'q' key: Quits the program.

Configuration:
The user needs to set the `INPUT_DIR`, `OUTPUT_DIR`, `ROI_HEIGHT`, and `ROI_WIDTH`



In [None]:
import os
import cv2
import numpy as np
import scipy.io
from tqdm import tqdm

# Configuration 
# 1. Directory containing your HSI files (can have subfolders).
INPUT_DIR = "Original Data"

# 2. Directory where the new, fixed-size ROIs will be saved.
OUTPUT_DIR = "Output"

# 3. The dimensions (height, width) of the rectangular ROI.
ROI_HEIGHT = 20
ROI_WIDTH = 20

# Global variable for mouse position
mouse_coords = (0, 0)


def find_hsi_variable_in_mat(mat_contents):
    """For .mat files, automatically finds the most likely HSI data cube."""
    # Initialize variables to keep track of the best candidate for HSI data.
    best_candidate_name, best_candidate_data, max_size = None, None, -1
    # Iterate through all items in the loaded .mat file.
    for key, value in mat_contents.items():
        # Skip internal MATLAB variables which typically start with '__'.
        if key.startswith("__"):
            continue
        # Check if the item is a NumPy array, has 3 dimensions (height, width, bands),
        # and is larger than any previous candidate found.
        if isinstance(value, np.ndarray) and value.ndim == 3 and value.size > max_size:
            # If it's a better candidate, update our tracking variables.
            max_size, best_candidate_name, best_candidate_data = value.size, key, value
    # Return the name and data of the largest 3D array found.
    return best_candidate_name, best_candidate_data


def load_hsi_file(file_path):
    """Loads HSI data from either a .mat or .npy file."""
    # Check if the file is a .mat file (case-insensitive).
    if file_path.lower().endswith(".mat"):
        # Load the entire .mat file into a dictionary.
        mat_contents = scipy.io.loadmat(file_path)
        # Automatically find the most likely HSI data and its variable name.
        mat_var_name, hsi_data = find_hsi_variable_in_mat(mat_contents)
        return hsi_data, mat_var_name
    # Check if the file is a .npy file (case-insensitive).
    elif file_path.lower().endswith(".npy"):
        # Load the NumPy array directly from the file.
        hsi_data = np.load(file_path)
        # For .npy files, there's no MATLAB variable name, so return None for it.
        return hsi_data, None
    # If the file format is not supported, return None for both values.
    return None, None


def save_hsi_file(file_path, hsi_data, mat_var_name=None):
    """Saves HSI data to either a .mat or .npy file."""
    # Check if the desired output file is a .mat file.
    if file_path.lower().endswith(".mat"):
        # Save the data as a .mat file, using the original variable name.
        scipy.io.savemat(file_path, {mat_var_name: hsi_data})
    # Check if the desired output file is a .npy file.
    elif file_path.lower().endswith(".npy"):
        # Save the data as a .npy file.
        np.save(file_path, hsi_data)


def create_greyscale_from_hsi(hsi_data):
    """Creates a single-channel 8-bit greyscale image from an HSI data cube."""
    # Create a 2D greyscale image by averaging the pixel values across all spectral bands (axis=2).
    greyscale_img = np.mean(hsi_data, axis=2)
    # Normalize the greyscale image to the range [0, 255] to make it displayable.
    normalized_img = cv2.normalize(greyscale_img, None, 0, 255, cv2.NORM_MINMAX)
    # Convert the floating-point image to an 8-bit unsigned integer format required by OpenCV.
    return normalized_img.astype(np.uint8)


def mouse_move_callback(event, x, y, flags, param):
    """A simple callback that just updates the global mouse coordinates."""
    global mouse_coords
    # Check if the event is a mouse movement.
    if event == cv2.EVENT_MOUSEMOVE:
        # Update the global mouse_coords variable with the current (x, y) position.
        mouse_coords = (x, y)


def process_files_with_sliding_window():
    """
    Main function to loop through files and allow user to select ROIs.
    """
    global mouse_coords

    # 1. Find all files recursively
    all_files = []
    # Walk through the input directory and all its subdirectories.
    for root, _, files in os.walk(INPUT_DIR):
        for file in files:
            # Check if the file has a supported extension.
            if file.lower().endswith((".mat", ".npy")):
                # If so, add its full path to our list of files to process.
                all_files.append(os.path.join(root, file))

    # Check if any files were found.
    if not all_files:
        print(f"Error: No .mat or .npy files found in '{INPUT_DIR}' or its subfolders.")
        return

    # Print instructions for the user.
    print("--- Starting Fixed-Size ROI Selector ---")
    print(f"ROI size is set to: {ROI_WIDTH}x{ROI_HEIGHT} pixels (Width x Height).")
    print("Instructions:")
    print(" - Move your mouse to position the selection box.")
    print(" - Press SPACEBAR to SAVE the ROI inside the box (it will turn blue).")
    print(" - Press 'n' to SKIP to the next image.")
    print(" - Press 'q' to QUIT the program.")
    print("-------------------------------------\n")

    # Loop through each found file, with a progress bar from tqdm.
    for file_path in tqdm(all_files, desc="Processing files"):
        try:
            # Load the HSI data from the current file.
            hsi_data, mat_var_name = load_hsi_file(file_path)
            # If loading failed, print a warning and skip to the next file.
            if hsi_data is None:
                print(
                    f"Warning: Could not load {os.path.basename(file_path)}. Skipping."
                )
                continue

            # Prepare for display
            # Get the dimensions (height, width, number of bands) of the HSI cube.
            img_height, img_width, _ = hsi_data.shape

            # Create a displayable 8-bit greyscale image from the HSI data.
            base_img = create_greyscale_from_hsi(hsi_data)
            # Convert the single-channel greyscale image to a 3-channel BGR image
            # so we can draw colored rectangles on it.
            base_img_bgr = cv2.cvtColor(base_img, cv2.COLOR_GRAY2BGR)

            # Setup OpenCV window
            window_name = f"ROI Selector - {os.path.basename(file_path)}"
            cv2.namedWindow(window_name, cv2.WINDOW_AUTOSIZE)
            # Link our mouse callback function to this specific window.
            cv2.setMouseCallback(window_name, mouse_move_callback)

            # Initialize a counter for the ROIs saved from this specific image.
            roi_counter = 0
            # A list to store the coordinates of saved ROIs for this image, to redraw them.
            saved_rois_for_current_image = []

            # Main interactive loop for the current image
            while True:
                # Create a fresh copy of the base image in each frame to avoid
                # drawing artifacts from previous frames.
                display_img = base_img_bgr.copy()

                # Draw all the previously saved ROIs for this image in blue.
                # This gives the user persistent feedback on what they've already saved.
                for sx, sy, ex, ey in saved_rois_for_current_image:
                    cv2.rectangle(
                        display_img, (sx, sy), (ex, ey), (255, 0, 0), 2
                    )  # Blue color, 2px thickness

                # Calculate current ROI box coordinates, handling edges
                # Calculate half-width and half-height for centering the box on the cursor.
                half_w = ROI_WIDTH // 2
                half_h = ROI_HEIGHT // 2
                # Clamp the center coordinates to ensure the ROI box never goes outside
                # the image boundaries.
                center_x = np.clip(mouse_coords[0], half_w, img_width - half_w - 1)
                center_y = np.clip(mouse_coords[1], half_h, img_height - half_h - 1)

                # Calculate the top-left (start) and bottom-right (end) corners of the ROI box.
                start_x = center_x - half_w
                start_y = center_y - half_h
                end_x = start_x + ROI_WIDTH
                end_y = start_y + ROI_HEIGHT

                # Draw the active ROI selection box in green.
                cv2.rectangle(
                    display_img, (start_x, start_y), (end_x, end_y), (0, 255, 0), 2
                ) # Green color, 2px thickness

                # Display the image with the drawn rectangles.
                cv2.imshow(window_name, display_img)
                # Wait for a key press for 1 millisecond.
                key = cv2.waitKey(1) & 0xFF

                # Handle user input
                if key == ord("q"):
                    print("\nQuitting program.")
                    cv2.destroyAllWindows()
                    return

                elif key == ord("n"):
                    break

                elif key == 32:
                    # Increment the counter for this image's ROIs.
                    roi_counter += 1

                    # Crop the region from the original, full-depth HSI data cube.
                    cropped_hsi = hsi_data[start_y:end_y, start_x:end_x, :]

                    # Add the coordinates of this newly saved box to our list for persistent display.
                    saved_rois_for_current_image.append(
                        (start_x, start_y, end_x, end_y)
                    )

                    # Construct the output path to preserve subfolder structure
                    # Get the relative path of the file's directory with respect to the input directory.
                    relative_path = os.path.relpath(
                        os.path.dirname(file_path), INPUT_DIR
                    )
                    # Join this relative path with the main output directory.
                    final_output_dir = os.path.join(OUTPUT_DIR, relative_path)
                    # Create the output directory if it doesn't exist.
                    os.makedirs(final_output_dir, exist_ok=True)

                    # Create a unique filename for the saved ROI.
                    base_filename = os.path.splitext(os.path.basename(file_path))[0]
                    file_ext = os.path.splitext(file_path)[1]
                    output_filename = f"{base_filename}_roi_{roi_counter}{file_ext}"
                    output_path = os.path.join(final_output_dir, output_filename)

                    # Save the cropped HSI data to the new file.
                    save_hsi_file(output_path, cropped_hsi, mat_var_name)
                    # Print a confirmation message.
                    print(
                        f"  -> Saved ROI #{roi_counter} to {os.path.join(os.path.basename(final_output_dir), output_filename)}"
                    )
            
            # After breaking the loop (by pressing 'n'), destroy the current image window.
            cv2.destroyWindow(window_name)

        # Catch any exceptions that occur during file processing.
        except Exception as e:
            print(
                f"\nAn error occurred while processing {os.path.basename(file_path)}: {e}"
            )
            # Clean up any open windows in case of an error.
            cv2.destroyAllWindows()
            # Continue to the next file in the main loop.
            continue

    print("\n--- All files processed. ---")


if __name__ == "__main__":
    if not os.path.exists(INPUT_DIR):
        print(
            f"Input directory '{INPUT_DIR}' not found. Please create it and add your files."
        )
    else:
        process_files_with_sliding_window()


This script is designed to train and evaluate a 2D Convolutional Neural Network (CNN) for
classifying HSI data. It automates the entire pipeline, from
loading pre-processed data to performing robust evaluation using stratified k-fold
cross-validation.

HOW IT WORKS:
1.  Data Loading: 
    The script expects a main data folder (`main_data_folder`) that contains
    subfolders for each class. Each subfolder should contain the pre-processed HSI samples
    (e.g., fixed-size ROIs) as `.npy` files. It loads all these samples into memory.

2.  Preprocessing:
     - The loaded data and corresponding labels are converted into NumPy arrays.
     - The labels (class names) are numerically encoded (e.g., 'ClassA' -> 0) and then
       converted into a one-hot categorical format suitable for the model's output layer.
     - The script also truncates the spectral bands to a specific number (985 in this case).

3.  Model Definition:
     - A user-defined 2D-CNN model is created using the Keras Functional API. This
       architecture can be easily modified by the user to experiment with different layers.
     - The architecture uses Depthwise Separable Convolutions. This is a highly efficient
       method for HSI data that works in two steps:
       a) The `DepthwiseConv2D` layer first processes each spectral band (channel)
          independently, learning spatial features like edges and textures within that band.
       b) The following 1x1 `Conv2D` (Pointwise Convolution) then combines the outputs from
          all bands, allowing the model to learn the crucial spectral relationships and
          patterns across the entire spectrum.
          
       This separation of concerns is computationally cheaper and often more effective than
       a standard convolution for high-dimensional data.
     - The model includes standard components like Batch Normalization, ReLU activation,
       and MaxPooling.
     - It ends with a series of Dense (fully connected) layers for classification,
       using L2 regularization to prevent overfitting.
     - The final output layer uses a 'softmax' activation function to produce class
       probabilities.
     - The model is compiled with the Adam optimizer and 'categorical_crossentropy' loss,
       which is standard for multi-class classification.

4.  Cross-Validation & Training:
     - To ensure the model's performance is robust and not dependent on a single random
       train-test split, the script uses Stratified K-Fold cross-validation. This technique
       divides the data into 'k' folds (sets), ensuring that each fold has the same
       proportion of class labels as the original dataset.
     - The script iterates 'k' times. In each iteration (fold), it uses one fold as the
       test set and the remaining k-1 folds as the training set.
     - The model is trained on the training set and evaluated on the test set for each fold.
     - The model from each fold is saved for later use.

5.  Evaluation:
     - After all folds are completed, the script calculates the average accuracy and
       standard deviation across all folds, providing a reliable measure of the model's
       overall performance.
     - It also generates a comprehensive classification report (including precision, recall,
       and F1-score for each class) based on the combined predictions from all folds.

Configuration:
The user needs to set the `main_data_folder`, `n_splits`, `epochs`, and `batch_size`


In [None]:
import os
import numpy as np
import tensorflow as tf
from keras.models import Model
from keras.layers import (
    Conv2D,
    DepthwiseConv2D,
    MaxPooling2D,
    Flatten,
    Dense,
    Input,
    BatchNormalization,
    ReLU,
)
from keras.regularizers import l2
from keras.optimizers import Adam
from keras.utils import to_categorical
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder

# Configuration
# 1. Path to the folder of small, pre-resized data.
# This folder should contain subdirectories, where each subdirectory name is a class label.
main_data_folder = "fixed_size_rois"

# 2. Set the training parameters
n_splits = 5 
epochs = 100
batch_size = 8 


# A safety check to ensure the user has updated the default path.
if (
    not os.path.exists(main_data_folder)
    or main_data_folder == "path/to/your/main_data_folder"
):
    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    print("!!! IMPORTANT: Please update the main_data_folder variable !!!")
    print("!!! with the correct path to your data before running.   !!!")
    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
else:
    print("--- Starting HSI Model Training and Evaluation ---")

    # 1. Load All Pre-processed Data and Labels
    print(f"\n[Phase 1] Loading pre-processed data from '{main_data_folder}'...")

    try:
        # Automatically find class names by listing the subdirectories in the main data folder.
        class_names = sorted(
            [
                d
                for d in os.listdir(main_data_folder)
                if os.path.isdir(os.path.join(main_data_folder, d))
            ]
        )
        # Exit if no class subfolders are found.
        if not class_names:
            print(f"Error: No subfolders found in '{main_data_folder}'.")
            exit()

        print(f"Found {len(class_names)} classes: {class_names}")

        # Initialize lists to hold all data samples and their corresponding labels.
        all_data = []
        all_labels = []

        # Iterate through each class folder.
        for class_name in class_names:
            class_folder = os.path.join(main_data_folder, class_name)
            # Find all .npy files within the class folder.
            npy_files = [f for f in os.listdir(class_folder) if f.endswith(".npy")]
            # Load each .npy file.
            for file_name in npy_files:
                file_path = os.path.join(class_folder, file_name)
                try:
                    hsi_data = np.load(file_path)
                    # Append the loaded data and its class label to our lists.
                    all_data.append(hsi_data)
                    all_labels.append(class_name)
                except Exception as e:
                    print(f"Warning: Could not load file {file_path}. Error: {e}")

        # Exit if no data could be loaded at all.
        if not all_data:
            print("Error: Failed to load any data. Aborting.")
            exit()

    except FileNotFoundError:
        print(f"Error: The main data folder was not found at '{main_data_folder}'.")
        exit()

    print(f"Successfully loaded {len(all_data)} total samples into memory.")

    # 2. Preprocess Data
    print("\n[Phase 2] Preprocessing data...")

    # Convert the lists of data and labels into NumPy arrays for efficient processing.
    X = np.array(all_data)
    y = np.array(all_labels)
    # Truncate the spectral dimension to the first 985 bands.
    X = X[:, :, :, :985]

    # Label Encoding
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)
    y_categorical = to_categorical(y_encoded)

    input_shape = X.shape[1:]
    num_classes = len(class_names)
    print(f"Input shape for model: {input_shape}")
    print(f"Number of classes: {num_classes}")

    # 3. Define the 2D-CNN Model
    def create_2d_cnn_model(input_shape, num_classes):
        """Defines and compiles the 2D-CNN model architecture."""
        inputs = Input(shape=input_shape)

        # A Depthwise Convolution applies a single filter per input channel (band).
        x = DepthwiseConv2D(kernel_size=(3, 3), padding="same")(inputs)
        x = BatchNormalization()(x)
        x = ReLU()(x)

        # A 1x1 Convolution (Pointwise Convolution) is used to combine the features
        # across all the bands.
        x = Conv2D(32, kernel_size=(1, 1), padding="same")(x)
        x = BatchNormalization()(x)
        x = ReLU()(x)

        # MaxPooling reduces the spatial dimensions (height, width) of the feature maps.
        x = MaxPooling2D(pool_size=(2, 2))(x)

        x = Flatten()(x)
        x = Dense(64, activation="relu", kernel_regularizer=l2(0.001))(x)
        x = Dense(128, activation="relu", kernel_regularizer=l2(0.001))(x)
        x = Dense(256, activation="relu", kernel_regularizer=l2(0.001))(x)
        outputs = Dense(num_classes, activation="softmax")(x)

        model = Model(inputs=inputs, outputs=outputs)
        model.compile(
            optimizer=Adam(
                learning_rate=0.00001
            ),
            loss="categorical_crossentropy",
            metrics=["accuracy"],
        )
        return model

    # 4. Cross-Validation and Training
    print(f"\n[Phase 3] Starting {n_splits}-Fold Cross-Validation...")

    # Initialize StratifiedKFold to ensure class distribution is preserved in each fold.
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

    # Lists to store results from each fold.
    fold_accuracies = []
    all_true_labels = []
    all_pred_labels = []

    # Loop through each fold created by StratifiedKFold.
    for fold, (train_index, test_index) in enumerate(skf.split(X, y_encoded)):
        print(f"\n--- Processing Fold {fold + 1}/{n_splits} ---")

        # Split the data into training and testing sets for the current fold.
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y_categorical[train_index], y_categorical[test_index]
        y_test_encoded = y_encoded[
            test_index
        ]

        # Create a new, fresh model for each fold.
        model = create_2d_cnn_model(input_shape, num_classes)

        # Train the model on the training data for the current fold.
        history = model.fit(
            X_train,
            y_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(
                X_test,
                y_test,
            ), 
            verbose=1, 
        )

        # Evaluate the final trained model on the test set for this fold.
        loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
        fold_accuracies.append(accuracy)
        print(f"Fold {fold + 1} Accuracy: {accuracy:.4f}")

        # Save the trained model from the fold.
        model_save_path = f"trained_model_fold_{fold+1}.keras"
        print(f"Saving model from fold {fold+1} to: {model_save_path}")
        model.save(model_save_path)

        y_pred_probs = model.predict(X_test)
        y_pred_encoded = np.argmax(y_pred_probs, axis=1)
        all_true_labels.extend(y_test_encoded)
        all_pred_labels.extend(y_pred_encoded)

    # 5. Final Evaluation
    print("\n\n[Phase 4] --- Final Evaluation Results ---")

    mean_accuracy = np.mean(fold_accuracies)
    std_accuracy = np.std(fold_accuracies)
    print(
        f"\nAverage Cross-Validation Accuracy: {mean_accuracy:.4f} (+/- {std_accuracy:.4f})"
    )

    print("\nOverall Classification Report:")
    report = classification_report(
        all_true_labels, all_pred_labels, target_names=class_names
    )
    print(report)

This script performs a post-training analysis on a set of HSI classification models
(trained via cross-validation) to understand which spectral features are most important
for identifying each class. It uses a technique called "saliency mapping" to visualize
how the model make a decision. The final output is a single plot
comparing the average "salient" spectral signatures for each class, averaged across all
cross-validation folds for a robust result.

HOW IT WORKS:
1.  Loading Models & Data:
    The script loads all the trained `.keras` models from the
    specified folder. It also prepares to load the original HSI data samples for analysis.

2.  Iteration Across Folds & Classes:
     - The outer loop iterates through each saved model (representing each fold of CV).
     - The inner loop iterates through each class defined by the user.

3.  Saliency Map Generation: For each class, the script processes every
     corresponding HSI sample to determine its "salient spectrum".
     a) A saliency map is generated for the sample. This map is a greyscale image of the
        same size as the input, where brighter pixels indicate regions that were more
        influential in the model's decision to classify it as its correct class.
     b) A threshold is calculated (e.g., the 99th percentile) to identify only the most
        significant pixels from the saliency map.
     c) This threshold creates a mask, highlighting the most important spatial locations.

4.  Salient Spectrum Extraction:
     - The mask is applied to the original HSI data cube.
     - The script then calculates the average spectral signature of *only* the pixels
       that were identified as significant by the mask. This gives the "salient spectrum"
       for that one sample.

5.  Multi-Level Averaging:
     - The salient spectrums of all samples within a class are averaged together to get a
       single, representative spectrum for that class *for that specific model/fold*.
     - After processing all models, these per-fold class averages are themselves averaged
       together. This final "grand average" is a robust representation of the important
       spectral features for each class, as it's validated across all training folds.

6.  Plotting: The final averaged salient spectrum for each class is plotted on a single
     graph against a calculated wavelength axis (400-1000 nm). This allows for direct
     visual comparison of the key spectral features that differentiate the classes.

Configuration:
The user must set the paths, the list of class names (in the correct order), and the
significance percentile

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
from tf_keras_vis.utils.scores import CategoricalScore
from tqdm import tqdm

# Configuration
# 1. Path to the folder containing all your saved .keras models from the CV folds.
MODELS_FOLDER_PATH = "Results\\Full Class"

# 2. Path to the main data folder containing the class subfolders.
MAIN_DATA_FOLDER = "fixed_size_rois"

# 3. Path for the final output plot.
OUTPUT_PLOT_PATH = "cross_validated_class_spectrum_overlay.png"

# 4. The list of your class names, in the exact same sorted order as in your training script.
CLASS_NAMES = [
    "Class 1",
    "Class 2",
    "Class 3",
    "Class 4",
    "Class 5",
    "Class 6",
    "Class 7",
]

# 5. Set the percentile for pixels to be considered "significant".
#    For example, 99 means only pixels in the top 1% of saliency scores are used.
SIGNIFICANT_PIXEL_PERCENTILE = 99


def analyze_and_average_spectrums_across_folds():
    """
    Generates a class-average spectrum plot by averaging the results
    from all trained models in the cross-validation folds.
    """
    print("--- Starting Cross-Validated Global Salient Spectrum Analysis ---")

    # 1. Validate Configuration
    # Check if the necessary input folders exist before starting.
    if not os.path.exists(MODELS_FOLDER_PATH):
        print(f"Error: Models folder not found at '{MODELS_FOLDER_PATH}'.")
        return
    if not os.path.exists(MAIN_DATA_FOLDER):
        print(f"Error: Main data folder not found at '{MAIN_DATA_FOLDER}'.")
        return

    # Find all saved model files in the specified folder.
    model_files = sorted(
        [f for f in os.listdir(MODELS_FOLDER_PATH) if f.endswith(".keras")]
    )
    if not model_files:
        print(f"Error: No .keras model files found in '{MODELS_FOLDER_PATH}'.")
        return

    # 2. Data structure to hold results from all folds
    fold_results = {name: [] for name in CLASS_NAMES}

    # 3. Outer Loop: Iterate Through Each Model/Fold
    for model_file in tqdm(model_files, desc="Processing Folds"):
        try:
            # Load the pre-trained Keras model.
            model_path = os.path.join(MODELS_FOLDER_PATH, model_file)
            model = tf.keras.models.load_model(model_path)
            # Prepare the model for saliency analysis by replacing the final activation
            replace2linear = ReplaceToLinear()
            saliency_tool = Saliency(model, model_modifier=replace2linear, clone=False)

            # 4. Inner Loop: Iterate Through Each Class
            for class_index, class_name in enumerate(CLASS_NAMES):
                class_folder_path = os.path.join(MAIN_DATA_FOLDER, class_name)
                if not os.path.exists(class_folder_path):
                    continue

                # List to hold the average spectrum from each sample of this class.
                sample_average_spectrums = []
                npy_files = sorted(
                    [f for f in os.listdir(class_folder_path) if f.endswith(".npy")]
                )
                if not npy_files:
                    continue

                # Process each sample in the class
                for file_name in npy_files:
                    try:
                        # Load a single HSI data sample.
                        sample_path = os.path.join(class_folder_path, file_name)
                        sample_hsi = np.load(sample_path)
                        # Ensure the band count matches what the model was trained on.
                        sample_hsi = sample_hsi[:, :, :985]
                        # Add a batch dimension to the sample for model prediction.
                        sample_batch = np.expand_dims(sample_hsi, axis=0)

                        # Define the target for the saliency tool
                        score = CategoricalScore([class_index])
                        # Generate the saliency map. This calculates the gradient of the output
                        # with respect to the input, indicating pixel importance.
                        saliency_map = saliency_tool(
                            score, sample_batch, smooth_samples=20, smooth_noise=0.05
                        )
                        # Remove the batch dimension from the resulting map.
                        saliency_map_2d = saliency_map[0]

                        # Calculate the threshold value for identifying significant pixels.
                        threshold = np.percentile(
                            saliency_map_2d, SIGNIFICANT_PIXEL_PERCENTILE
                        )
                        # Create a boolean mask where True indicates a significant pixel.
                        significant_mask = saliency_map_2d >= threshold

                        if np.any(significant_mask):
                            # calculate the mean spectrum of only those significant pixels.
                            avg_spec_for_sample = np.mean(
                                sample_hsi[significant_mask], axis=0
                            )
                            sample_average_spectrums.append(avg_spec_for_sample)
                    except Exception:
                        continue 

                # Calculate the average spectrum for this entire class for the current fold.
                if sample_average_spectrums:
                    class_avg_for_fold = np.mean(sample_average_spectrums, axis=0)
                    fold_results[class_name].append(class_avg_for_fold)

        except Exception as e:
            print(f"\nAn error occurred while processing model {model_file}: {e}")
            continue

    # 5. Calculate the Final Grand Average Across All Folds
    final_class_spectrums = {}
    # Iterate through the results collected from all folds.
    for class_name, spectrum_list in fold_results.items():
        if spectrum_list:
            # Average the per-fold results to get the final, robust spectrum for each class.
            final_class_spectrums[class_name] = np.mean(spectrum_list, axis=0)

    if not final_class_spectrums:
        print("\nError: No data was processed successfully. Cannot generate plot.")
        return

    # 6. Plot the Final Overlay Graph
    num_bands = list(final_class_spectrums.values())[0].shape[0]
    wavelengths = np.linspace(400, 1000, num_bands)

    print("\n--- Generating Final Comparison Plot ---")
    fig, ax = plt.subplots(figsize=(15, 8))
    colors = plt.cm.get_cmap("tab10", len(final_class_spectrums))

    for i, (class_name, spectrum) in enumerate(final_class_spectrums.items()):
        ax.plot(wavelengths, spectrum, label=class_name, color=colors(i))

    ax.set_title("Cross-Validated Comparison of Average Salient Spectrums", fontsize=16)
    ax.set_xlabel("Wavelength (nm)", fontsize=12)
    ax.set_ylabel("Average Intensity / Reflectance", fontsize=12)
    ax.legend()
    ax.grid(True, linestyle="--", alpha=0.6)

    plt.savefig(OUTPUT_PLOT_PATH, bbox_inches="tight")
    plt.show()

    print(f"\n--- Analysis complete. Final plot saved to '{OUTPUT_PLOT_PATH}' ---")

if __name__ == "__main__":
    analyze_and_average_spectrums_across_folds()

This script analyzes a set of trained HSI classification models (those
using a depthwise separable convolution structure) to determine the overall importance
of each spectral band for the classification task. Instead of looking at individual
samples, it inspects the model's internal weights to find global patterns. The final
output is a plot showing the calculated importance score for each spectral band,
highlighting the bands the models collectively found most useful.

HOW IT WORKS:
1.  Loading Models: The script scans a specified folder and loads all trained `.keras`
    models, assuming each model corresponds to a fold from a cross-validation process.

2.  Weight Extraction:
     - For each loaded model, the script searches for the first 1x1 `Conv2D` layer. In a
       depthwise separable convolution block, this "pointwise" layer is responsible for
       combining the features learned from each individual spectral band.
     - The script extracts the weights of this layer. The magnitude (absolute value) of
       these weights can be interpreted as a proxy for the importance the model has
       assigned to each input band. A larger weight suggests the model relies more heavily
       on the features from that corresponding band to make its decisions.
     - It calculates the mean absolute weight for each band across all output filters.

3.  Averaging Across Folds:
     - The band importance vector calculated from each model (fold) is stored.
     - After processing all models, these vectors are averaged together. This step ensures
       the final result is robust and represents a consensus of importance across all
       cross-validation folds, rather than being skewed by a single training run.

4.  Plotting & Reporting:
     - A wavelength axis is generated based on the specified spectral range (e.g., 400-1000 nm).
     - The final average importance scores are plotted against their corresponding wavelengths.
     - The script identifies the top N most important bands, prints their details (band index,
       wavelength, and importance score) to the console, and highlights them on the plot
       with vertical lines and text labels for easy identification.

Configuration:
The user must set the paths, the number of top bands to highlight, and the wavelength
range

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm import tqdm

# Configuration
# 1. Path to the folder containing all saved .keras models from the CV folds.
MODELS_FOLDER_PATH = "Results\\4 Class (1,3,5,7)\\depth wise"

# 2. Path for the final output plot.
OUTPUT_PLOT_PATH = "overall_band_importance.png"

# 3. Number of top bands to highlight on the plot.
NUM_TOP_BANDS_TO_HIGHLIGHT = 10

# 4. Define the spectral range of your data.
START_WAVELENGTH = 400 
END_WAVELENGTH = 1000


def analyze_band_importance():
    """
    Loads all trained models from a folder, calculates the average importance
    of each spectral band, and plots the final result against wavelength.
    """
    print("--- Starting Overall Band Importance Analysis ---")

    # 1. Find all saved .keras models
    if not os.path.exists(MODELS_FOLDER_PATH):
        print(f"Error: Models folder not found at '{MODELS_FOLDER_PATH}'.")
        return

    model_files = sorted(
        [f for f in os.listdir(MODELS_FOLDER_PATH) if f.endswith(".keras")]
    )
    if not model_files:
        print(f"Error: No .keras model files found in '{MODELS_FOLDER_PATH}'.")
        return

    print(f"Found {len(model_files)} models to analyze.")

    # List to store the importance vector from each model/fold.
    all_fold_importances = []

    # 2. Loop Through Each Model
    for model_file in tqdm(model_files, desc="Analyzing models"):
        try:
            # Load one of the trained models.
            model_path = os.path.join(MODELS_FOLDER_PATH, model_file)
            model = tf.keras.models.load_model(model_path)

            # Find the target pointwise convolution layer
            target_layer = None
            for layer in model.layers:
                if isinstance(layer, tf.keras.layers.Conv2D) and layer.kernel_size == (
                    1,
                    1,
                ):
                    target_layer = layer
                    break

            if target_layer is None:
                print(
                    f"\nWarning: Could not find a pointwise layer in {model_file}. Skipping."
                )
                continue

            # Extract and process weights
            weights, _ = target_layer.get_weights()
            # The weights have a shape like (1, 1, num_bands, num_filters).
            # Squeeze removes the dimensions of size 1, resulting in (num_bands, num_filters).
            weights = np.squeeze(weights)
            absolute_weights = np.abs(weights)
            # For each band, calculate the mean importance across all filters in the layer.
            mean_importance_per_band = np.mean(absolute_weights, axis=1)
            all_fold_importances.append(mean_importance_per_band)

        except Exception as e:
            print(f"\nAn error occurred while processing {model_file}: {e}")
            continue

    if not all_fold_importances:
        print("\nError: Could not extract weights from any model. Aborting.")
        return

    # 3. Calculate the Final Average Importance
    # Average the importance vectors from all folds to get a final, robust result.
    final_average_importance = np.mean(all_fold_importances, axis=0)

    # 4. Create the Wavelength Axis
    num_bands = len(final_average_importance)
    wavelengths = np.linspace(START_WAVELENGTH, END_WAVELENGTH, num_bands)

    print("\n--- Generating Final Band Importance Plot ---")

    # 5. Find and Print the Top N Most Important Bands
    # Get the indices of the bands sorted by importance in descending order.
    top_band_indices = np.argsort(final_average_importance)[::-1][
        :NUM_TOP_BANDS_TO_HIGHLIGHT
    ]

    print(f"\nTop {NUM_TOP_BANDS_TO_HIGHLIGHT} most important bands:")

    for i, band_index in enumerate(top_band_indices):
        importance_score = final_average_importance[band_index]
        wavelength = wavelengths[band_index]
        print(
            f"  {i+1}. Band {band_index} (~{wavelength:.1f} nm) (Importance: {importance_score:.4f})"
        )

    # 6. Plot the Results with Wavelength Axis
    fig, ax = plt.subplots(figsize=(15, 8))

    ax.plot(
        wavelengths, final_average_importance, color="navy", label="Average Importance"
    )

    for band_index in top_band_indices:
        importance_score = final_average_importance[band_index]
        wavelength = wavelengths[band_index]
        ax.axvline(x=wavelength, color="red", linestyle="--", alpha=0.7)
        ax.text(
            wavelength + 5,
            importance_score,
            f"{wavelength:.0f} nm",
            color="red",
            va="center",
        )

    ax.set_title(
        "Overall Importance of Each Spectral Band Across All CV Folds", fontsize=16
    )
    ax.set_xlabel("Wavelength (nm)", fontsize=12)
    ax.set_ylabel("Average Absolute Weight (Importance)", fontsize=12)
    ax.grid(True, linestyle="--", alpha=0.6)
    ax.set_xlim(START_WAVELENGTH, END_WAVELENGTH)
    ax.legend()

    plt.savefig(OUTPUT_PLOT_PATH, bbox_inches="tight")
    plt.show()

    print(f"\n--- Analysis complete. Final plot saved to '{OUTPUT_PLOT_PATH}' ---")


if __name__ == "__main__":
    analyze_band_importance()