In [1]:
import os
import time
import pickle
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from sklearn.cluster import MiniBatchKMeans

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

Using device: cuda


In [2]:
# File paths
data_path = "../Data/Lime Experiment/processed/masked_data_cutoff_30nm_exposure_max_power_min.pkl"
mask_path = "../Data/Lime Experiment/processed/mask.npy"
output_dir = "hyperspectral_results_lime"
os.makedirs(output_dir, exist_ok=True)

# Load data
print("Loading hyperspectral data...")
with open(data_path, 'rb') as f:
    data_dict = pickle.load(f)

print("Data Summary:")
print(f"Number of excitation wavelengths: {len(data_dict['excitation_wavelengths'])}")
print(f"Excitation wavelengths: {data_dict['excitation_wavelengths']}")

# Load mask
print("Loading mask...")
mask = np.load(mask_path)
valid_pixels = np.sum(mask)
total_pixels = mask.size
print(f"Mask loaded: {valid_pixels}/{total_pixels} valid pixels ({valid_pixels/total_pixels*100:.2f}%)")

# Create dataset
from AutoencoderPipeline import MaskedHyperspectralDataset

dataset = MaskedHyperspectralDataset(
    data_dict=data_dict,
    mask=mask,
    normalize=True,
    downscale_factor=1
)

# Get spatial dimensions
height, width = dataset.get_spatial_dimensions()
print(f"Data dimensions after processing: {height}x{width}")

# Get all data
all_data = dataset.get_all_data()

Loading hyperspectral data...
Data Summary:
Number of excitation wavelengths: 21
Excitation wavelengths: [300.0, 310.0, 320.0, 330.0, 340.0, 350.0, 360.0, 370.0, 380.0, 390.0, 400.0, 410.0, 420.0, 430.0, 440.0, 450.0, 460.0, 470.0, 480.0, 490.0, 500.0]
Loading mask...
Mask loaded: 62953/89088 valid pixels (70.66%)
Preparing data for 21 excitation wavelengths...
Emission band lengths for each excitation wavelength:
  - Excitation 300.0 nm: 24 bands
  - Excitation 310.0 nm: 24 bands
  - Excitation 320.0 nm: 24 bands
  - Excitation 330.0 nm: 24 bands
  - Excitation 340.0 nm: 24 bands
  - Excitation 350.0 nm: 25 bands
  - Excitation 360.0 nm: 27 bands
  - Excitation 370.0 nm: 29 bands
  - Excitation 380.0 nm: 31 bands
  - Excitation 390.0 nm: 31 bands
  - Excitation 400.0 nm: 30 bands
  - Excitation 410.0 nm: 29 bands
  - Excitation 420.0 nm: 28 bands
  - Excitation 430.0 nm: 27 bands
  - Excitation 440.0 nm: 26 bands
  - Excitation 450.0 nm: 25 bands
  - Excitation 460.0 nm: 24 bands
  - 

In [3]:
from AutoencoderPipeline import HyperspectralCAEWithMasking

model = HyperspectralCAEWithMasking(
    excitations_data={ex: data.numpy() for ex, data in all_data.items()},
    k1=20,
    k3=20,
    filter_size=5,
    sparsity_target=0.1,
    sparsity_weight=1.0,
    dropout_rate=0.5
)

print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")
model = model.to(device)

Model created with 342498 parameters


In [8]:
# IMPORTANT FIX: Make sure parameter names match
# The error you encountered was because 'chunk_overlap' should be 'overlap'
# Let's define a function to create spatial chunks with the correct parameter names

