# Hyperspectral Data Analysis with Convolutional Autoencoders

This notebook demonstrates how to use the hyperspectral convolutional autoencoder modules to:
1. Load and preprocess hyperspectral data
2. Handle variable emission band lengths and masked (NaN) values
3. Train and evaluate a convolutional autoencoder with sigmoid activations
4. Visualize results with various methods

In [None]:
# Import standard libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import time
from pathlib import Path

# Import our custom modules
from hyperspectral_dataset import HyperspectralDataset, load_hyperspectral_data
from hyperspectral_models import HyperspectralCAEVariable
from hyperspectral_training import train_variable_cae, evaluate_model
from hyperspectral_visualization import (
    visualize_training_loss,
    visualize_emission_spectrum,
    visualize_multiple_spectra,
    visualize_spatial_slice,
    visualize_feature_maps,
    visualize_reconstruction_comparison,
    visualize_multiple_excitations,
    visualize_all_spectral_bands
)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Load Hyperspectral Data

First, we load the data from a pickle file created by the HyperspectralDataLoader.

In [None]:
# Set the path to your data file
data_path = "../Data/Kiwi Experiment/pickles/masked_KiwiData.pkl"

# Load the data
data_dict = load_hyperspectral_data(data_path)

## 2. Create Dataset with Global Normalization and NaN Handling

Create a dataset that will handle NaN values and apply global normalization to the [0,1] range.

In [None]:
# Create dataset
dataset = HyperspectralDataset(
    data_dict,
    normalize=True,  # Apply global normalization to [0,1]
    downscale_factor=1  # Use full resolution (adjust based on memory constraints)
)

# Get all processed data
all_data = dataset.get_all_data()
spatial_height, spatial_width = dataset.get_spatial_dimensions()

print(f"Processed data dimensions: {spatial_height}x{spatial_width}")

# Print normalization parameters if available
if hasattr(dataset, 'normalization_params'):
    print(f"Global normalization range: [{dataset.normalization_params['min']:.4f}, {dataset.normalization_params['max']:.4f}]")

## 3. Create and Train the Convolutional Autoencoder Model

Now we'll create our model and train it using chunking for memory efficiency.

In [None]:
# Create model
model = HyperspectralCAEVariable(
    excitations_data={ex: data.numpy() for ex, data in all_data.items()},
    k1=20,  # Number of filters in first layer
    k3=20,  # Number of filters in third layer
    filter_size=5,
    sparsity_target=0.1,  # Lower value for sigmoid activation
    sparsity_weight=1.0,
    dropout_rate=0.5,
    debug=False
)
print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

In [None]:
# Train model (or load a previously trained model)
train_new_model = True  # Set to False to load a previously saved model

if train_new_model:
    model, losses = train_variable_cae(
        model,
        dataset,
        num_epochs=20,
        learning_rate=0.01,
        chunk_size=32,
        chunk_overlap=8,
        early_stopping_patience=5  # Stop if no improvement for 5 epochs
    )
    
    # Save the final model (best model is saved during training)
    torch.save(model.state_dict(), "hyperspectral_cae_final_model.pth")
    print("Final model saved to hyperspectral_cae_final_model.pth")
    
    # Visualize training loss
    visualize_training_loss(losses)
else:
    # Load previously trained model
    model_path = "best_hyperspectral_model.pth"
    model.load_state_dict(torch.load(model_path, map_location=device))
    print(f"Model loaded from {model_path}")

## 4. Evaluate the Model

Generate reconstructions and evaluate the model performance.

In [None]:
# Evaluate the model
evaluation_results = evaluate_model(model, dataset, chunk_size=64, chunk_overlap=8, device=device)

# Print overall metrics
if 'overall' in evaluation_results['metrics']:
    print(f"\nOverall Metrics:")
    for metric, value in evaluation_results['metrics']['overall'].items():
        print(f"  {metric.upper()}: {value:.4f}")

## 5. Generate Reconstructions for Visualization

For visualization purposes, we'll generate reconstructions for all excitations.

In [None]:
# Generate reconstructions for all excitations
model.eval()
reconstructions = {}

for ex, data in all_data.items():
    # Add batch dimension
    data_batch = {ex: data.unsqueeze(0).to(device)}
    
    # Generate reconstruction
    with torch.no_grad():
        output = model(data_batch)
    
    # Store reconstruction
    if ex in output:
        reconstructions[ex] = output[ex][0].cpu()  # Remove batch dimension
        print(f"Generated reconstruction for excitation {ex}nm. "
              f"Shape: {reconstructions[ex].shape}, "
              f"Range: [{reconstructions[ex].min().item():.4f}, {reconstructions[ex].max().item():.4f}]")

## 6. Visualizations

Now we'll create various visualizations to analyze the results.

### 6.1 Emission Spectrum Analysis

First, let's visualize the emission spectrum at different spatial locations.

In [None]:
# Select a specific excitation wavelength to analyze
excitation_to_analyze = list(all_data.keys())[0]  # Use first excitation
print(f"Analyzing excitation wavelength: {excitation_to_analyze}nm")

# Get original and reconstructed data
original_data = all_data[excitation_to_analyze]
reconstructed_data = reconstructions[excitation_to_analyze]

# Get emission wavelengths if available
emission_wavelengths = dataset.emission_wavelengths.get(excitation_to_analyze, None)
if emission_wavelengths:
    print(f"Emission wavelength range: {min(emission_wavelengths)}nm - {max(emission_wavelengths)}nm")