def create_chunks(data_tensor, chunk_size=64, overlap=16):
    """
    Split a large spatial hyperspectral tensor into overlapping chunks.

    Args:
        data_tensor: Input tensor of shape [height, width, emission_bands]
        chunk_size: Size of each spatial chunk
        overlap: Overlap between adjacent chunks

    Returns:
        List of chunk tensors and their positions
    """
    # Determine input shape
    if len(data_tensor.shape) == 4:  # [num_excitations, height, width, emission_bands]
        height, width = data_tensor.shape[1], data_tensor.shape[2]
    else:  # [height, width, emission_bands]
        height, width = data_tensor.shape[0], data_tensor.shape[1]

    # Calculate stride
    stride = chunk_size - overlap

    # Calculate number of chunks in each dimension
    num_chunks_y = max(1, (height - overlap) // stride)
    num_chunks_x = max(1, (width - overlap) // stride)

    # Adjust to ensure we cover the entire image
    if stride * num_chunks_y + overlap < height:
        num_chunks_y += 1
    if stride * num_chunks_x + overlap < width:
        num_chunks_x += 1

    # Create list to store chunks and their positions
    chunks = []
    positions = []

    # Extract chunks
    for i in range(num_chunks_y):
        for j in range(num_chunks_x):
            # Calculate start and end positions
            y_start = i * stride
            x_start = j * stride
            y_end = min(y_start + chunk_size, height)
            x_end = min(x_start + chunk_size, width)

            # Handle edge cases by adjusting start positions
            if y_end == height:
                y_start = max(0, height - chunk_size)
            if x_end == width:
                x_start = max(0, width - chunk_size)

            # Extract chunk based on input shape
            if len(data_tensor.shape) == 4:  # [num_excitations, height, width, emission_bands]
                chunk = data_tensor[:, y_start:y_end, x_start:x_end, :]
            else:  # [height, width, emission_bands]
                chunk = data_tensor[y_start:y_end, x_start:x_end, :]

            # Add to lists
            chunks.append(chunk)
            positions.append((y_start, y_end, x_start, x_end))

    print(f"Created {len(chunks)} chunks of size up to {chunk_size}x{chunk_size} with {overlap} overlap")
    return chunks, positions

In [9]:
# Now let's implement a custom training function that uses our fixed chunk creation

def train_model(
    model,
    dataset,
    num_epochs=30,
    learning_rate=0.001,
    chunk_size=64,
    overlap=8,
    device='cuda'
):
    # Create output directory for models
    model_dir = os.path.join(output_dir, "model")
    os.makedirs(model_dir, exist_ok=True)

    # Setup optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # Get data and mask
    all_data = dataset.get_all_data()
    mask_tensor = torch.tensor(dataset.processed_mask, dtype=torch.float32, device=device)

    # Create chunks for mask
    mask_expanded = dataset.processed_mask[..., np.newaxis]  # Add dummy dimension
    mask_chunks, mask_positions = create_chunks(mask_expanded, chunk_size, overlap)
    mask_chunks = [torch.tensor(chunk[..., 0], dtype=torch.float32).to(device) for chunk in mask_chunks]

    # Create chunks for each excitation
    print("Creating chunks for each excitation wavelength...")
    chunks_dict = {}
    positions_dict = {}

    for ex in all_data:
        data_np = all_data[ex].numpy()
        chunks, positions = create_chunks(data_np, chunk_size, overlap)
        chunks_dict[ex] = chunks
        positions_dict[ex] = positions

    # Get number of chunks
    num_chunks = len(next(iter(chunks_dict.values())))

    # Create batches
    batches = []
    mask_batches = []

    for i in range(num_chunks):
        # Data batch
        batch = {}
        for ex in chunks_dict:
            chunk = chunks_dict[ex][i]
            batch[ex] = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(device)
        batches.append(batch)

        # Mask batch
        mask_batches.append(mask_chunks[i].unsqueeze(0))  # Add batch dimension

    # Training loop
    print(f"Training for {num_epochs} epochs with {len(batches)} batches...")
    train_losses = []
    best_loss = float('inf')
    best_epoch = 0
    no_improvement_count = 0

    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        model.train()
        epoch_loss = 0.0
        epoch_recon_loss = 0.0
        epoch_sparsity_loss = 0.0

        # Train on each batch
        for i, (batch, mask_batch) in enumerate(zip(batches, mask_batches)):
            # Forward pass
            output = model(batch)

            # Compute masked reconstruction loss
            recon_loss = model.compute_masked_loss(
                output_dict=output,
                target_dict=batch,
                spatial_mask=mask_batch
            )

            # Compute sparsity loss
            encoded = model.encode(batch)
            sparsity_loss = model.compute_sparsity_loss(encoded)

            # Total loss
            loss = recon_loss + model.sparsity_weight * sparsity_loss

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            epoch_recon_loss += recon_loss.item()
            epoch_sparsity_loss += sparsity_loss.item()

            # Print progress
            if (i + 1) % 10 == 0 or i == len(batches) - 1:
                print(f"  Batch {i+1}/{len(batches)}", end="\r")

        # Record average loss
        avg_loss = epoch_loss / len(batches)
        avg_recon_loss = epoch_recon_loss / len(batches)
        avg_sparsity_loss = epoch_sparsity_loss / len(batches)
        train_losses.append(avg_loss)

        # Update scheduler
        scheduler.step(avg_loss)

        # Report progress
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f} "
              f"(Recon: {avg_recon_loss:.4f}, Sparsity: {avg_sparsity_loss:.4f}), "
              f"Time: {epoch_time:.2f}s")

        # Check if this is the best epoch
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_epoch = epoch
            no_improvement_count = 0

            # Save best model
            model_path = os.path.join(model_dir, "best_model.pth")
            torch.save(model.state_dict(), model_path)
            print(f"  New best model saved to {model_path}")
        else:
            no_improvement_count += 1
            print(f"  No improvement for {no_improvement_count} epochs (best: {best_loss:.4f} at epoch {best_epoch+1})")

            # Early stopping
            if no_improvement_count >= 5:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break

    # Save final model
    model_path = os.path.join(model_dir, "final_model.pth")
    torch.save(model.state_dict(), model_path)
    print(f"Final model saved to {model_path}")

    # Load best model
    model.load_state_dict(torch.load(os.path.join(model_dir, "best_model.pth")))

    return model, train_losses

# Run training
model, losses = train_model(
    model=model,
    dataset=dataset,
    num_epochs=30,
    learning_rate=0.001,
    chunk_size=64,
    overlap=8,
    device=device
)

# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(losses, marker='o')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.savefig(os.path.join(output_dir, "training_loss.png"))
plt.close()



Created 35 chunks of size up to 64x64 with 8 overlap
Creating chunks for each excitation wavelength...
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 overlap
Created 35 chunks of size up to 64x64 with 8 over

In [10]:
def merge_chunk_reconstructions(chunks, positions, full_height, full_width):
    """
    Merge the reconstructed chunks back into a full image.
    """
    # Determine shape from the first chunk
    first_chunk = chunks[0]

    if len(first_chunk.shape) == 4:  # [batch, height, width, emission_bands]
        batch_size, _, _, num_bands = first_chunk.shape
        merged = torch.zeros((batch_size, full_height, full_width, num_bands),
                         device=first_chunk.device)
        weights = torch.zeros((batch_size, full_height, full_width, num_bands),
                          device=first_chunk.device)
    else:
        raise ValueError(f"Unexpected chunk shape: {first_chunk.shape}")

    # Merge chunks
    for chunk, (y_start, y_end, x_start, x_end) in zip(chunks, positions):
        merged[:, y_start:y_end, x_start:x_end, :] += chunk
        weights[:, y_start:y_end, x_start:x_end, :] += 1

    # Average overlapping regions
    merged = merged / torch.clamp(weights, min=1.0)

    return merged

def evaluate_model(model, dataset, chunk_size=64, overlap=8, device='cuda'):
    """
    Evaluate the model by generating reconstructions and calculating metrics.
    """
    # Create output directory
    eval_dir = os.path.join(output_dir, "evaluation")
    os.makedirs(eval_dir, exist_ok=True)

    # Set model to evaluation mode
    model.eval()

    # Get data and mask
    all_data = dataset.get_all_data()
    mask = dataset.processed_mask

    # Store results
    results = {
        'metrics': {},
        'reconstructions': {}
    }

    print("Evaluating model...")
    with torch.no_grad():
        overall_mse = 0.0
        overall_mae = 0.0
        num_excitations = 0

        for ex in all_data:
            data = all_data[ex]

            # Create chunks for this excitation
            chunks, positions = create_chunks(data.numpy(), chunk_size, overlap)

            # Process chunks
            reconstructed_chunks = []
            for i, chunk in enumerate(chunks):
                # Convert to tensor and add batch dimension
                chunk_tensor = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(device)

                # Create input dictionary for this excitation only
                chunk_dict = {ex: chunk_tensor}

                # Generate reconstruction
                output = model(chunk_dict)

                # Add to reconstructed chunks
                if ex in output:
                    reconstructed_chunks.append(output[ex])

                # Print progress
                if (i + 1) % 20 == 0 or i == len(chunks) - 1:
                    print(f"  Chunk {i+1}/{len(chunks)} for Ex={ex}nm", end="\r")

            # Skip if no valid reconstructions
            if not reconstructed_chunks:
                print(f"Warning: No valid reconstructions for excitation {ex}")
                continue

            # Merge chunks
            full_reconstruction = merge_chunk_reconstructions(
                reconstructed_chunks, positions, height, width
            )

            # Remove batch dimension
            full_reconstruction = full_reconstruction[0]

            # Store reconstruction
            results['reconstructions'][ex] = full_reconstruction

            # Apply mask for metric calculation
            if mask is not None:
                mask_tensor = torch.tensor(mask, dtype=torch.float32, device=device)
                mask_expanded = mask_tensor.unsqueeze(-1).expand_as(data.to(device))

                # Calculate metrics only on valid pixels
                valid_pixels = mask_expanded.sum().item()

                if valid_pixels > 0:
                    # Calculate masked metrics
                    masked_squared_error = ((full_reconstruction - data.to(device)) ** 2) * mask_expanded
                    masked_abs_error = torch.abs(full_reconstruction - data.to(device)) * mask_expanded

                    mse = masked_squared_error.sum().item() / valid_pixels
                    mae = masked_abs_error.sum().item() / valid_pixels
                    psnr = 10 * np.log10(1.0 / mse) if mse > 0 else float('inf')

                    results['metrics'][ex] = {
                        'mse': mse,
                        'mae': mae,
                        'psnr': psnr,
                        'valid_pixels': valid_pixels
                    }

                    overall_mse += mse
                    overall_mae += mae
                    num_excitations += 1

                    print(f"Excitation {ex}nm - MSE: {mse:.4f}, PSNR: {psnr:.2f} dB")
            else:
                # If no mask, use all pixels
                mse = F.mse_loss(full_reconstruction, data.to(device)).item()
                mae = torch.mean(torch.abs(full_reconstruction - data.to(device))).item()
                psnr = 10 * np.log10(1.0 / mse) if mse > 0 else float('inf')

                results['metrics'][ex] = {
                    'mse': mse,
                    'mae': mae,
                    'psnr': psnr,
                }

                overall_mse += mse
                overall_mae += mae
                num_excitations += 1

                print(f"Excitation {ex}nm - MSE: {mse:.4f}, PSNR: {psnr:.2f} dB")

        # Calculate overall metrics
        if num_excitations > 0:
            results['metrics']['overall'] = {
                'mse': overall_mse / num_excitations,
                'mae': overall_mae / num_excitations,
                'psnr': 10 * np.log10(1.0 / (overall_mse / num_excitations))
            }

            print(f"Overall - MSE: {results['metrics']['overall']['mse']:.4f}, "
                  f"PSNR: {results['metrics']['overall']['psnr']:.2f} dB")

    return results