# Visualize spectrum at center pixel
center_y, center_x = spatial_height // 2, spatial_width // 2
rmse = visualize_emission_spectrum(
    original_data,
    reconstructed_data,
    excitation_to_analyze,
    y=center_y,
    x=center_x,
    wavelengths=emission_wavelengths
)
print(f"Center pixel RMSE: {rmse:.4f}")

In [None]:
# Visualize multiple spectra at different positions
h_quarter, w_quarter = spatial_height // 4, spatial_width // 4
positions = [
    (h_quarter, w_quarter),
    (h_quarter, spatial_width - w_quarter),
    (spatial_height - h_quarter, w_quarter),
    (spatial_height - h_quarter, spatial_width - w_quarter)
]

rmse_values = visualize_multiple_spectra(
    original_data,
    reconstructed_data,
    excitation_to_analyze,
    positions=positions,
    wavelengths=emission_wavelengths
)
print(f"RMSE at different positions: {[f'{rmse:.4f}' for rmse in rmse_values]}")

### 6.2 Spatial Slice Analysis

Now let's look at spatial slices for specific emission bands.

In [None]:
# Visualize a spatial slice for a specific emission band
# Use the middle emission band
middle_band = original_data.shape[2] // 2

metrics = visualize_spatial_slice(
    original_data,
    reconstructed_data,
    excitation_to_analyze,
    emission_idx=middle_band,
    cmap='viridis'
)

print(f"Spatial metrics for emission band {middle_band}:")
for metric, value in metrics.items():
    print(f"  {metric}: {value:.4f}")

### 6.3 Feature Maps Visualization

Let's examine what features the model is learning.

In [None]:
# Visualize feature maps
feature_stats = visualize_feature_maps(
    model,
    {excitation_to_analyze: all_data[excitation_to_analyze]},
    num_maps=16,
    cmap='viridis',
    device=device
)

# Print feature activation statistics
print("\nFeature Map Statistics:")
print(f"  Mean activation range: [{np.min(feature_stats['means']):.4f}, {np.max(feature_stats['means']):.4f}]")
print(f"  Mean standard deviation: {np.mean(feature_stats['stds']):.4f}")

### 6.4 RGB False Color Visualization

Create false color visualizations to compare original and reconstructed data.

In [None]:
# RGB false color visualization for a single excitation
visualization_results = visualize_reconstruction_comparison(
    original_data,
    reconstructed_data,
    emission_wavelengths,
    excitation_to_analyze,
    use_consistent_normalization=True
)

# Print RGB channel RMSE values
print("\nRGB Channel RMSE Values:")
for channel, value in visualization_results['rmse'].items():
    print(f"  {channel.upper()}: {value:.4f}")

In [None]:
# Visualize multiple excitations
rmse_values = visualize_multiple_excitations(
    model,
    all_data,
    dataset.emission_wavelengths,
    use_consistent_normalization=False,
    device=device
)

# Print RMSE for each excitation
print("\nRMSE for each excitation wavelength:")
for ex, rmse in sorted(rmse_values.items()):
    print(f"  Ex={ex}nm: {rmse:.4f}")

### 6.5 Detailed Spectral Band Analysis

Visualize all spectral bands for a specific excitation.

In [None]:
# Visualize all spectral bands
band_rmse = visualize_all_spectral_bands(
    original_data,
    reconstructed_data,
    excitation_to_analyze,
    emission_wavelengths=emission_wavelengths,
    grid_size=4,  # Adjust based on number of bands
    cmap='viridis'
)

# Find best and worst reconstructed bands
best_band = min(band_rmse.items(), key=lambda x: x[1])
worst_band = max(band_rmse.items(), key=lambda x: x[1])

best_wavelength = emission_wavelengths[best_band[0]] if emission_wavelengths else best_band[0]
worst_wavelength = emission_wavelengths[worst_band[0]] if emission_wavelengths else worst_band[0]

print(f"\nBest reconstructed band: {best_wavelength}nm (RMSE: {best_band[1]:.4f})")
print(f"Worst reconstructed band: {worst_wavelength}nm (RMSE: {worst_band[1]:.4f})")

## 7. Analysis Summary

Create a summary of the model performance and visualization results.

In [None]:
# Calculate overall metrics for all excitations
overall_metrics = {}
all_rmse = []

for ex in reconstructions:
    if ex in all_data:
        orig = all_data[ex]
        recon = reconstructions[ex]
        
        # MSE
        mse = torch.mean((orig - recon) ** 2).item()
        rmse = np.sqrt(mse)
        
        # MAE
        mae = torch.mean(torch.abs(orig - recon)).item()
        
        all_rmse.append(rmse)
        
        # Store metrics
        overall_metrics[ex] = {
            'mse': mse,
            'rmse': rmse,
            'mae': mae
        }

# Print summary
print("\nModel Performance Summary:")
print(f"Number of excitation wavelengths: {len(reconstructions)}")
print(f"Average RMSE across all excitations: {np.mean(all_rmse):.4f}")
print(f"Best excitation: {min(overall_metrics.items(), key=lambda x: x[1]['rmse'])[0]}")
print(f"Worst excitation: {max(overall_metrics.items(), key=lambda x: x[1]['rmse'])[0]}")