# Run evaluation
evaluation_results = evaluate_model(model, dataset, chunk_size=64, overlap=8, device=device)

Evaluating model...
Created 35 chunks of size up to 64x64 with 8 overlap
Excitation 300.0nm - MSE: 0.0001, PSNR: 40.47 dB
Created 35 chunks of size up to 64x64 with 8 overlap
Excitation 310.0nm - MSE: 0.0001, PSNR: 42.09 dB
Created 35 chunks of size up to 64x64 with 8 overlap
Excitation 320.0nm - MSE: 0.0000, PSNR: 43.08 dB
Created 35 chunks of size up to 64x64 with 8 overlap
Excitation 330.0nm - MSE: 0.0000, PSNR: 43.81 dB
Created 35 chunks of size up to 64x64 with 8 overlap
Excitation 340.0nm - MSE: 0.0000, PSNR: 44.27 dB
Created 35 chunks of size up to 64x64 with 8 overlap
Excitation 350.0nm - MSE: 0.0000, PSNR: 44.51 dB
Created 35 chunks of size up to 64x64 with 8 overlap
Excitation 360.0nm - MSE: 0.0000, PSNR: 44.36 dB
Created 35 chunks of size up to 64x64 with 8 overlap
Excitation 370.0nm - MSE: 0.0000, PSNR: 44.36 dB
Created 35 chunks of size up to 64x64 with 8 overlap
Excitation 380.0nm - MSE: 0.0001, PSNR: 42.84 dB
Created 35 chunks of size up to 64x64 with 8 overlap
Excitatio

In [11]:
def create_rgb_visualization(data_dict, emission_wavelengths, mask=None, output_dir=None):
    """
    Create RGB visualizations from hyperspectral data.
    """
    # Create output directory
    vis_dir = os.path.join(output_dir, "visualizations")
    os.makedirs(vis_dir, exist_ok=True)

    # Default RGB bands (adjust as needed)
    r_band, g_band, b_band = 650, 550, 450

    # Choose excitations to visualize (first 3 for example)
    excitations = list(data_dict.keys())[:3]

    # Create a figure for comparison
    fig, axes = plt.subplots(1, len(excitations), figsize=(len(excitations) * 6, 5))
    if len(excitations) == 1:
        axes = [axes]

    # Store RGB images
    rgb_dict = {}

    # Find global min/max for consistent normalization
    global_min, global_max = float('inf'), float('-inf')

    for ex in excitations:
        # Get data
        if isinstance(data_dict[ex], torch.Tensor):
            data = data_dict[ex].cpu().numpy()
        else:
            data = data_dict[ex]

        # Get band indices (find closest to target wavelengths)
        if ex in emission_wavelengths:
            wavelengths = emission_wavelengths[ex]
            r_idx = np.argmin(np.abs(np.array(wavelengths) - r_band))
            g_idx = np.argmin(np.abs(np.array(wavelengths) - g_band))
            b_idx = np.argmin(np.abs(np.array(wavelengths) - b_band))
        else:
            # Use indices proportionally if wavelengths not available
            num_bands = data.shape[2]
            r_idx = int(num_bands * 0.8)
            g_idx = int(num_bands * 0.5)
            b_idx = int(num_bands * 0.2)

        # Get channel data
        r_values = data[:, :, r_idx].flatten()
        g_values = data[:, :, g_idx].flatten()
        b_values = data[:, :, b_idx].flatten()

        # Apply mask if provided
        if mask is not None:
            mask_flat = mask.flatten()
            r_values = r_values[mask_flat > 0]
            g_values = g_values[mask_flat > 0]
            b_values = b_values[mask_flat > 0]

        # Update global min/max
        if not np.all(np.isnan(r_values)) and not np.all(np.isnan(g_values)) and not np.all(np.isnan(b_values)):
            local_min = min(np.nanmin(r_values), np.nanmin(g_values), np.nanmin(b_values))
            local_max = max(np.nanmax(r_values), np.nanmax(g_values), np.nanmax(b_values))
            global_min = min(global_min, local_min)
            global_max = max(global_max, local_max)

    # Create RGB images
    for i, ex in enumerate(excitations):
        # Get data
        if isinstance(data_dict[ex], torch.Tensor):
            data = data_dict[ex].cpu().numpy()
        else:
            data = data_dict[ex]

        # Get band indices
        if ex in emission_wavelengths:
            wavelengths = emission_wavelengths[ex]
            r_idx = np.argmin(np.abs(np.array(wavelengths) - r_band))
            g_idx = np.argmin(np.abs(np.array(wavelengths) - g_band))
            b_idx = np.argmin(np.abs(np.array(wavelengths) - b_band))
        else:
            num_bands = data.shape[2]
            r_idx = int(num_bands * 0.8)
            g_idx = int(num_bands * 0.5)
            b_idx = int(num_bands * 0.2)

        # Create RGB image
        rgb = np.stack([
            data[:, :, r_idx],  # R channel
            data[:, :, g_idx],  # G channel
            data[:, :, b_idx]   # B channel
        ], axis=2)

        # Apply mask if provided
        if mask is not None:
            # Create mask with 3 channels
            mask_rgb = np.stack([mask, mask, mask], axis=2)
            rgb = rgb * mask_rgb

        # Normalize to [0,1] range
        rgb_normalized = np.clip((rgb - global_min) / (global_max - global_min + 1e-8), 0, 1)

        # Replace NaNs with zeros
        rgb_normalized = np.nan_to_num(rgb_normalized, nan=0.0)

        # Store and plot
        rgb_dict[ex] = rgb_normalized
        axes[i].imshow(rgb_normalized)
        axes[i].set_title(f'Excitation {ex}nm')
        axes[i].axis('off')

    # Save figure
    plt.tight_layout()
    plt.savefig(os.path.join(vis_dir, "rgb_comparison.png"), dpi=300)
    plt.close()

    # Save individual RGB images
    for ex, rgb in rgb_dict.items():
        plt.figure(figsize=(8, 8))
        plt.imshow(rgb)
        plt.title(f'Excitation {ex}nm')
        plt.axis('off')
        plt.savefig(os.path.join(vis_dir, f"rgb_ex{ex}.png"), dpi=300)
        plt.close()

    return rgb_dict

# Create visualizations for original and reconstructed data
original_rgb = create_rgb_visualization(
    all_data,
    dataset.emission_wavelengths,
    mask=dataset.processed_mask,
    output_dir=output_dir
)

reconstructions = evaluation_results['reconstructions']
recon_rgb = create_rgb_visualization(
    reconstructions,
    dataset.emission_wavelengths,
    mask=dataset.processed_mask,
    output_dir=output_dir
)

# Create comparison for a specific excitation
def create_comparison(original_data, reconstructed_data, excitation, emission_wavelengths=None, mask=None):
    """
    Create comparison of original vs reconstructed data.
    """
    vis_dir = os.path.join(output_dir, "visualizations")

    # Convert tensors to numpy
    if isinstance(original_data, torch.Tensor):
        original_data = original_data.cpu().numpy()
    if isinstance(reconstructed_data, torch.Tensor):
        reconstructed_data = reconstructed_data.cpu().numpy()

    # Create figure
    plt.figure(figsize=(18, 6))

    # RGB comparison
    num_bands = original_data.shape[2]

    # Determine RGB indices
    if emission_wavelengths is not None:
        r_band, g_band, b_band = 650, 550, 450
        r_idx = np.argmin(np.abs(np.array(emission_wavelengths) - r_band))
        g_idx = np.argmin(np.abs(np.array(emission_wavelengths) - g_band))
        b_idx = np.argmin(np.abs(np.array(emission_wavelengths) - b_band))
    else:
        r_idx = int(num_bands * 0.8)
        g_idx = int(num_bands * 0.5)
        b_idx = int(num_bands * 0.2)

    # Create RGB images
    rgb_original = np.stack([
        original_data[:, :, r_idx],
        original_data[:, :, g_idx],
        original_data[:, :, b_idx]
    ], axis=2)

    rgb_recon = np.stack([
        reconstructed_data[:, :, r_idx],
        reconstructed_data[:, :, g_idx],
        reconstructed_data[:, :, b_idx]
    ], axis=2)

    # Apply mask
    if mask is not None:
        mask_rgb = np.stack([mask, mask, mask], axis=2)
        rgb_original = rgb_original * mask_rgb
        rgb_recon = rgb_recon * mask_rgb

    # Normalize
    min_val = min(np.nanmin(rgb_original), np.nanmin(rgb_recon))
    max_val = max(np.nanmax(rgb_original), np.nanmax(rgb_recon))

    rgb_original_norm = np.clip((rgb_original - min_val) / (max_val - min_val + 1e-8), 0, 1)
    rgb_recon_norm = np.clip((rgb_recon - min_val) / (max_val - min_val + 1e-8), 0, 1)

    # Calculate difference
    diff = np.abs(rgb_original_norm - rgb_recon_norm)
    diff_enhanced = np.clip(diff * 5, 0, 1)  # Enhance for visibility

    # Plot
    plt.subplot(1, 3, 1)
    plt.imshow(rgb_original_norm)
    plt.title('Original')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(rgb_recon_norm)
    plt.title('Reconstructed')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(diff_enhanced)
    plt.title('Difference (enhanced 5x)')
    plt.axis('off')

    plt.suptitle(f'Reconstruction Comparison - Excitation {excitation}nm')
    plt.tight_layout()
    plt.savefig(os.path.join(vis_dir, f"comparison_ex{excitation}.png"), dpi=300)
    plt.close()

    return {
        'original': rgb_original_norm,
        'reconstructed': rgb_recon_norm,
        'difference': diff_enhanced
    }

# Create comparisons for first 3 excitations
for ex in list(all_data.keys())[:3]:
    if ex in reconstructions:
        create_comparison(
            all_data[ex],
            reconstructions[ex],
            ex,
            emission_wavelengths=dataset.emission_wavelengths.get(ex, None),
            mask=dataset.processed_mask
        )

In [12]:
def extract_encoded_features(model, data_dict, mask=None, chunk_size=64, overlap=8, device='cuda'):
    """
    Extract encoded features from the model for all excitations.
    """
    # Set model to evaluation mode
    model.eval()

    # Get dimensions from first excitation
    first_ex = next(iter(data_dict.keys()))
    height, width = data_dict[first_ex].shape[:2]

    # Store features and shapes
    encoded_features = {}
    spatial_shapes = {}

    with torch.no_grad():
        for ex, data in data_dict.items():
            print(f"Extracting features for excitation {ex}...")

            # Create chunks
            chunks, positions = create_chunks(data.numpy(), chunk_size, overlap)

            # Initialize feature maps
            all_features = None

            # Process chunks
            for i, chunk in enumerate(chunks):
                # Convert to tensor and add batch dimension
                chunk_tensor = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(device)

                # Create input dictionary for this excitation only
                chunk_dict = {ex: chunk_tensor}

                # Extract encoded representation
                encoded = model.encode(chunk_dict)
                features = encoded.cpu().numpy()[0]  # Remove batch dimension

                # Initialize feature array on first chunk
                if all_features is None:
                    all_features = []
                    for feat_idx in range(features.shape[0]):
                        all_features.append(np.zeros((height, width)))

                # Store features in appropriate positions
                y_start, y_end, x_start, x_end = positions[i]

                # Remove emission dimension (which is 1)
                spatial_features = features.squeeze(1)

                # Store features
                for feat_idx in range(spatial_features.shape[0]):
                    feature_chunk = spatial_features[feat_idx]
                    current = all_features[feat_idx][y_start:y_end, x_start:x_end]

                    # Handle overlapping regions
                    overlap_mask = current != 0

                    # Set new areas directly
                    new_areas = ~overlap_mask
                    current[new_areas] = feature_chunk[new_areas]

                    # Average overlapping areas
                    if np.any(overlap_mask):
                        current[overlap_mask] = (current[overlap_mask] + feature_chunk[overlap_mask]) / 2

                    all_features[feat_idx][y_start:y_end, x_start:x_end] = current

                # Print progress
                if (i + 1) % 20 == 0 or i == len(chunks) - 1:
                    print(f"  Processed {i+1}/{len(chunks)} chunks", end="\r")

            # Stack features
            features_array = np.stack(all_features)

            # Store results
            encoded_features[ex] = features_array
            spatial_shapes[ex] = (height, width)

            print(f"\nExtracted {features_array.shape[0]} features for excitation {ex}")

    return encoded_features, spatial_shapes

def run_kmeans_clustering(features, n_clusters=10, random_state=42):
    """
    Run K-means clustering on the extracted features.
    """
    print(f"Running K-means clustering with {n_clusters} clusters...")

    # Get shape of features
    n_features, height, width = features.shape

    # Reshape to [pixels, features]
    features_reshaped = features.reshape(n_features, -1).T
    print(f"Feature matrix shape: {features_reshaped.shape}")

    # Apply PCA if dimensionality is very high
    if n_features > 50:
        from sklearn.decomposition import PCA
        print("Applying PCA to reduce dimensions...")
        pca = PCA(n_components=min(50, n_features-1))
        features_reshaped = pca.fit_transform(features_reshaped)
        print(f"Reduced features shape: {features_reshaped.shape}")

    # Use MiniBatchKMeans for better performance
    kmeans = MiniBatchKMeans(
        n_clusters=n_clusters,
        batch_size=1000,
        max_iter=300,
        random_state=random_state
    )

    print("Fitting K-means model...")
    labels = kmeans.fit_predict(features_reshaped)

    # Reshape labels back to spatial dimensions
    cluster_map = labels.reshape(height, width)
    print(f"Clustering complete. Found {len(np.unique(labels))} unique clusters")

    return cluster_map, kmeans

# Extract features
cluster_dir = os.path.join(output_dir, "clustering")
os.makedirs(cluster_dir, exist_ok=True)

print("Starting feature extraction...")
encoded_features, spatial_shapes = extract_encoded_features(
    model=model,
    data_dict=all_data,
    mask=dataset.processed_mask,
    chunk_size=64,
    overlap=8,
    device=device
)

# Choose excitation for clustering (use first excitation by default)
excitation_to_use = list(encoded_features.keys())[0]
print(f"Using excitation {excitation_to_use} for clustering")

# Run clustering
cluster_labels, clustering_model = run_kmeans_clustering(
    features=encoded_features[excitation_to_use],
    n_clusters=10
)

# Apply mask to cluster labels
if dataset.processed_mask is not None:
    cluster_labels[dataset.processed_mask == 0] = -1

# Visualize cluster map
plt.figure(figsize=(10, 8))
plt.imshow(cluster_labels, cmap='tab10', interpolation='nearest')
plt.colorbar(label='Cluster ID')
plt.title(f'Pixel-wise Clustering (Ex={excitation_to_use}nm, K=10)')
plt.axis('off')
plt.savefig(os.path.join(cluster_dir, f"cluster_map_ex{excitation_to_use}.png"), dpi=300)
plt.close()

# Save cluster labels
np.save(os.path.join(cluster_dir, f"cluster_labels_ex{excitation_to_use}.npy"), cluster_labels)

# Visualize cluster overlay on RGB image
def create_cluster_overlay(cluster_labels, rgb_image, alpha=0.5, output_path=None):
    """
    Create overlay of cluster labels on RGB image.
    """
    # Get unique clusters (excluding -1 which is for masked areas)
    unique_clusters = sorted([c for c in np.unique(cluster_labels) if c >= 0])
    n_clusters = len(unique_clusters)

    # Create a colormap for clusters - FIX FOR DEPRECATION WARNING
    # Replace plt.cm.get_cmap with plt.colormaps
    cluster_cmap = plt.colormaps['tab10'].resampled(max(10, n_clusters))

    # Create empty overlay (RGBA)
    overlay = np.zeros((*cluster_labels.shape, 4))

    # Fill with cluster colors
    for i, cluster_id in enumerate(unique_clusters):
        mask_cluster = cluster_labels == cluster_id
        color = cluster_cmap(i % 10)
        overlay[mask_cluster] = (*color[:3], alpha)

    # Set transparent for masked areas
    mask = cluster_labels < 0
    overlay[mask] = (0, 0, 0, 0)

    # Create figure
    plt.figure(figsize=(12, 10))

    # Show RGB image
    plt.imshow(rgb_image)

    # Add overlay
    plt.imshow(overlay, alpha=overlay[..., 3])

    # Add colorbar - FIX FOR COLORBAR ERROR
    # Get the current axes to pass to colorbar
    ax = plt.gca()
    sm = plt.cm.ScalarMappable(cmap=cluster_cmap)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, ticks=np.arange(n_clusters))
    cbar.set_ticklabels([f'Cluster {c}' for c in unique_clusters])

    plt.title('Cluster Overlay on RGB Image')
    plt.axis('off')

    # Save figure
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')

    plt.close()

    return overlay
# Create cluster overlay
if excitation_to_use in original_rgb:
    overlay = create_cluster_overlay(
        cluster_labels=cluster_labels,
        rgb_image=original_rgb[excitation_to_use],
        alpha=0.5,
        output_path=os.path.join(cluster_dir, "cluster_overlay.png")
    )

Starting feature extraction...
Extracting features for excitation 300.0...
Created 35 chunks of size up to 64x64 with 8 overlap
  Processed 35/35 chunks
Extracted 20 features for excitation 300.0
Extracting features for excitation 310.0...
Created 35 chunks of size up to 64x64 with 8 overlap
  Processed 35/35 chunks
Extracted 20 features for excitation 310.0
Extracting features for excitation 320.0...
Created 35 chunks of size up to 64x64 with 8 overlap
  Processed 35/35 chunks
Extracted 20 features for excitation 320.0
Extracting features for excitation 330.0...
Created 35 chunks of size up to 64x64 with 8 overlap
  Processed 35/35 chunks
Extracted 20 features for excitation 330.0
Extracting features for excitation 340.0...
Created 35 chunks of size up to 64x64 with 8 overlap
  Processed 35/35 chunks
Extracted 20 features for excitation 340.0
Extracting features for excitation 350.0...
Created 35 chunks of size up to 64x64 with 8 overlap
  Processed 35/35 chunks
Extracted 20 fea

In [13]:
def analyze_cluster_profiles(cluster_labels, all_data, emission_wavelengths):
    """
    Analyze spectral profiles for each cluster.
    """
    cluster_dir = os.path.join(output_dir, "clustering")

    # Get unique clusters (excluding -1 which is for masked areas)
    unique_clusters = sorted([c for c in np.unique(cluster_labels) if c >= 0])
    print(f"Analyzing profiles for {len(unique_clusters)} clusters")

    # Create figure for profiles
    plt.figure(figsize=(12, 8))

    # Store stats for each cluster
    cluster_stats = {}

    # Process each excitation wavelength
    for i, ex in enumerate(all_data.keys()):
        # Get data for this excitation
        if isinstance(all_data[ex], torch.Tensor):
            data = all_data[ex].cpu().numpy()
        else:
            data = all_data[ex]

        # Get emission wavelengths
        if ex in emission_wavelengths:
            wavelengths = emission_wavelengths[ex]
        else:
            wavelengths = np.arange(data.shape[2])

        # Use different markers for each excitation
        markers = ['o', 's', '^', 'D', 'v']
        marker = markers[i % len(markers)]

        # Process each cluster
        for cluster_id in unique_clusters:
            # Create mask for this cluster
            mask = cluster_labels == cluster_id

            # Skip if no pixels in this cluster
            if not np.any(mask):
                continue

            # Get data for this cluster
            cluster_data = data[mask]

            # Calculate mean spectrum (ignore NaNs)
            mean_spectrum = np.nanmean(cluster_data, axis=0)

            # Calculate standard deviation
            std_spectrum = np.nanstd(cluster_data, axis=0)

            # Store statistics
            if cluster_id not in cluster_stats:
                cluster_stats[cluster_id] = {}

            cluster_stats[cluster_id][ex] = {
                'mean': mean_spectrum,
                'std': std_spectrum,
                'count': np.sum(mask)
            }

            # Plot mean spectrum
            if i == 0:  # Only add to legend for first excitation
                plt.plot(wavelengths, mean_spectrum, marker=marker,
                         label=f"Cluster {cluster_id}",
                         color=plt.cm.tab10(cluster_id % 10))
            else:
                plt.plot(wavelengths, mean_spectrum, marker=marker,
                         color=plt.cm.tab10(cluster_id % 10))

    # Finish plot
    plt.xlabel('Emission Wavelength (nm)' if len(emission_wavelengths) > 0 else 'Emission Band Index')
    plt.ylabel('Intensity')
    plt.title('Spectral Profiles by Cluster')
    plt.grid(True, alpha=0.3)
    plt.legend()

    # Save figure
    plt.savefig(os.path.join(cluster_dir, "cluster_profiles.png"), dpi=300, bbox_inches='tight')
    plt.close()

    # Create bar chart of cluster sizes
    plt.figure(figsize=(10, 6))

    # Count pixels in each cluster
    cluster_sizes = [np.sum(cluster_labels == c) for c in unique_clusters]

    # Create bar chart
    plt.bar(
        [f"Cluster {c}" for c in unique_clusters],
        cluster_sizes,
        color=[plt.cm.tab10(c % 10) for c in unique_clusters]
    )

    plt.xlabel('Cluster')
    plt.ylabel('Number of Pixels')
    plt.title('Cluster Sizes')
    plt.xticks(rotation=45)
    plt.grid(True, axis='y', alpha=0.3)

    # Save figure
    plt.savefig(os.path.join(cluster_dir, "cluster_sizes.png"), dpi=300, bbox_inches='tight')
    plt.close()

    return cluster_stats

# Analyze cluster profiles
cluster_stats = analyze_cluster_profiles(
    cluster_labels=cluster_labels,
    all_data=all_data,
    emission_wavelengths=dataset.emission_wavelengths
)

print("Pipeline complete! All results saved to", output_dir)

Analyzing profiles for 10 clusters
Pipeline complete! All results saved to hyperspectral_results_lime
