In [None]:
import os
import time
import pickle
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from sklearn.cluster import MiniBatchKMeans

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [None]:
# File paths
data_path = "../Data/Lime Experiment/processed/masked_data_cutoff_30nm_exposure_max_power_min.pkl"
mask_path = "../Data/Lime Experiment/processed/mask.npy"
output_dir = "Lime"
os.makedirs(output_dir, exist_ok=True)

# Load data
print("Loading hyperspectral data...")
with open(data_path, 'rb') as f:
    data_dict = pickle.load(f)

print("Data Summary:")
print(f"Number of excitation wavelengths: {len(data_dict['excitation_wavelengths'])}")
print(f"Excitation wavelengths: {data_dict['excitation_wavelengths']}")

# Load mask
print("Loading mask...")
mask = np.load(mask_path)
valid_pixels = np.sum(mask)
total_pixels = mask.size
print(f"Mask loaded: {valid_pixels}/{total_pixels} valid pixels ({valid_pixels/total_pixels*100:.2f}%)")

# Create dataset
from AutoencoderPipeline import MaskedHyperspectralDataset

dataset = MaskedHyperspectralDataset(
    data_dict=data_dict,
    mask=mask,
    normalize=True,
    downscale_factor=1
)

# Get spatial dimensions
height, width = dataset.get_spatial_dimensions()
print(f"Data dimensions after processing: {height}x{width}")

# Get all data
all_data = dataset.get_all_data()

In [None]:
from AutoencoderPipeline import HyperspectralCAEWithMasking

model = HyperspectralCAEWithMasking(
    excitations_data={ex: data.numpy() for ex, data in all_data.items()},
    k1=20,
    k3=20,
    filter_size=5,
    sparsity_target=0.1,
    sparsity_weight=1.0,
    dropout_rate=0.5
)

print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")
model = model.to(device)

In [None]:
# IMPORTANT FIX: Make sure parameter names match
# The error you encountered was because 'chunk_overlap' should be 'overlap'
# Let's define a function to create spatial chunks with the correct parameter names

def create_chunks(data_tensor, chunk_size=96, overlap=32):
    """
    Split a large spatial hyperspectral tensor into overlapping chunks.

    Args:
        data_tensor: Input tensor of shape [height, width, emission_bands]
        chunk_size: Size of each spatial chunk
        overlap: Overlap between adjacent chunks

    Returns:
        List of chunk tensors and their positions
    """
    # Determine input shape
    if len(data_tensor.shape) == 4:  # [num_excitations, height, width, emission_bands]
        height, width = data_tensor.shape[1], data_tensor.shape[2]
    else:  # [height, width, emission_bands]
        height, width = data_tensor.shape[0], data_tensor.shape[1]

    # Calculate stride
    stride = chunk_size - overlap

    # Calculate number of chunks in each dimension
    num_chunks_y = max(1, (height - overlap) // stride)
    num_chunks_x = max(1, (width - overlap) // stride)

    # Adjust to ensure we cover the entire image
    if stride * num_chunks_y + overlap < height:
        num_chunks_y += 1
    if stride * num_chunks_x + overlap < width:
        num_chunks_x += 1

    # Create list to store chunks and their positions
    chunks = []
    positions = []

    # Extract chunks
    for i in range(num_chunks_y):
        for j in range(num_chunks_x):
            # Calculate start and end positions
            y_start = i * stride
            x_start = j * stride
            y_end = min(y_start + chunk_size, height)
            x_end = min(x_start + chunk_size, width)

            # Handle edge cases by adjusting start positions
            if y_end == height:
                y_start = max(0, height - chunk_size)
            if x_end == width:
                x_start = max(0, width - chunk_size)

            # Extract chunk based on input shape
            if len(data_tensor.shape) == 4:  # [num_excitations, height, width, emission_bands]
                chunk = data_tensor[:, y_start:y_end, x_start:x_end, :]
            else:  # [height, width, emission_bands]
                chunk = data_tensor[y_start:y_end, x_start:x_end, :]

            # Add to lists
            chunks.append(chunk)
            positions.append((y_start, y_end, x_start, x_end))

    print(f"Created {len(chunks)} chunks of size up to {chunk_size}x{chunk_size} with {overlap} overlap")
    return chunks, positions

In [None]:
# Now let's implement a custom training function that uses our fixed chunk creation

def train_model(
    model,
    dataset,
    num_epochs=30,
    learning_rate=0.001,
    chunk_size=96,
    overlap=32,
    device='cuda'
):
    # Create output directory for models
    model_dir = os.path.join(output_dir, "model")
    os.makedirs(model_dir, exist_ok=True)

    # Setup optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # Get data and mask
    all_data = dataset.get_all_data()
    mask_tensor = torch.tensor(dataset.processed_mask, dtype=torch.float32, device=device)

    # Create chunks for mask
    mask_expanded = dataset.processed_mask[..., np.newaxis]  # Add dummy dimension
    mask_chunks, mask_positions = create_chunks(mask_expanded, chunk_size, overlap)
    mask_chunks = [torch.tensor(chunk[..., 0], dtype=torch.float32).to(device) for chunk in mask_chunks]

    # Create chunks for each excitation
    print("Creating chunks for each excitation wavelength...")
    chunks_dict = {}
    positions_dict = {}

    for ex in all_data:
        data_np = all_data[ex].numpy()
        chunks, positions = create_chunks(data_np, chunk_size, overlap)
        chunks_dict[ex] = chunks
        positions_dict[ex] = positions

    # Get number of chunks
    num_chunks = len(next(iter(chunks_dict.values())))

    # Create batches
    batches = []
    mask_batches = []

    for i in range(num_chunks):
        # Data batch
        batch = {}
        for ex in chunks_dict:
            chunk = chunks_dict[ex][i]
            batch[ex] = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(device)
        batches.append(batch)

        # Mask batch
        mask_batches.append(mask_chunks[i].unsqueeze(0))  # Add batch dimension

    # Training loop
    print(f"Training for {num_epochs} epochs with {len(batches)} batches...")
    train_losses = []
    best_loss = float('inf')
    best_epoch = 0
    no_improvement_count = 0

    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        model.train()
        epoch_loss = 0.0
        epoch_recon_loss = 0.0
        epoch_sparsity_loss = 0.0

        # Train on each batch
        for i, (batch, mask_batch) in enumerate(zip(batches, mask_batches)):
            # Forward pass
            output = model(batch)

            # Compute masked reconstruction loss
            recon_loss = model.compute_masked_loss(
                output_dict=output,
                target_dict=batch,
                spatial_mask=mask_batch
            )

            # Compute sparsity loss
            encoded = model.encode(batch)
            sparsity_loss = model.compute_sparsity_loss(encoded)

            # Total loss
            loss = recon_loss + model.sparsity_weight * sparsity_loss

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            epoch_recon_loss += recon_loss.item()
            epoch_sparsity_loss += sparsity_loss.item()

            # Print progress
            if (i + 1) % 10 == 0 or i == len(batches) - 1:
                print(f"  Batch {i+1}/{len(batches)}", end="\r")

        # Record average loss
        avg_loss = epoch_loss / len(batches)
        avg_recon_loss = epoch_recon_loss / len(batches)
        avg_sparsity_loss = epoch_sparsity_loss / len(batches)
        train_losses.append(avg_loss)

        # Update scheduler
        scheduler.step(avg_loss)

        # Report progress
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f} "
              f"(Recon: {avg_recon_loss:.4f}, Sparsity: {avg_sparsity_loss:.4f}), "
              f"Time: {epoch_time:.2f}s")

        # Check if this is the best epoch
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_epoch = epoch
            no_improvement_count = 0

            # Save best model
            model_path = os.path.join(model_dir, "best_model.pth")
            torch.save(model.state_dict(), model_path)
            print(f"  New best model saved to {model_path}")
        else:
            no_improvement_count += 1
            print(f"  No improvement for {no_improvement_count} epochs (best: {best_loss:.4f} at epoch {best_epoch+1})")

            # Early stopping
            if no_improvement_count >= 5:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break

    # Save final model
    model_path = os.path.join(model_dir, "final_model.pth")
    torch.save(model.state_dict(), model_path)
    print(f"Final model saved to {model_path}")

    # Load best model
    model.load_state_dict(torch.load(os.path.join(model_dir, "best_model.pth")))

    return model, train_losses

# Run training
model, losses = train_model(
    model=model,
    dataset=dataset,
    num_epochs=30,
    learning_rate=0.001,
    chunk_size=256,
    overlap=128,
    device=device
)

# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(losses, marker='o')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.savefig(os.path.join(output_dir, "training_loss.png"))
plt.close()

In [None]:
def merge_chunk_reconstructions(chunks, positions, full_height, full_width):
    """
    Merge the reconstructed chunks back into a full image.
    """
    # Determine shape from the first chunk
    first_chunk = chunks[0]

    if len(first_chunk.shape) == 4:  # [batch, height, width, emission_bands]
        batch_size, _, _, num_bands = first_chunk.shape
        merged = torch.zeros((batch_size, full_height, full_width, num_bands),
                         device=first_chunk.device)
        weights = torch.zeros((batch_size, full_height, full_width, num_bands),
                          device=first_chunk.device)
    else:
        raise ValueError(f"Unexpected chunk shape: {first_chunk.shape}")

    # Merge chunks
    for chunk, (y_start, y_end, x_start, x_end) in zip(chunks, positions):
        merged[:, y_start:y_end, x_start:x_end, :] += chunk
        weights[:, y_start:y_end, x_start:x_end, :] += 1

    # Average overlapping regions
    merged = merged / torch.clamp(weights, min=1.0)

    return merged

def evaluate_model(model, dataset, chunk_size=96, overlap=32, device='cuda'):
    """
    Evaluate the model by generating reconstructions and calculating metrics.
    """
    # Create output directory
    eval_dir = os.path.join(output_dir, "evaluation")
    os.makedirs(eval_dir, exist_ok=True)

    # Set model to evaluation mode
    model.eval()

    # Get data and mask
    all_data = dataset.get_all_data()
    mask = dataset.processed_mask

    # Store results
    results = {
        'metrics': {},
        'reconstructions': {}
    }

    print("Evaluating model...")
    with torch.no_grad():
        overall_mse = 0.0
        overall_mae = 0.0
        num_excitations = 0

        for ex in all_data:
            data = all_data[ex]

            # Create chunks for this excitation
            chunks, positions = create_chunks(data.numpy(), chunk_size, overlap)

            # Process chunks
            reconstructed_chunks = []
            for i, chunk in enumerate(chunks):
                # Convert to tensor and add batch dimension
                chunk_tensor = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(device)

                # Create input dictionary for this excitation only
                chunk_dict = {ex: chunk_tensor}

                # Generate reconstruction
                output = model(chunk_dict)

                # Add to reconstructed chunks
                if ex in output:
                    reconstructed_chunks.append(output[ex])

                # Print progress
                if (i + 1) % 20 == 0 or i == len(chunks) - 1:
                    print(f"  Chunk {i+1}/{len(chunks)} for Ex={ex}nm", end="\r")

            # Skip if no valid reconstructions
            if not reconstructed_chunks:
                print(f"Warning: No valid reconstructions for excitation {ex}")
                continue

            # Merge chunks
            full_reconstruction = merge_chunk_reconstructions(
                reconstructed_chunks, positions, height, width
            )

            # Remove batch dimension
            full_reconstruction = full_reconstruction[0]

            # Store reconstruction
            results['reconstructions'][ex] = full_reconstruction

            # Apply mask for metric calculation
            if mask is not None:
                mask_tensor = torch.tensor(mask, dtype=torch.float32, device=device)
                mask_expanded = mask_tensor.unsqueeze(-1).expand_as(data.to(device))

                # Calculate metrics only on valid pixels
                valid_pixels = mask_expanded.sum().item()

                if valid_pixels > 0:
                    # Calculate masked metrics
                    masked_squared_error = ((full_reconstruction - data.to(device)) ** 2) * mask_expanded
                    masked_abs_error = torch.abs(full_reconstruction - data.to(device)) * mask_expanded

                    mse = masked_squared_error.sum().item() / valid_pixels
                    mae = masked_abs_error.sum().item() / valid_pixels
                    psnr = 10 * np.log10(1.0 / mse) if mse > 0 else float('inf')

                    results['metrics'][ex] = {
                        'mse': mse,
                        'mae': mae,
                        'psnr': psnr,
                        'valid_pixels': valid_pixels
                    }

                    overall_mse += mse
                    overall_mae += mae
                    num_excitations += 1

                    print(f"Excitation {ex}nm - MSE: {mse:.4f}, PSNR: {psnr:.2f} dB")
            else:
                # If no mask, use all pixels
                mse = F.mse_loss(full_reconstruction, data.to(device)).item()
                mae = torch.mean(torch.abs(full_reconstruction - data.to(device))).item()
                psnr = 10 * np.log10(1.0 / mse) if mse > 0 else float('inf')

                results['metrics'][ex] = {
                    'mse': mse,
                    'mae': mae,
                    'psnr': psnr,
                }

                overall_mse += mse
                overall_mae += mae
                num_excitations += 1

                print(f"Excitation {ex}nm - MSE: {mse:.4f}, PSNR: {psnr:.2f} dB")

        # Calculate overall metrics
        if num_excitations > 0:
            results['metrics']['overall'] = {
                'mse': overall_mse / num_excitations,
                'mae': overall_mae / num_excitations,
                'psnr': 10 * np.log10(1.0 / (overall_mse / num_excitations))
            }

            print(f"Overall - MSE: {results['metrics']['overall']['mse']:.4f}, "
                  f"PSNR: {results['metrics']['overall']['psnr']:.2f} dB")

    return results

# Run evaluation
evaluation_results = evaluate_model(model, dataset, chunk_size=256, overlap=128, device=device)

In [None]:
def create_rgb_visualization(data_dict, emission_wavelengths, mask=None, output_dir=None):
    """
    Create RGB visualizations from hyperspectral data.
    """
    # Create output directory
    vis_dir = os.path.join(output_dir, "visualizations")
    os.makedirs(vis_dir, exist_ok=True)

    # Default RGB bands (adjust as needed)
    r_band, g_band, b_band = 650, 550, 450

    # Choose excitations to visualize (first 3 for example)
    excitations = list(data_dict.keys())[:3]

    # Create a figure for comparison
    fig, axes = plt.subplots(1, len(excitations), figsize=(len(excitations) * 6, 5))
    if len(excitations) == 1:
        axes = [axes]

    # Store RGB images
    rgb_dict = {}

    # Find global min/max for consistent normalization
    global_min, global_max = float('inf'), float('-inf')

    for ex in excitations:
        # Get data
        if isinstance(data_dict[ex], torch.Tensor):
            data = data_dict[ex].cpu().numpy()
        else:
            data = data_dict[ex]

        # Get band indices (find closest to target wavelengths)
        if ex in emission_wavelengths:
            wavelengths = emission_wavelengths[ex]
            r_idx = np.argmin(np.abs(np.array(wavelengths) - r_band))
            g_idx = np.argmin(np.abs(np.array(wavelengths) - g_band))
            b_idx = np.argmin(np.abs(np.array(wavelengths) - b_band))
        else:
            # Use indices proportionally if wavelengths not available
            num_bands = data.shape[2]
            r_idx = int(num_bands * 0.8)
            g_idx = int(num_bands * 0.5)
            b_idx = int(num_bands * 0.2)

        # Get channel data
        r_values = data[:, :, r_idx].flatten()
        g_values = data[:, :, g_idx].flatten()
        b_values = data[:, :, b_idx].flatten()

        # Apply mask if provided
        if mask is not None:
            mask_flat = mask.flatten()
            r_values = r_values[mask_flat > 0]
            g_values = g_values[mask_flat > 0]
            b_values = b_values[mask_flat > 0]

        # Update global min/max
        if not np.all(np.isnan(r_values)) and not np.all(np.isnan(g_values)) and not np.all(np.isnan(b_values)):
            local_min = min(np.nanmin(r_values), np.nanmin(g_values), np.nanmin(b_values))
            local_max = max(np.nanmax(r_values), np.nanmax(g_values), np.nanmax(b_values))
            global_min = min(global_min, local_min)
            global_max = max(global_max, local_max)

    # Create RGB images
    for i, ex in enumerate(excitations):
        # Get data
        if isinstance(data_dict[ex], torch.Tensor):
            data = data_dict[ex].cpu().numpy()
        else:
            data = data_dict[ex]

        # Get band indices
        if ex in emission_wavelengths:
            wavelengths = emission_wavelengths[ex]
            r_idx = np.argmin(np.abs(np.array(wavelengths) - r_band))
            g_idx = np.argmin(np.abs(np.array(wavelengths) - g_band))
            b_idx = np.argmin(np.abs(np.array(wavelengths) - b_band))
        else:
            num_bands = data.shape[2]
            r_idx = int(num_bands * 0.8)
            g_idx = int(num_bands * 0.5)
            b_idx = int(num_bands * 0.2)

        # Create RGB image
        rgb = np.stack([
            data[:, :, r_idx],  # R channel
            data[:, :, g_idx],  # G channel
            data[:, :, b_idx]   # B channel
        ], axis=2)

        # Apply mask if provided
        if mask is not None:
            # Create mask with 3 channels
            mask_rgb = np.stack([mask, mask, mask], axis=2)
            rgb = rgb * mask_rgb

        # Normalize to [0,1] range
        rgb_normalized = np.clip((rgb - global_min) / (global_max - global_min + 1e-8), 0, 1)

        # Replace NaNs with zeros
        rgb_normalized = np.nan_to_num(rgb_normalized, nan=0.0)

        # Store and plot
        rgb_dict[ex] = rgb_normalized
        axes[i].imshow(rgb_normalized)
        axes[i].set_title(f'Excitation {ex}nm')
        axes[i].axis('off')

    # Save figure
    plt.tight_layout()
    plt.savefig(os.path.join(vis_dir, "rgb_comparison.png"), dpi=300)
    plt.close()

    # Save individual RGB images
    for ex, rgb in rgb_dict.items():
        plt.figure(figsize=(8, 8))
        plt.imshow(rgb)
        plt.title(f'Excitation {ex}nm')
        plt.axis('off')
        plt.savefig(os.path.join(vis_dir, f"rgb_ex{ex}.png"), dpi=300)
        plt.close()

    return rgb_dict

# Create visualizations for original and reconstructed data
original_rgb = create_rgb_visualization(
    all_data,
    dataset.emission_wavelengths,
    mask=dataset.processed_mask,
    output_dir=output_dir
)

reconstructions = evaluation_results['reconstructions']
recon_rgb = create_rgb_visualization(
    reconstructions,
    dataset.emission_wavelengths,
    mask=dataset.processed_mask,
    output_dir=output_dir
)

# Create comparison for a specific excitation
def create_comparison(original_data, reconstructed_data, excitation, emission_wavelengths=None, mask=None):
    """
    Create comparison of original vs reconstructed data.
    """
    vis_dir = os.path.join(output_dir, "visualizations")

    # Convert tensors to numpy
    if isinstance(original_data, torch.Tensor):
        original_data = original_data.cpu().numpy()
    if isinstance(reconstructed_data, torch.Tensor):
        reconstructed_data = reconstructed_data.cpu().numpy()

    # Create figure
    plt.figure(figsize=(18, 6))

    # RGB comparison
    num_bands = original_data.shape[2]

    # Determine RGB indices
    if emission_wavelengths is not None:
        r_band, g_band, b_band = 650, 550, 450
        r_idx = np.argmin(np.abs(np.array(emission_wavelengths) - r_band))
        g_idx = np.argmin(np.abs(np.array(emission_wavelengths) - g_band))
        b_idx = np.argmin(np.abs(np.array(emission_wavelengths) - b_band))
    else:
        r_idx = int(num_bands * 0.8)
        g_idx = int(num_bands * 0.5)
        b_idx = int(num_bands * 0.2)

    # Create RGB images
    rgb_original = np.stack([
        original_data[:, :, r_idx],
        original_data[:, :, g_idx],
        original_data[:, :, b_idx]
    ], axis=2)

    rgb_recon = np.stack([
        reconstructed_data[:, :, r_idx],
        reconstructed_data[:, :, g_idx],
        reconstructed_data[:, :, b_idx]
    ], axis=2)

    # Apply mask
    if mask is not None:
        mask_rgb = np.stack([mask, mask, mask], axis=2)
        rgb_original = rgb_original * mask_rgb
        rgb_recon = rgb_recon * mask_rgb

    # Normalize
    min_val = min(np.nanmin(rgb_original), np.nanmin(rgb_recon))
    max_val = max(np.nanmax(rgb_original), np.nanmax(rgb_recon))

    rgb_original_norm = np.clip((rgb_original - min_val) / (max_val - min_val + 1e-8), 0, 1)
    rgb_recon_norm = np.clip((rgb_recon - min_val) / (max_val - min_val + 1e-8), 0, 1)

    # Calculate difference
    diff = np.abs(rgb_original_norm - rgb_recon_norm)
    diff_enhanced = np.clip(diff * 5, 0, 1)  # Enhance for visibility

    # Plot
    plt.subplot(1, 3, 1)
    plt.imshow(rgb_original_norm)
    plt.title('Original')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(rgb_recon_norm)
    plt.title('Reconstructed')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(diff_enhanced)
    plt.title('Difference (enhanced 5x)')
    plt.axis('off')

    plt.suptitle(f'Reconstruction Comparison - Excitation {excitation}nm')
    plt.tight_layout()
    plt.savefig(os.path.join(vis_dir, f"comparison_ex{excitation}.png"), dpi=300)
    plt.close()

    return {
        'original': rgb_original_norm,
        'reconstructed': rgb_recon_norm,
        'difference': diff_enhanced
    }

# Create comparisons for first 3 excitations
for ex in all_data.keys():
    if ex in reconstructions:
        create_comparison(
            all_data[ex],
            reconstructions[ex],
            ex,
            emission_wavelengths=dataset.emission_wavelengths.get(ex, None),
            mask=dataset.processed_mask
        )

In [None]:
def extract_encoded_features(model, data_dict, mask=None, chunk_size=64, overlap=8, device='cuda'):
    """
    Extract encoded features from the model for all excitations.
    """
    # Set model to evaluation mode
    model.eval()

    # Get dimensions from first excitation
    first_ex = next(iter(data_dict.keys()))
    height, width = data_dict[first_ex].shape[:2]

    # Store features and shapes
    encoded_features = {}
    spatial_shapes = {}

    with torch.no_grad():
        for ex, data in data_dict.items():
            print(f"Extracting features for excitation {ex}...")

            # Create chunks
            chunks, positions = create_chunks(data.numpy(), chunk_size, overlap)

            # Initialize feature maps
            all_features = None

            # Process chunks
            for i, chunk in enumerate(chunks):
                # Convert to tensor and add batch dimension
                chunk_tensor = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).to(device)

                # Create input dictionary for this excitation only
                chunk_dict = {ex: chunk_tensor}

                # Extract encoded representation
                encoded = model.encode(chunk_dict)
                features = encoded.cpu().numpy()[0]  # Remove batch dimension

                # Initialize feature array on first chunk
                if all_features is None:
                    all_features = []
                    for feat_idx in range(features.shape[0]):
                        all_features.append(np.zeros((height, width)))

                # Store features in appropriate positions
                y_start, y_end, x_start, x_end = positions[i]

                # Remove emission dimension (which is 1)
                spatial_features = features.squeeze(1)

                # Store features
                for feat_idx in range(spatial_features.shape[0]):
                    feature_chunk = spatial_features[feat_idx]
                    current = all_features[feat_idx][y_start:y_end, x_start:x_end]

                    # Handle overlapping regions
                    overlap_mask = current != 0

                    # Set new areas directly
                    new_areas = ~overlap_mask
                    current[new_areas] = feature_chunk[new_areas]

                    # Average overlapping areas
                    if np.any(overlap_mask):
                        current[overlap_mask] = (current[overlap_mask] + feature_chunk[overlap_mask]) / 2

                    all_features[feat_idx][y_start:y_end, x_start:x_end] = current

                # Print progress
                if (i + 1) % 20 == 0 or i == len(chunks) - 1:
                    print(f"  Processed {i+1}/{len(chunks)} chunks", end="\r")

            # Stack features
            features_array = np.stack(all_features)

            # Store results
            encoded_features[ex] = features_array
            spatial_shapes[ex] = (height, width)

            print(f"\nExtracted {features_array.shape[0]} features for excitation {ex}")

    return encoded_features, spatial_shapes

def run_kmeans_clustering(features, n_clusters=10, random_state=42):
    """
    Run K-means clustering on the extracted features.
    """
    print(f"Running K-means clustering with {n_clusters} clusters...")

    # Get shape of featuresa
    n_features, height, width = features.shape

    # Reshape to [pixels, features]
    features_reshaped = features.reshape(n_features, -1).T
    print(f"Feature matrix shape: {features_reshaped.shape}")

    # Apply PCA if dimensionality is very high
    # if n_features > 50:
    #     from sklearn.decomposition import PCA
    #     print("Applying PCA to reduce dimensions...")
    #     pca = PCA(n_components=min(50, n_features-1))
    #     features_reshaped = pca.fit_transform(features_reshaped)
    #     print(f"Reduced features shape: {features_reshaped.shape}")

    # Use MiniBatchKMeans for better performance
    kmeans = MiniBatchKMeans(
        n_clusters=n_clusters,
        batch_size=1000,
        max_iter=300,
        random_state=random_state
    )

    print("Fitting K-means model...")
    labels = kmeans.fit_predict(features_reshaped)

    # Reshape labels back to spatial dimensions
    cluster_map = labels.reshape(height, width)
    print(f"Clustering complete. Found {len(np.unique(labels))} unique clusters")

    return cluster_map, kmeans

In [None]:
cluster_dir = os.path.join(output_dir, "clustering")
os.makedirs(cluster_dir, exist_ok=True)

print("Starting feature extraction...")
encoded_features, spatial_shapes = extract_encoded_features(
    model=model,
    data_dict=all_data,
    mask=dataset.processed_mask,
    chunk_size=256,
    overlap=128,
    device=device
)

# Choose excitation for clustering (use first excitation by default)
excitation_to_use = list(encoded_features.keys())[0]
print(f"Using excitation {excitation_to_use} for clustering")

# Run clustering
cluster_labels, clustering_model = run_kmeans_clustering(
    features=encoded_features[excitation_to_use],
    n_clusters=5
)

# Apply mask to cluster labels
if dataset.processed_mask is not None:
    cluster_labels[dataset.processed_mask == 0] = -1

# Visualize cluster map
plt.figure(figsize=(10, 8))
plt.imshow(cluster_labels, cmap='tab10', interpolation='nearest')
plt.colorbar(label='Cluster ID')
plt.title(f'Pixel-wise Clustering (Ex={excitation_to_use}nm, K=10)')
plt.axis('off')
plt.savefig(os.path.join(cluster_dir, f"cluster_map_ex{excitation_to_use}.png"), dpi=300)
plt.close()

# Save cluster labels
np.save(os.path.join(cluster_dir, f"cluster_labels_ex{excitation_to_use}.npy"), cluster_labels)

# Visualize cluster overlay on RGB image
def create_cluster_overlay(cluster_labels, rgb_image, alpha=0.5, output_path=None):
    """
    Create overlay of cluster labels on RGB image.
    """
    # Get unique clusters (excluding -1 which is for masked areas)
    unique_clusters = sorted([c for c in np.unique(cluster_labels) if c >= 0])
    n_clusters = len(unique_clusters)

    # Create a colormap for clusters - FIX FOR DEPRECATION WARNING
    # Replace plt.cm.get_cmap with plt.colormaps
    cluster_cmap = plt.colormaps['tab10'].resampled(max(10, n_clusters))

    # Create empty overlay (RGBA)
    overlay = np.zeros((*cluster_labels.shape, 4))

    # Fill with cluster colors
    for i, cluster_id in enumerate(unique_clusters):
        mask_cluster = cluster_labels == cluster_id
        color = cluster_cmap(i % 10)
        overlay[mask_cluster] = (*color[:3], alpha)

    # Set transparent for masked areas
    mask = cluster_labels < 0
    overlay[mask] = (0, 0, 0, 0)

    # Create figure
    plt.figure(figsize=(12, 10))

    # Show RGB image
    plt.imshow(rgb_image)

    # Add overlay
    plt.imshow(overlay, alpha=overlay[..., 3])

    # Add colorbar - FIX FOR COLORBAR ERROR
    # Get the current axes to pass to colorbar
    ax = plt.gca()
    sm = plt.cm.ScalarMappable(cmap=cluster_cmap)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, ticks=np.arange(n_clusters))
    cbar.set_ticklabels([f'Cluster {c}' for c in unique_clusters])

    plt.title('Cluster Overlay on RGB Image')
    plt.axis('off')

    # Save figure
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')

    plt.close()

    return overlay
# Create cluster overlay
if excitation_to_use in original_rgb:
    overlay = create_cluster_overlay(
        cluster_labels=cluster_labels,
        rgb_image=original_rgb[excitation_to_use],
        alpha=0.5,
        output_path=os.path.join(cluster_dir, "cluster_overlay.png")
    )

In [None]:
def analyze_cluster_profiles(cluster_labels, all_data, emission_wavelengths):
    """
    Analyze spectral profiles for each cluster.
    """
    cluster_dir = os.path.join(output_dir, "clustering")

    # Get unique clusters (excluding -1 which is for masked areas)
    unique_clusters = sorted([c for c in np.unique(cluster_labels) if c >= 0])
    print(f"Analyzing profiles for {len(unique_clusters)} clusters")

    # Create figure for profiles
    plt.figure(figsize=(12, 8))

    # Store stats for each cluster
    cluster_stats = {}

    # Process each excitation wavelength
    for i, ex in enumerate(all_data.keys()):
        # Get data for this excitation
        if isinstance(all_data[ex], torch.Tensor):
            data = all_data[ex].cpu().numpy()
        else:
            data = all_data[ex]

        # Get emission wavelengths
        if ex in emission_wavelengths:
            wavelengths = emission_wavelengths[ex]
        else:
            wavelengths = np.arange(data.shape[2])

        # Use different markers for each excitation
        markers = ['o', 's', '^', 'D', 'v']
        marker = markers[i % len(markers)]

        # Process each cluster
        for cluster_id in unique_clusters:
            # Create mask for this cluster
            mask = cluster_labels == cluster_id

            # Skip if no pixels in this cluster
            if not np.any(mask):
                continue

            # Get data for this cluster
            cluster_data = data[mask]

            # Calculate mean spectrum (ignore NaNs)
            mean_spectrum = np.nanmean(cluster_data, axis=0)

            # Calculate standard deviation
            std_spectrum = np.nanstd(cluster_data, axis=0)

            # Store statistics
            if cluster_id not in cluster_stats:
                cluster_stats[cluster_id] = {}

            cluster_stats[cluster_id][ex] = {
                'mean': mean_spectrum,
                'std': std_spectrum,
                'count': np.sum(mask)
            }

            # Plot mean spectrum
            if i == 0:  # Only add to legend for first excitation
                plt.plot(wavelengths, mean_spectrum, marker=marker,
                         label=f"Cluster {cluster_id}",
                         color=plt.cm.tab10(cluster_id % 10))
            else:
                plt.plot(wavelengths, mean_spectrum, marker=marker,
                         color=plt.cm.tab10(cluster_id % 10))

    # Finish plot
    plt.xlabel('Emission Wavelength (nm)' if len(emission_wavelengths) > 0 else 'Emission Band Index')
    plt.ylabel('Intensity')
    plt.title('Spectral Profiles by Cluster')
    plt.grid(True, alpha=0.3)
    plt.legend()

    # Save figure
    plt.savefig(os.path.join(cluster_dir, "cluster_profiles.png"), dpi=300, bbox_inches='tight')
    plt.close()

    # Create bar chart of cluster sizes
    plt.figure(figsize=(10, 6))

    # Count pixels in each cluster
    cluster_sizes = [np.sum(cluster_labels == c) for c in unique_clusters]

    # Create bar chart
    plt.bar(
        [f"Cluster {c}" for c in unique_clusters],
        cluster_sizes,
        color=[plt.cm.tab10(c % 10) for c in unique_clusters]
    )

    plt.xlabel('Cluster')
    plt.ylabel('Number of Pixels')
    plt.title('Cluster Sizes')
    plt.xticks(rotation=45)
    plt.grid(True, axis='y', alpha=0.3)

    # Save figure
    plt.savefig(os.path.join(cluster_dir, "cluster_sizes.png"), dpi=300, bbox_inches='tight')
    plt.close()

    return cluster_stats

# Analyze cluster profiles
cluster_stats = analyze_cluster_profiles(
    cluster_labels=cluster_labels,
    all_data=all_data,
    emission_wavelengths=dataset.emission_wavelengths
)

print("Pipeline complete! All results saved to", output_dir)

In [None]:
# APPROACH 2: FEATURE CONCATENATION (preserves more information but increases dimensions)
print(f"Concatenating features from all {len(encoded_features)} excitations...")
all_excitation_features = np.concatenate([encoded_features[ex] for ex in encoded_features.keys()], axis=0)
print(f"Concatenated feature shape: {all_excitation_features.shape}")

# Run clustering on concatenated features
cluster_labels, clustering_model = run_kmeans_clustering(
    features=all_excitation_features,
    n_clusters=10
)

In [None]:
cluster_stats = analyze_cluster_profiles(
    cluster_labels=cluster_labels,
    all_data=all_data,
    emission_wavelengths=dataset.emission_wavelengths
)

In [None]:
# Extract features
cluster_dir = os.path.join(output_dir, "clustering")
os.makedirs(cluster_dir, exist_ok=True)

print("Starting feature extraction...")
encoded_features, spatial_shapes = extract_encoded_features(
    model=model,
    data_dict=all_data,
    mask=dataset.processed_mask,
    chunk_size=256,
    overlap=128,
    device=device
)

# APPROACH 1: FEATURE AVERAGING (combines excitations while keeping the same dimensionality)
print(f"Combining features from all {len(encoded_features)} excitations...")
all_excitation_features = np.stack([encoded_features[ex] for ex in encoded_features.keys()])
combined_features = np.mean(all_excitation_features, axis=0)
print(f"Combined feature shape: {combined_features.shape}")

# Run clustering on combined features from ALL excitations
print("Running clustering on features from all excitations...")
cluster_labels, clustering_model = run_kmeans_clustering(
    features=combined_features,
    n_clusters=10
)

# Apply mask to cluster labels
if dataset.processed_mask is not None:
    cluster_labels[dataset.processed_mask == 0] = -1

# Visualize cluster map
plt.figure(figsize=(10, 8))
plt.imshow(cluster_labels, cmap='tab10', interpolation='nearest')
plt.colorbar(label='Cluster ID')
plt.title(f'Pixel-wise Clustering (ALL Excitations, K=10)')
plt.axis('off')
plt.savefig(os.path.join(cluster_dir, "cluster_map_all_excitations.png"), dpi=300)
plt.close()

# Save cluster labels
np.save(os.path.join(cluster_dir, "cluster_labels_all_excitations.npy"), cluster_labels)

# Create cluster overlay (choose one excitation for background visualization only)
excitation_for_viz = list(original_rgb.keys())[0]
overlay = create_cluster_overlay(
    cluster_labels=cluster_labels,
    rgb_image=original_rgb[excitation_for_viz],
    alpha=0.5,
    output_path=os.path.join(cluster_dir, "cluster_overlay_all_excitations.png")
)

In [None]:
# Extract features
cluster_dir = os.path.join(output_dir, "clustering")
os.makedirs(cluster_dir, exist_ok=True)

print("Starting feature extraction...")
encoded_features, spatial_shapes = extract_encoded_features(
    model=model,
    data_dict=all_data,
    mask=dataset.processed_mask,
    chunk_size=256,
    overlap=128,
    device=device
)

# APPROACH 2: FEATURE CONCATENATION (preserves more information but increases dimensions)
print(f"Concatenating features from all {len(encoded_features)} excitations...")
all_excitation_features = np.concatenate([encoded_features[ex] for ex in encoded_features.keys()], axis=0)
print(f"Concatenated feature shape: {all_excitation_features.shape}")

# Run clustering on concatenated features
cluster_labels, clustering_model = run_kmeans_clustering(
    features=all_excitation_features,
    n_clusters=10
)
# Apply mask to cluster labels
if dataset.processed_mask is not None:
    cluster_labels[dataset.processed_mask == 0] = -1

# Visualize cluster map
plt.figure(figsize=(10, 8))
plt.imshow(cluster_labels, cmap='tab10', interpolation='nearest')
plt.colorbar(label='Cluster ID')
plt.title(f'Pixel-wise Clustering (ALL Excitations, K=10)')
plt.axis('off')
plt.savefig(os.path.join(cluster_dir, "cluster_map_all_excitations_concat.png"), dpi=300)
plt.close()

# Save cluster labels
np.save(os.path.join(cluster_dir, "cluster_map_all_excitations_concat.npy"), cluster_labels)

# Create cluster overlay (choose one excitation for background visualization only)
excitation_for_viz = list(original_rgb.keys())[0]
overlay = create_cluster_overlay(
    cluster_labels=cluster_labels,
    rgb_image=original_rgb[excitation_for_viz],
    alpha=0.5,
    output_path=os.path.join(cluster_dir, "cluster_map_all_excitations_concat.png")
)

In [None]:
# Add after clustering in both approaches to calculate validation metrics
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
import numpy as np
from scipy import ndimage

def evaluate_clustering(features, cluster_labels, mask=None):
    """Calculate clustering quality metrics."""
    # Reshape features for metric calculation
    if len(features.shape) == 3:  # [n_features, height, width]
        n_features, height, width = features.shape
        features_reshaped = features.reshape(n_features, -1).T
    else:
        features_reshaped = features

    # Reshape labels
    if len(cluster_labels.shape) == 2:  # [height, width]
        labels_flat = cluster_labels.flatten()
    else:
        labels_flat = cluster_labels

    # Get valid indices (exclude masked areas)
    valid_indices = labels_flat >= 0
    valid_features = features_reshaped[valid_indices]
    valid_labels = labels_flat[valid_indices]

    # Skip if only one cluster
    if len(np.unique(valid_labels)) <= 1:
        return {"error": "Not enough clusters for evaluation"}

    # Calculate standard metrics
    metrics = {
        "silhouette_score": silhouette_score(valid_features, valid_labels),
        "davies_bouldin_score": davies_bouldin_score(valid_features, valid_labels),
        "calinski_harabasz_score": calinski_harabasz_score(valid_features, valid_labels),
    }

    # Calculate spatial coherence (how many neighbors have the same label)
    if len(cluster_labels.shape) == 2:
        spatial_coherence = 0
        for cluster_id in np.unique(cluster_labels[cluster_labels >= 0]):
            # Create binary mask for this cluster
            cluster_mask = (cluster_labels == cluster_id).astype(np.int32)

            # Count neighbors with same label (3x3 kernel minus center)
            kernel = np.ones((3, 3), dtype=np.int32)
            kernel[1, 1] = 0  # Remove center
            neighbor_count = ndimage.convolve(cluster_mask, kernel, mode='constant', cval=0)

            # Calculate average neighbor ratio (max is 8 neighbors)
            neighbor_ratio = np.mean(neighbor_count[cluster_labels == cluster_id] / 8.0)
            spatial_coherence += neighbor_ratio

        # Average across clusters
        metrics["spatial_coherence"] = spatial_coherence / len(np.unique(cluster_labels[cluster_labels >= 0]))

    return metrics

# Calculate metrics for concatenation approach
concat_metrics = evaluate_clustering(all_excitation_features, cluster_labels, dataset.processed_mask)
print("Metrics for feature concatenation approach:")
for metric, value in concat_metrics.items():
    print(f"  {metric}: {value:.4f}")

# Save metrics
with open(os.path.join(cluster_dir, "concat_metrics.json"), "w") as f:
    import json
    json.dump(concat_metrics, f, indent=2)

# Create comparison visualization
plt.figure(figsize=(10, 6))
metrics_to_plot = ["silhouette_score", "spatial_coherence"]
colors = ['#3498db', '#2ecc71']
for i, metric in enumerate(metrics_to_plot):
    if metric in concat_metrics:
        plt.bar(i, concat_metrics[metric], color=colors[i])
plt.xticks(range(len(metrics_to_plot)), metrics_to_plot)
plt.ylabel('Score')
plt.title('Clustering Quality Metrics')
plt.savefig(os.path.join(cluster_dir, "clustering_metrics.png"), dpi=300)
# plt.close()

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
from scipy import ndimage
import pandas as pd

# Create a dedicated folder for comparison visualizations
comparison_dir = os.path.join(output_dir, "clustering_comparison")
os.makedirs(comparison_dir, exist_ok=True)

# Get the actual available excitation wavelengths from your data
# Instead of hardcoding specific values
available_excitations = list(encoded_features.keys())
print(f"Available excitations: {available_excitations}")

# Select a subset of excitations to compare (choose 4 if possible)
excitations_to_compare = []
if len(available_excitations) <= 4:
    excitations_to_compare = available_excitations
else:
    # Try to select evenly spaced excitations
    step = len(available_excitations) // 4
    excitations_to_compare = [available_excitations[i*step] for i in range(4)]

print(f"Using excitations for comparison: {excitations_to_compare}")

# Function to evaluate clustering quality
def evaluate_clustering(features, cluster_labels, mask=None):
    """Calculate clustering quality metrics."""
    # Reshape features for metric calculation
    if len(features.shape) == 3:  # [n_features, height, width]
        n_features, height, width = features.shape
        features_reshaped = features.reshape(n_features, -1).T
    else:
        features_reshaped = features

    # Reshape labels
    if len(cluster_labels.shape) == 2:  # [height, width]
        labels_flat = cluster_labels.flatten()
    else:
        labels_flat = cluster_labels

    # Get valid indices (exclude masked areas)
    valid_indices = labels_flat >= 0
    features_reshaped = features_reshaped[valid_indices]
    labels_flat = labels_flat[valid_indices]

    # Skip if only one cluster
    if len(np.unique(labels_flat)) <= 1:
        return {"error": "Not enough clusters for evaluation"}

    # Calculate standard metrics
    metrics = {}
    try:
        metrics["silhouette_score"] = silhouette_score(features_reshaped, labels_flat)
        metrics["davies_bouldin_score"] = davies_bouldin_score(features_reshaped, labels_flat)
        metrics["calinski_harabasz_score"] = calinski_harabasz_score(features_reshaped, labels_flat)
    except Exception as e:
        print(f"Error calculating metrics: {str(e)}")
        return {"error": str(e)}

    # Calculate spatial coherence
    if len(cluster_labels.shape) == 2:
        spatial_coherence = 0
        for cluster_id in np.unique(cluster_labels[cluster_labels >= 0]):
            # Create binary mask for this cluster
            cluster_mask = (cluster_labels == cluster_id).astype(np.int32)

            # Count neighbors with same label (3x3 kernel minus center)
            kernel = np.ones((3, 3), dtype=np.int32)
            kernel[1, 1] = 0  # Remove center
            neighbor_count = ndimage.convolve(cluster_mask, kernel, mode='constant', cval=0)

            # Calculate average neighbor ratio (max is 8 neighbors)
            neighbor_ratio = np.mean(neighbor_count[cluster_labels == cluster_id] / 8.0)
            spatial_coherence += neighbor_ratio

        # Average across clusters
        metrics["spatial_coherence"] = spatial_coherence / len(np.unique(cluster_labels[cluster_labels >= 0]))

    return metrics

# Store all clustering results and metrics
all_results = {}

# First, ensure we have the combined clustering result
print("1. Processing combined 4D clustering result...")
all_results['combined'] = {
    'features': all_excitation_features,
    'labels': cluster_labels,
    'metrics': evaluate_clustering(all_excitation_features, cluster_labels, dataset.processed_mask),
    'name': 'Combined 4D'
}

# Save the combined overlay separately
excitation_for_viz = list(original_rgb.keys())[0]
plt.figure(figsize=(10, 8))
overlay = create_cluster_overlay(
    cluster_labels=cluster_labels,
    rgb_image=original_rgb[excitation_for_viz],
    alpha=0.5,
    output_path=os.path.join(comparison_dir, "overlay_combined.png")
)
plt.close()

# 2. Run clustering on individual excitations
print("2. Processing individual excitation wavelengths...")
for ex in excitations_to_compare:
    print(f"  Processing excitation {ex}...")

    # Run clustering on this excitation's features
    ex_cluster_labels, _ = run_kmeans_clustering(
        features=encoded_features[ex],
        n_clusters=10  # Use same number as for combined
    )

    # Apply mask
    if dataset.processed_mask is not None:
        ex_cluster_labels[dataset.processed_mask == 0] = -1

    # Calculate metrics
    ex_metrics = evaluate_clustering(encoded_features[ex], ex_cluster_labels)

    # Store results
    all_results[ex] = {
        'features': encoded_features[ex],
        'labels': ex_cluster_labels,
        'metrics': ex_metrics,
        'name': f'Excitation {ex}'
    }

    # Save individual cluster map
    plt.figure(figsize=(10, 8))
    plt.imshow(ex_cluster_labels, cmap='tab10', interpolation='nearest')
    plt.colorbar(label='Cluster ID')
    plt.title(f'Clustering on Excitation {ex}')
    plt.axis('off')
    plt.savefig(os.path.join(comparison_dir, f"cluster_map_ex{ex}.png"), dpi=300)
    plt.close()

    # Save individual overlay
    plt.figure(figsize=(10, 8))
    overlay = create_cluster_overlay(
        cluster_labels=ex_cluster_labels,
        rgb_image=original_rgb[excitation_for_viz], # Use same image for consistency
        alpha=0.5,
        output_path=os.path.join(comparison_dir, f"overlay_ex{ex}.png")
    )
    plt.close()

# 3. Create side-by-side comparison of all overlays (this is your main poster visual)
print("3. Creating side-by-side comparison visualization...")
n_methods = len(all_results)
fig, axes = plt.subplots(1, n_methods, figsize=(n_methods*5, 5))

for i, (key, result) in enumerate(all_results.items()):
    ax = axes[i] if n_methods > 1 else axes
    # Create cluster overlay
    overlay = create_cluster_overlay(
        cluster_labels=result['labels'],
        rgb_image=original_rgb[excitation_for_viz],
        alpha=0.5
    )
    # Display the overlay
    ax.imshow(overlay)
    ax.set_title(result['name'])
    ax.axis('off')

plt.tight_layout()
plt.savefig(os.path.join(comparison_dir, "all_overlays_comparison.png"), dpi=300)
plt.close()

# 4. Create metrics comparison table and visualization
print("4. Creating metrics comparison...")
metrics_table = []
for key, result in all_results.items():
    row = {'Method': result['name']}
    row.update(result['metrics'])
    metrics_table.append(row)

metrics_df = pd.DataFrame(metrics_table)

# Save metrics to CSV
metrics_df.to_csv(os.path.join(comparison_dir, "clustering_metrics.csv"), index=False)

# Create normalized metrics plot (perfect for poster)
plt.figure(figsize=(12, 6))

# Define metrics and their interpretation
metrics_info = {
    "silhouette_score": {"name": "Silhouette Score", "higher_better": True, "color": "#3498db"},
    "davies_bouldin_score": {"name": "Davies-Bouldin Index", "higher_better": False, "color": "#e74c3c"},
    "spatial_coherence": {"name": "Spatial Coherence", "higher_better": True, "color": "#2ecc71"}
}

# Set up for grouped bar chart
methods = list(all_results.keys())
x = np.arange(len(methods))
width = 0.25
multiplier = 0

# Create normalized scores (0-1 range) for easier comparison
normalized_scores = {}
for metric, info in metrics_info.items():
    if metric not in metrics_df.columns:
        continue  # Skip metrics that we don't have

    values = [all_results[m]['metrics'].get(metric, 0) for m in methods]

    # Handle error case
    if any(isinstance(v, dict) and 'error' in v for v in values):
        continue

    if not info["higher_better"]:
        # Invert the score if lower is better
        max_val = max(values) if max(values) > 0 else 1
        normalized_scores[metric] = [1 - (v / max_val) for v in values]
    else:
        # Normalize to 0-1 range
        max_val = max(values) if max(values) > 0 else 1
        normalized_scores[metric] = [v / max_val for v in values]

# Plot each metric as a group of bars
for metric, info in metrics_info.items():
    if metric not in normalized_scores:
        continue

    offset = width * multiplier
    plt.bar(x + offset, normalized_scores[metric], width, label=info["name"], color=info["color"])
    multiplier += 1

# Add labels and legend
plt.xlabel('Clustering Method')
plt.ylabel('Normalized Score (higher is better)')
plt.title('Clustering Quality Metrics Comparison')
plt.xticks(x + width, [all_results[m]['name'] for m in methods])
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(comparison_dir, "normalized_metrics.png"), dpi=300)
plt.close()

# 5. Create individual plots with metrics text underneath
print("5. Creating individual visualizations with metrics...")
for key, result in all_results.items():
    # Create figure with 2 rows: top for image, bottom for metrics
    fig = plt.figure(figsize=(8, 10))

    # Top plot: overlay
    ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
    overlay = create_cluster_overlay(
        cluster_labels=result['labels'],
        rgb_image=original_rgb[excitation_for_viz],
        alpha=0.5
    )
    ax1.imshow(overlay)
    ax1.set_title(f"Cluster Overlay: {result['name']}")
    ax1.axis('off')

    # Bottom plot: metrics as text
    ax2 = plt.subplot2grid((3, 1), (2, 0))
    ax2.axis('off')

    # Format metrics text
    metrics_text = "Metrics:\n"
    if isinstance(result['metrics'], dict) and 'error' not in result['metrics']:
        for metric, value in result['metrics'].items():
            # Format the metric name to be more readable
            nice_name = metric.replace('_', ' ').title()
            metrics_text += f"{nice_name}: {value:.4f}\n"
    else:
        metrics_text += "Metrics calculation failed"

    ax2.text(0.5, 0.5, metrics_text, ha='center', va='center', fontsize=12)

    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, f"overlay_with_metrics_{key}.png"), dpi=300)
    plt.close()

    # Also save without metrics for flexibility
    plt.figure(figsize=(8, 6))
    overlay = create_cluster_overlay(
        cluster_labels=result['labels'],
        rgb_image=original_rgb[excitation_for_viz],
        alpha=0.5
    )
    plt.imshow(overlay)
    plt.title(f"Cluster Overlay: {result['name']}")
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, f"overlay_only_{key}.png"), dpi=300)
    plt.close()

print(f"All comparison visualizations saved to: {comparison_dir}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import NMF
import os
import pandas as pd
import torch  # Added this import that was missing

# Create directory for NMF comparison
nmf_dir = os.path.join(output_dir, "nmf_comparison")
os.makedirs(nmf_dir, exist_ok=True)

print("Running Non-negative Matrix Factorization (NMF) for comparison...")

# 1. Prepare the data for NMF
def prepare_hyperspectral_data_for_nmf(all_data, mask=None):
    """
    Prepare hyperspectral data for NMF by flattening and stacking all excitations
    """
    # Get first excitation to determine spatial dimensions
    first_ex = next(iter(all_data.keys()))
    if isinstance(all_data[first_ex], torch.Tensor):
        first_data = all_data[first_ex].cpu().numpy()
    else:
        first_data = all_data[first_ex]

    height, width, _ = first_data.shape

    # Create list to hold all spectral data
    all_spectra = []
    wavelength_labels = []

    # Process each excitation
    for ex in all_data.keys():
        if isinstance(all_data[ex], torch.Tensor):
            data = all_data[ex].cpu().numpy()
        else:
            data = all_data[ex]

        # Get the number of emission bands
        _, _, n_bands = data.shape

        # Reshape to [pixels, bands]
        reshaped = data.reshape(height * width, n_bands)

        # Add to list
        all_spectra.append(reshaped)

        # Add wavelength labels for tracking
        for b in range(n_bands):
            wavelength_labels.append(f"Ex{ex}_Em{b}")

    # Concatenate along spectral dimension
    X = np.hstack(all_spectra)

    # Apply mask if provided
    if mask is not None:
        mask_flat = mask.flatten()
        X = X[mask_flat > 0]
        pixel_indices = np.where(mask_flat > 0)[0]
    else:
        pixel_indices = np.arange(height * width)

    return X, pixel_indices, (height, width), wavelength_labels

# 2. Run NMF with same number of components as clusters
n_components = 10  # Same as n_clusters used in K-means

# Prepare data
print("Preparing data for NMF...")
X, pixel_indices, spatial_dims, wavelength_labels = prepare_hyperspectral_data_for_nmf(
    all_data, dataset.processed_mask)
height, width = spatial_dims

print(f"Data shape for NMF: {X.shape}")

# Initialize and fit NMF model
print(f"Fitting NMF with {n_components} components...")
model = NMF(
    n_components=n_components,
    init='random',
    random_state=42,
    max_iter=200,
    solver='cd',  # Coordinate descent usually works well for this
)

# Fit model and get components (endmembers) and coefficients (abundances)
W = model.fit_transform(X)  # abundance coefficients [pixels, components]
H = model.components_      # endmembers [components, wavelengths]

print(f"NMF model fit complete.")

# 3. Create abundance maps
print("Creating abundance maps...")
abundance_maps = np.zeros((height, width, n_components))

# Fill in the valid pixels
for i, pixel_idx in enumerate(pixel_indices):
    y, x = pixel_idx // width, pixel_idx % width
    abundance_maps[y, x, :] = W[i, :]

# 4. Visualize abundance maps
print("Visualizing abundance maps...")
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()

for i in range(n_components):
    ax = axes[i]
    im = ax.imshow(abundance_maps[:, :, i], cmap='viridis')
    ax.set_title(f'Component {i+1}')
    ax.set_xticks([])
    ax.set_yticks([])
    plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.savefig(os.path.join(nmf_dir, "abundance_maps.png"), dpi=300)
plt.close()

# 5. Visualize endmember spectra
print("Visualizing endmember spectra...")
plt.figure(figsize=(12, 6))

for i in range(n_components):
    plt.plot(H[i], label=f'Component {i+1}')

plt.xlabel('Wavelength Index')
plt.ylabel('Intensity')
plt.title('NMF Endmember Spectra')
plt.legend()
plt.grid(alpha=0.3)
plt.savefig(os.path.join(nmf_dir, "endmember_spectra.png"), dpi=300)
plt.close()

# 6. Compare with clustering by creating dominant component map
print("Creating dominant component map...")
dominant_component = np.argmax(abundance_maps, axis=2)

plt.figure(figsize=(10, 8))
plt.imshow(dominant_component, cmap='tab10')
plt.colorbar(label='Dominant Component')
plt.title('Dominant NMF Component Map')
plt.axis('off')
plt.savefig(os.path.join(nmf_dir, "dominant_component_map.png"), dpi=300)
plt.close()

# 7. Create RGB overlay for dominant component
print("Creating RGB overlay for dominant component...")
excitation_for_viz = list(original_rgb.keys())[0]
overlay = create_cluster_overlay(
    cluster_labels=dominant_component,
    rgb_image=original_rgb[excitation_for_viz],
    alpha=0.5,
    output_path=os.path.join(nmf_dir, "dominant_component_overlay.png")
)

# 8. Quantitatively compare NMF dominant components with clustering
print("Comparing NMF with clustering results...")

# Compute overlap between dominant NMF components and cluster labels
def compute_component_cluster_overlap(components, clusters, n_components, n_clusters):
    """
    Compute overlap between NMF components and cluster assignments
    """
    # Flatten both maps
    components_flat = components.flatten()
    clusters_flat = clusters.flatten()

    # Consider only pixels with valid assignments in both
    valid_mask = (clusters_flat >= 0)
    components_flat = components_flat[valid_mask]
    clusters_flat = clusters_flat[valid_mask]

    # Create confusion matrix
    confusion = np.zeros((n_components, n_clusters))

    for i in range(len(components_flat)):
        component = components_flat[i]
        cluster = clusters_flat[i]
        confusion[component, cluster] += 1

    # Normalize by cluster size
    cluster_sizes = np.sum(confusion, axis=0)
    normalized_confusion = confusion / (cluster_sizes + 1e-10)

    return confusion, normalized_confusion

confusion, normalized_confusion = compute_component_cluster_overlap(
    dominant_component, cluster_labels, n_components, 10)

# Visualize confusion matrix
plt.figure(figsize=(10, 8))
plt.imshow(normalized_confusion, cmap='Blues', interpolation='nearest')
plt.colorbar(label='Overlap Ratio')
plt.xlabel('Cluster ID')
plt.ylabel('NMF Component ID')
plt.title('NMF Component vs K-means Cluster Overlap')
for i in range(n_components):
    for j in range(10):
        plt.text(j, i, f'{normalized_confusion[i, j]:.2f}',
                ha='center', va='center', color='white' if normalized_confusion[i, j] > 0.3 else 'black')
plt.tight_layout()
plt.savefig(os.path.join(nmf_dir, "component_cluster_overlap.png"), dpi=300)
plt.close()

# 9. Calculate overall agreement (assigned to same group)
print("Calculating overall agreement...")

# Create a mapping from NMF components to best matching clusters
component_to_cluster = np.argmax(normalized_confusion, axis=1)

# Create a remapped NMF component map to match clusters
remapped_component = np.zeros_like(dominant_component)
for i in range(n_components):
    remapped_component[dominant_component == i] = component_to_cluster[i]

# Calculate agreement
valid_mask = (cluster_labels >= 0)
agreement = np.sum(remapped_component[valid_mask] == cluster_labels[valid_mask]) / np.sum(valid_mask)

print(f"Overall agreement between NMF and clustering: {agreement:.4f} ({agreement*100:.1f}%)")

# 10. Side-by-side comparison of clustering and NMF
plt.figure(figsize=(16, 8))

plt.subplot(1, 2, 1)
plt.imshow(cluster_labels, cmap='tab10')
plt.title('K-means Clustering')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(dominant_component, cmap='tab10')
plt.title('NMF Dominant Component')
plt.axis('off')

plt.tight_layout()
plt.savefig(os.path.join(nmf_dir, "clustering_vs_nmf.png"), dpi=300)
plt.close()

# 11. Summary table for poster
summary = {
    'Method': ['K-means Clustering', 'NMF Decomposition'],
    'Approach': ['Hard assignment to clusters', 'Soft assignment (abundance weights)'],
    'Number of Groups': [10, n_components],
    'Features Used': ['Autoencoder latent space', 'Raw spectral data'],
    'Agreement': ['-', f'{agreement*100:.1f}%']
}

summary_df = pd.DataFrame(summary)
summary_df.to_csv(os.path.join(nmf_dir, "method_comparison.csv"), index=False)

print(f"NMF analysis complete. Results saved to: {nmf_dir}")

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
from scipy import ndimage
import pandas as pd

# Create a dedicated folder for spectral profiles
profiles_dir = os.path.join(output_dir, "spectral_profiles")
os.makedirs(profiles_dir, exist_ok=True)

# 1. MODIFY SIDE-BY-SIDE COMPARISON TO INCLUDE ORIGINAL IMAGE
print("3. Creating side-by-side comparison visualization with original image...")
n_methods = len(all_results) + 1  # +1 for the original image
fig, axes = plt.subplots(1, n_methods, figsize=(n_methods*5, 5))

# First plot: original RGB image
axes[0].imshow(original_rgb[excitation_for_viz])
axes[0].set_title("Original Image")
axes[0].axis('off')

# Then plot all clustering results
for i, (key, result) in enumerate(all_results.items()):
    ax = axes[i+1]  # +1 because the original image takes the first position
    # Create cluster overlay
    overlay = create_cluster_overlay(
        cluster_labels=result['labels'],
        rgb_image=original_rgb[excitation_for_viz],
        alpha=0.5
    )
    # Display the overlay
    ax.imshow(overlay)
    ax.set_title(result['name'])
    ax.axis('off')

plt.tight_layout()
plt.savefig(os.path.join(comparison_dir, "all_overlays_comparison_with_original.png"), dpi=300)
plt.close()

# 2. EXTRACT SPECTRAL PROFILES FOR EACH CLUSTER
print("6. Extracting and plotting spectral profiles for each clustering approach...")

# Function to extract spectral profiles for a specific excitation
def extract_spectral_profiles(excitation, cluster_labels, all_data, mask=None):
    """
    Extract average spectral profiles for each cluster in a specific excitation.

    Args:
        excitation: Excitation wavelength
        cluster_labels: Cluster assignment for each pixel
        all_data: Dictionary containing hyperspectral data
        mask: Optional binary mask to apply

    Returns:
        Dictionary mapping cluster IDs to average spectra
    """
    # Get data for this excitation
    ex_str = str(excitation)
    if ex_str not in all_data:
        print(f"Warning: Excitation {excitation} not found in data.")
        return {}

    # Get data and convert to numpy if needed
    if isinstance(all_data[ex_str], torch.Tensor):
        data = all_data[ex_str].cpu().numpy()
    else:
        data = all_data[ex_str]

    # Get emission bands
    n_bands = data.shape[2]

    # Get unique cluster IDs (excluding -1 which is for masked areas)
    unique_clusters = sorted([c for c in np.unique(cluster_labels) if c >= 0])

    # Calculate average spectrum for each cluster
    cluster_spectra = {}

    for cluster_id in unique_clusters:
        # Create mask for this cluster
        cluster_mask = cluster_labels == cluster_id

        # Apply additional mask if provided
        if mask is not None:
            cluster_mask = cluster_mask & (mask > 0)

        # Skip if no pixels in this cluster
        if not np.any(cluster_mask):
            continue

        # Extract spectra for all pixels in this cluster
        cluster_data = data[cluster_mask]

        # Calculate mean spectrum (ignore NaNs)
        mean_spectrum = np.nanmean(cluster_data, axis=0)

        # Calculate standard deviation for error bars
        std_spectrum = np.nanstd(cluster_data, axis=0)

        # Store in dictionary
        cluster_spectra[cluster_id] = {
            'mean': mean_spectrum,
            'std': std_spectrum,
            'count': np.sum(cluster_mask)
        }

    return cluster_spectra

# For each excitation, plot clusters from individual and combined clustering
for ex in excitations_to_compare:
    print(f"  Extracting spectral profiles for excitation {ex}...")
    ex_str = str(ex)

    # Create a directory for this excitation
    ex_dir = os.path.join(profiles_dir, f"excitation_{ex}")
    os.makedirs(ex_dir, exist_ok=True)

    # Get individual clustering result for this excitation
    individual_labels = all_results[ex]['labels']

    # Extract spectral profiles for individual clustering
    individual_spectra = extract_spectral_profiles(
        ex, individual_labels, all_data, dataset.processed_mask)

    # Get wavelengths if available
    if hasattr(dataset, 'emission_wavelengths') and ex in dataset.emission_wavelengths:
        wavelengths = dataset.emission_wavelengths[ex]
    else:
        wavelengths = np.arange(all_data[ex_str].shape[2])

    # 1. Plot individual clustering spectra
    plt.figure(figsize=(12, 6))

    for cluster_id, spectrum in individual_spectra.items():
        plt.plot(wavelengths, spectrum['mean'],
                 label=f'Cluster {cluster_id} (n={spectrum["count"]})',
                 linewidth=2)

        # Add error bands (±1 standard deviation)
        plt.fill_between(wavelengths,
                          spectrum['mean'] - spectrum['std'],
                          spectrum['mean'] + spectrum['std'],
                          alpha=0.2)

    plt.xlabel('Wavelength (nm)' if isinstance(wavelengths[0], (int, float)) else 'Emission Band Index')
    plt.ylabel('Normalized Intensity')
    plt.title(f'Spectral Profiles by Cluster - Individual Clustering (Ex {ex}nm)')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(ex_dir, f"individual_spectra_{ex}.png"), dpi=300)
    plt.close()

    # 2. Extract spectral profiles for combined clustering in this excitation
    combined_labels = all_results['combined']['labels']
    combined_spectra = extract_spectral_profiles(
        ex, combined_labels, all_data, dataset.processed_mask)

    # Plot combined clustering spectra
    plt.figure(figsize=(12, 6))

    for cluster_id, spectrum in combined_spectra.items():
        plt.plot(wavelengths, spectrum['mean'],
                 label=f'Cluster {cluster_id} (n={spectrum["count"]})',
                 linewidth=2)

        # Add error bands
        plt.fill_between(wavelengths,
                          spectrum['mean'] - spectrum['std'],
                          spectrum['mean'] + spectrum['std'],
                          alpha=0.2)

    plt.xlabel('Wavelength (nm)' if isinstance(wavelengths[0], (int, float)) else 'Emission Band Index')
    plt.ylabel('Normalized Intensity')
    plt.title(f'Spectral Profiles by Cluster - Combined 4D Clustering (Ex {ex}nm)')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(ex_dir, f"combined_spectra_{ex}.png"), dpi=300)
    plt.close()

    # 3. Direct comparison between individual and combined for key clusters
    # Choose top clusters by size (top 3 for clarity)
    individual_top = sorted(individual_spectra.items(),
                           key=lambda x: x[1]['count'],
                           reverse=True)[:3]
    combined_top = sorted(combined_spectra.items(),
                         key=lambda x: x[1]['count'],
                         reverse=True)[:3]

    plt.figure(figsize=(12, 6))

    # Plot individual top clusters with solid lines
    for i, (cluster_id, spectrum) in enumerate(individual_top):
        plt.plot(wavelengths, spectrum['mean'],
                 label=f'Individual Cluster {cluster_id}',
                 linestyle='-', linewidth=2)

    # Plot combined top clusters with dashed lines
    for i, (cluster_id, spectrum) in enumerate(combined_top):
        plt.plot(wavelengths, spectrum['mean'],
                 label=f'Combined Cluster {cluster_id}',
                 linestyle='--', linewidth=2)

    plt.xlabel('Wavelength (nm)' if isinstance(wavelengths[0], (int, float)) else 'Emission Band Index')
    plt.ylabel('Normalized Intensity')
    plt.title(f'Comparison of Top Clusters - Individual vs Combined (Ex {ex}nm)')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(ex_dir, f"comparison_spectra_{ex}.png"), dpi=300)
    plt.close()

# 3. CREATE SUMMARY VISUALIZATION WITH ALL EXCITATIONS
# For the poster, create a compact visualization showing key spectral differences
print("7. Creating summary spectral profile visualization...")

# Select a single representative excitation for clarity
representative_ex = excitations_to_compare[0]
rep_ex_str = str(representative_ex)

# Get wavelengths
if hasattr(dataset, 'emission_wavelengths') and representative_ex in dataset.emission_wavelengths:
    wavelengths = dataset.emission_wavelengths[representative_ex]
else:
    wavelengths = np.arange(all_data[rep_ex_str].shape[2])

# Create a 2x2 grid: Individual vs Combined, Cluster Maps vs Spectra
fig = plt.figure(figsize=(15, 12))

# Top row: Cluster maps
# Top left: Individual clustering
ax1 = plt.subplot2grid((2, 2), (0, 0))
ax1.imshow(all_results[representative_ex]['labels'], cmap='tab10')
ax1.set_title(f'Individual Clustering (Ex {representative_ex}nm)')
ax1.axis('off')

# Top right: Combined clustering
ax2 = plt.subplot2grid((2, 2), (0, 1))
ax2.imshow(all_results['combined']['labels'], cmap='tab10')
ax2.set_title('Combined 4D Clustering')
ax2.axis('off')

# Bottom row: Spectral profiles
# Bottom left: Individual spectra
ax3 = plt.subplot2grid((2, 2), (1, 0))
individual_spectra = extract_spectral_profiles(
    representative_ex, all_results[representative_ex]['labels'],
    all_data, dataset.processed_mask)

for cluster_id, spectrum in individual_spectra.items():
    ax3.plot(wavelengths, spectrum['mean'],
             label=f'Cluster {cluster_id}')

ax3.set_xlabel('Wavelength (nm)' if isinstance(wavelengths[0], (int, float)) else 'Emission Band')
ax3.set_ylabel('Normalized Intensity')
ax3.set_title('Individual Clustering Spectra')
ax3.grid(True, alpha=0.3)
ax3.legend()

# Bottom right: Combined spectra
ax4 = plt.subplot2grid((2, 2), (1, 1))
combined_spectra = extract_spectral_profiles(
    representative_ex, all_results['combined']['labels'],
    all_data, dataset.processed_mask)

for cluster_id, spectrum in combined_spectra.items():
    ax4.plot(wavelengths, spectrum['mean'],
             label=f'Cluster {cluster_id}')

ax4.set_xlabel('Wavelength (nm)' if isinstance(wavelengths[0], (int, float)) else 'Emission Band')
ax4.set_ylabel('Normalized Intensity')
ax4.set_title('Combined 4D Clustering Spectra')
ax4.grid(True, alpha=0.3)
ax4.legend()

plt.tight_layout()
plt.savefig(os.path.join(profiles_dir, "summary_comparison.png"), dpi=300)
plt.close()

print(f"Spectral profiles saved to: {profiles_dir}")

In [None]:
def extract_and_plot_spectral_profiles():
    """
    Create spectral profile plots for both individual excitation clustering
    and combined 4D clustering.
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    import torch  # Add this import to handle tensor conversion
    from matplotlib.colors import Normalize

    # Create a dedicated folder for spectral profiles
    profiles_dir = os.path.join(output_dir, "spectral_profiles")
    os.makedirs(profiles_dir, exist_ok=True)

    print("\n=== Extracting and plotting spectral profiles ===")

    # Print what's actually in all_data to identify the correct keys
    print("Keys in all_data:")
    if isinstance(all_data, dict):
        all_data_keys = list(all_data.keys())
        print(f"  all_data contains {len(all_data_keys)} keys: {all_data_keys[:5]}...")
    else:
        print(f"  all_data is not a dictionary, it's a {type(all_data)}")

    # Get list of all excitations
    available_excitations = list(encoded_features.keys())
    print(f"Available excitations: {available_excitations}")

    # Map encoded_features keys to all_data keys - THIS IS THE CRITICAL FIX
    # Since the issue is that we have excitation values as floats but all_data might be using them as strings
    # Or the actual mapping might be different
    ex_to_data_key = {}

    # Try different formats to match the keys
    for ex in available_excitations:
        # Try as is (float)
        if ex in all_data:
            ex_to_data_key[ex] = ex
        # Try as string
        elif str(ex) in all_data:
            ex_to_data_key[ex] = str(ex)
        # Try as integer
        elif int(ex) in all_data:
            ex_to_data_key[ex] = int(ex)
        # Try with different float precision
        elif f"{ex:.1f}" in all_data:
            ex_to_data_key[ex] = f"{ex:.1f}"

    print(f"Found matching keys for {len(ex_to_data_key)} excitations")

    # Function to extract spectral profiles for a specific excitation
    def extract_spectral_profiles(excitation, cluster_labels, all_data, ex_to_data_key, mask=None):
        """Extract average spectral profiles for each cluster in a specific excitation."""
        # Map to the right key
        if excitation not in ex_to_data_key:
            print(f"WARNING: No matching key found for excitation {excitation}")
            return {}

        data_key = ex_to_data_key[excitation]
        print(f"Using data_key '{data_key}' for excitation {excitation}")

        # Get data and convert to numpy if needed
        if isinstance(all_data[data_key], torch.Tensor):
            data = all_data[data_key].cpu().numpy()
        else:
            data = all_data[data_key]

        print(f"Processing excitation {excitation} data shape: {data.shape}")

        # Get emission bands
        n_bands = data.shape[2]

        # Get unique cluster IDs (excluding -1 which is for masked areas)
        unique_clusters = sorted([c for c in np.unique(cluster_labels) if c >= 0])
        print(f"Found {len(unique_clusters)} unique clusters")

        # Calculate average spectrum for each cluster
        cluster_spectra = {}

        for cluster_id in unique_clusters:
            # Create mask for this cluster
            cluster_mask = cluster_labels == cluster_id

            # Apply additional mask if provided
            if mask is not None:
                cluster_mask = cluster_mask & (mask > 0)

            # Count pixels in this cluster
            pixel_count = np.sum(cluster_mask)

            # Skip if no pixels in this cluster
            if pixel_count == 0:
                continue

            # Extract spectra for all pixels in this cluster
            cluster_data = data[cluster_mask]

            # Calculate mean spectrum (ignore NaNs)
            mean_spectrum = np.nanmean(cluster_data, axis=0)

            # Calculate standard deviation for error bars
            std_spectrum = np.nanstd(cluster_data, axis=0)

            # Store in dictionary
            cluster_spectra[cluster_id] = {
                'mean': mean_spectrum,
                'std': std_spectrum,
                'count': pixel_count
            }

        return cluster_spectra

    # Process each excitation that has a mapping
    for ex in available_excitations:
        if ex not in ex_to_data_key:
            continue

        print(f"\nProcessing excitation {ex}...")

        # Create a directory for this excitation
        ex_dir = os.path.join(profiles_dir, f"excitation_{ex}")
        os.makedirs(ex_dir, exist_ok=True)

        # Get individual clustering result for this excitation
        if ex in all_results:
            individual_labels = all_results[ex]['labels']
        else:
            print(f"Individual clustering not found for excitation {ex}")
            continue

        # Get combined clustering labels
        combined_labels = all_results['combined']['labels']

        # Get wavelengths if available
        if hasattr(dataset, 'emission_wavelengths') and ex in dataset.emission_wavelengths:
            wavelengths = dataset.emission_wavelengths[ex]
        else:
            data_key = ex_to_data_key[ex]
            wavelengths = np.arange(all_data[data_key].shape[2])

        # Extract spectral profiles for individual clustering
        print("Extracting individual clustering spectra...")
        individual_spectra = extract_spectral_profiles(
            ex, individual_labels, all_data, ex_to_data_key, dataset.processed_mask)

        # Plot individual clustering spectra
        if individual_spectra:
            plt.figure(figsize=(12, 6))

            for cluster_id, spectrum in individual_spectra.items():
                plt.plot(wavelengths, spectrum['mean'],
                        label=f'Cluster {cluster_id} (n={spectrum["count"]})',
                        linewidth=2)

                # Add error bands (±1 standard deviation)
                plt.fill_between(wavelengths,
                                spectrum['mean'] - spectrum['std'],
                                spectrum['mean'] + spectrum['std'],
                                alpha=0.2)

            plt.xlabel('Wavelength (nm)' if isinstance(wavelengths[0], (int, float)) else 'Emission Band Index')
            plt.ylabel('Normalized Intensity')
            plt.title(f'Spectral Profiles by Cluster - Individual Clustering (Ex {ex}nm)')
            plt.grid(True, alpha=0.3)
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(ex_dir, f"individual_spectra_{ex}.png"), dpi=300)
            plt.close()

        # Extract spectral profiles for combined clustering in this excitation
        print("Extracting combined 4D clustering spectra...")
        combined_spectra = extract_spectral_profiles(
            ex, combined_labels, all_data, ex_to_data_key, dataset.processed_mask)

        # Plot combined clustering spectra
        if combined_spectra:
            plt.figure(figsize=(12, 6))

            for cluster_id, spectrum in combined_spectra.items():
                plt.plot(wavelengths, spectrum['mean'],
                        label=f'Cluster {cluster_id} (n={spectrum["count"]})',
                        linewidth=2)

                # Add error bands
                plt.fill_between(wavelengths,
                                spectrum['mean'] - spectrum['std'],
                                spectrum['mean'] + spectrum['std'],
                                alpha=0.2)

            plt.xlabel('Wavelength (nm)' if isinstance(wavelengths[0], (int, float)) else 'Emission Band Index')
            plt.ylabel('Normalized Intensity')
            plt.title(f'Spectral Profiles by Cluster - Combined 4D Clustering (Ex {ex}nm)')
            plt.grid(True, alpha=0.3)
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(ex_dir, f"combined_spectra_{ex}.png"), dpi=300)
            plt.close()

        # If both individual and combined profiles are available, create comparison
        if individual_spectra and combined_spectra:
            # Choose top clusters by size (top 3 for clarity)
            individual_top = sorted(individual_spectra.items(),
                                key=lambda x: x[1]['count'],
                                reverse=True)[:3]
            combined_top = sorted(combined_spectra.items(),
                                key=lambda x: x[1]['count'],
                                reverse=True)[:3]

            plt.figure(figsize=(12, 6))

            # Plot individual top clusters with solid lines
            for i, (cluster_id, spectrum) in enumerate(individual_top):
                plt.plot(wavelengths, spectrum['mean'],
                        label=f'Individual Cluster {cluster_id}',
                        linestyle='-', linewidth=2)

            # Plot combined top clusters with dashed lines
            for i, (cluster_id, spectrum) in enumerate(combined_top):
                plt.plot(wavelengths, spectrum['mean'],
                        label=f'Combined Cluster {cluster_id}',
                        linestyle='--', linewidth=2)

            plt.xlabel('Wavelength (nm)' if isinstance(wavelengths[0], (int, float)) else 'Emission Band Index')
            plt.ylabel('Normalized Intensity')
            plt.title(f'Comparison of Top Clusters - Individual vs Combined (Ex {ex}nm)')
            plt.grid(True, alpha=0.3)
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(ex_dir, f"comparison_spectra_{ex}.png"), dpi=300)
            plt.close()

    print(f"Spectral profiles saved to: {profiles_dir}")

# Run the function
extract_and_plot_spectral_profiles()

In [None]:
def create_4d_cluster_profiles_by_excitation():
    """
    Create spectral profiles for 4D clusters with separate lines for each excitation.
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    import torch
    from matplotlib.colors import Normalize

    # Create directory
    profiles_dir = os.path.join(output_dir, "spectral_profiles_4d")
    os.makedirs(profiles_dir, exist_ok=True)

    print("\n=== Creating 4D cluster profiles with separate lines for each excitation ===")

    # First, identify what keys are in all_data
    if isinstance(all_data, dict):
        all_data_keys = list(all_data.keys())
        print(f"all_data keys: {all_data_keys[:10]}...")

    # Get excitation wavelengths
    available_excitations = []
    for key in all_data_keys:
        try:
            # Try to convert to float
            ex = float(key)
            available_excitations.append(ex)
        except (ValueError, TypeError):
            # Not a numerical key
            continue

    available_excitations.sort()
    print(f"Available excitations: {available_excitations}")

    # Get combined clustering labels
    combined_labels = all_results['combined']['labels']

    # Get unique cluster IDs (excluding -1)
    unique_clusters = sorted([c for c in np.unique(combined_labels) if c >= 0])
    print(f"Found {len(unique_clusters)} unique clusters")

    # Process each cluster
    for cluster_id in unique_clusters:
        print(f"Processing cluster {cluster_id}...")

        # Create mask for this cluster
        cluster_mask = combined_labels == cluster_id

        # Apply additional mask if provided
        if dataset.processed_mask is not None:
            cluster_mask = cluster_mask & (dataset.processed_mask > 0)

        # Count pixels in this cluster
        pixel_count = np.sum(cluster_mask)
        print(f"  Cluster {cluster_id} has {pixel_count} pixels")

        # Skip if empty
        if pixel_count == 0:
            continue

        # Create figure for this cluster
        plt.figure(figsize=(12, 6))

        # Process each excitation
        for ex in available_excitations:
            # Get data for this excitation
            ex_str = str(ex)
            if ex_str not in all_data:
                continue

            # Get data
            if isinstance(all_data[ex_str], torch.Tensor):
                data = all_data[ex_str].cpu().numpy()
            else:
                data = all_data[ex_str]

            # Get wavelengths
            if hasattr(dataset, 'emission_wavelengths') and ex in dataset.emission_wavelengths:
                wavelengths = dataset.emission_wavelengths[ex]
            else:
                wavelengths = np.arange(data.shape[2])

            # Extract spectra for this cluster
            cluster_data = data[cluster_mask]

            # Skip if no data
            if len(cluster_data) == 0:
                continue

            # Calculate mean spectrum
            mean_spectrum = np.nanmean(cluster_data, axis=0)

            # Plot with excitation in the label
            plt.plot(wavelengths, mean_spectrum,
                     label=f'Ex {ex}nm',
                     linewidth=1.5)

        plt.xlabel('Wavelength/Band Index')
        plt.ylabel('Normalized Intensity')
        plt.title(f'Cluster {cluster_id} Spectral Profile Across All Excitations (n={pixel_count})')
        plt.grid(True, alpha=0.3)
        plt.legend(loc='best', ncol=3)
        plt.tight_layout()
        plt.savefig(os.path.join(profiles_dir, f"cluster_{cluster_id}_all_excitations.png"), dpi=300)
        plt.close()

    print(f"4D cluster profiles saved to: {profiles_dir}")

def create_side_by_side_comparison():
    """
    Create a side-by-side comparison of:
    1. Original RGB image
    2. Individual excitation clustering
    3. Combined 4D clustering
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    import torch
    from sklearn.cluster import KMeans

    # Create directory
    comparison_dir = os.path.join(output_dir, "clustering_comparison_original")
    os.makedirs(comparison_dir, exist_ok=True)

    print("\n=== Creating side-by-side comparison ===")

    # 1. Find available excitations
    if isinstance(all_data, dict):
        all_data_keys = list(all_data.keys())
        print(f"all_data keys: {all_data_keys[:10]}...")

    # Get excitation wavelengths
    available_excitations = []
    for key in all_data_keys:
        try:
            # Try to convert to float
            ex = float(key)
            available_excitations.append(ex)
        except (ValueError, TypeError):
            # Not a numerical key
            continue

    available_excitations.sort()
    print(f"Available excitations: {available_excitations}")

    # 2. Select excitations for individual clustering (starting from ~300nm)
    selected_excitations = []
    for ex in available_excitations:
        if ex >= 300 and len(selected_excitations) < 4:
            selected_excitations.append(ex)

    print(f"Selected excitations for individual clustering: {selected_excitations}")

    # 3. Create RGB image from excitation ~350nm
    reference_ex = min(available_excitations, key=lambda x: abs(x - 350))
    print(f"Using excitation {reference_ex}nm for RGB reference")

    if reference_ex not in all_data:
        reference_ex = available_excitations[0]
        print(f"Fallback to excitation {reference_ex}nm")

    # Get reference data
    if isinstance(all_data[str(reference_ex)], torch.Tensor):
        reference_data = all_data[str(reference_ex)].cpu().numpy()
    else:
        reference_data = all_data[str(reference_ex)]

    # Create RGB from reference data
    n_bands = reference_data.shape[2]
    r_idx = min(n_bands - 1, int(n_bands * 0.8))  # Use band at 80% for red
    g_idx = min(n_bands - 1, int(n_bands * 0.5))  # Use band at 50% for green
    b_idx = min(n_bands - 1, int(n_bands * 0.2))  # Use band at 20% for blue

    rgb_image = np.stack([
        reference_data[:, :, r_idx],
        reference_data[:, :, g_idx],
        reference_data[:, :, b_idx]
    ], axis=2)

    # Normalize RGB
    rgb_min = np.nanmin(rgb_image)
    rgb_max = np.nanmax(rgb_image)
    rgb_image = (rgb_image - rgb_min) / (rgb_max - rgb_min)
    rgb_image = np.clip(rgb_image, 0, 1)

    # Save RGB image separately
    plt.figure(figsize=(8, 8))
    plt.imshow(rgb_image)
    plt.title(f'RGB Composite (Ex {reference_ex}nm)')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, "original_rgb.png"), dpi=300)
    plt.close()

    # 4. Run clustering on original data for selected excitations
    original_cluster_results = {}

    for ex in selected_excitations:
        print(f"Running clustering on original data for excitation {ex}nm...")
        ex_str = str(ex)

        # Get data
        if isinstance(all_data[ex_str], torch.Tensor):
            data = all_data[ex_str].cpu().numpy()
        else:
            data = all_data[ex_str]

        # Reshape for clustering [pixels, bands]
        height, width, n_bands = data.shape
        reshaped = data.reshape(-1, n_bands)

        # Apply mask if available
        valid_indices = None
        if dataset.processed_mask is not None:
            mask_flat = dataset.processed_mask.flatten()
            valid_indices = np.where(mask_flat > 0)[0]
            reshaped = reshaped[valid_indices]

        # Run K-means
        kmeans = KMeans(n_clusters=10, random_state=42)
        labels = kmeans.fit_predict(reshaped)

        # Reshape labels back to image
        cluster_map = np.ones((height, width)) * -1  # Default to -1 (masked)
        if valid_indices is not None:
            for i, idx in enumerate(valid_indices):
                y, x = idx // width, idx % width
                cluster_map[y, x] = labels[i]
        else:
            cluster_map = labels.reshape(height, width)

        # Store result
        original_cluster_results[ex] = cluster_map

        # Save individual cluster map
        plt.figure(figsize=(8, 8))
        plt.imshow(cluster_map, cmap='tab10', interpolation='nearest')
        plt.colorbar(label='Cluster')
        plt.title(f'Original Data Clustering (Ex {ex}nm)')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(comparison_dir, f"original_clustering_ex{ex}.png"), dpi=300)
        plt.close()

    # 5. Get combined 4D clustering result
    combined_labels = all_results['combined']['labels']

    # Save combined clustering
    plt.figure(figsize=(8, 8))
    plt.imshow(combined_labels, cmap='tab10', interpolation='nearest')
    plt.colorbar(label='Cluster')
    plt.title('Combined 4D Clustering')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, "combined_clustering.png"), dpi=300)
    plt.close()

    # 6. Create side-by-side comparison
    n_plots = len(selected_excitations) + 2  # +2 for original RGB and combined
    fig, axes = plt.subplots(1, n_plots, figsize=(n_plots * 4, 5))

    # Plot original RGB
    axes[0].imshow(rgb_image)
    axes[0].set_title(f'RGB Composite\n(Ex {reference_ex}nm)')
    axes[0].axis('off')

    # Plot individual clusterings
    for i, ex in enumerate(selected_excitations):
        axes[i+1].imshow(original_cluster_results[ex], cmap='tab10')
        axes[i+1].set_title(f'Clustering\n(Ex {ex}nm)')
        axes[i+1].axis('off')

    # Plot combined clustering
    axes[-1].imshow(combined_labels, cmap='tab10')
    axes[-1].set_title('Combined 4D\nClustering')
    axes[-1].axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, "full_comparison.png"), dpi=300)
    plt.close()

    print(f"Side-by-side comparison saved to: {comparison_dir}")

# Run both functions
create_4d_cluster_profiles_by_excitation()
create_side_by_side_comparison()

In [None]:
def create_side_by_side_comparison():
    """
    Create a side-by-side comparison of:
    1. Original RGB image
    2. Individual excitation clustering
    3. Combined 4D clustering
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    import torch
    from sklearn.cluster import KMeans

    # Create directory
    comparison_dir = os.path.join(output_dir, "clustering_comparison_original")
    os.makedirs(comparison_dir, exist_ok=True)

    print("\n=== Creating side-by-side comparison ===")

    # 1. Get all available excitation wavelengths
    available_excitations = list(all_data.keys())
    available_excitations.sort()
    print(f"Available excitations: {available_excitations[:5]}... (total: {len(available_excitations)})")

    # 2. Select excitations for individual clustering (4 excitations starting around 300nm)
    selected_excitations = []
    for ex in available_excitations:
        if ex >= 300 and len(selected_excitations) < 4:
            selected_excitations.append(ex)

    print(f"Selected excitations for individual clustering: {selected_excitations}")

    # 3. Create RGB image from reference excitation (~350nm)
    reference_ex = min(available_excitations, key=lambda x: abs(x - 350))
    print(f"Using excitation {reference_ex}nm for RGB reference")

    # Get tensor data and convert to numpy
    reference_data = all_data[reference_ex].cpu().numpy()
    print(f"Reference data shape: {reference_data.shape}")

    # Create RGB from reference data
    n_bands = reference_data.shape[2]
    r_idx = min(n_bands - 1, int(n_bands * 0.8))  # Use band at 80% for red
    g_idx = min(n_bands - 1, int(n_bands * 0.5))  # Use band at 50% for green
    b_idx = min(n_bands - 1, int(n_bands * 0.2))  # Use band at 20% for blue

    print(f"RGB bands from: R={r_idx}, G={g_idx}, B={b_idx}")

    rgb_image = np.stack([
        reference_data[:, :, r_idx],
        reference_data[:, :, g_idx],
        reference_data[:, :, b_idx]
    ], axis=2)

    # Normalize RGB
    rgb_min = np.nanmin(rgb_image)
    rgb_max = np.nanmax(rgb_image)
    rgb_image = (rgb_image - rgb_min) / (rgb_max - rgb_min)
    rgb_image = np.clip(rgb_image, 0, 1)

    # Save RGB image separately
    plt.figure(figsize=(8, 8))
    plt.imshow(rgb_image)
    plt.title(f'RGB Composite (Ex {reference_ex}nm)')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, "original_rgb.png"), dpi=300)
    plt.close()

    # 4. Run clustering on original data for selected excitations
    original_cluster_results = {}

    for ex in selected_excitations:
        print(f"Running clustering on original data for excitation {ex}nm...")

        # Get data and convert to numpy
        data = all_data[ex].cpu().numpy()
        print(f"  Data shape: {data.shape}")

        # Reshape for clustering [pixels, bands]
        height, width, n_bands = data.shape
        reshaped = data.reshape(-1, n_bands)

        # Apply mask if available
        valid_indices = None
        if dataset.processed_mask is not None:
            mask_flat = dataset.processed_mask.flatten()
            valid_indices = np.where(mask_flat > 0)[0]
            reshaped = reshaped[valid_indices]

        # Run K-means
        kmeans = KMeans(n_clusters=10, random_state=42)
        labels = kmeans.fit_predict(reshaped)

        # Reshape labels back to image
        cluster_map = np.ones((height, width)) * -1  # Default to -1 (masked)
        if valid_indices is not None:
            for i, idx in enumerate(valid_indices):
                y, x = idx // width, idx % width
                cluster_map[y, x] = labels[i]
        else:
            cluster_map = labels.reshape(height, width)

        # Store result
        original_cluster_results[ex] = cluster_map

        # Save individual cluster map
        plt.figure(figsize=(8, 8))
        plt.imshow(cluster_map, cmap='tab10', interpolation='nearest')
        plt.colorbar(label='Cluster')
        plt.title(f'Original Data Clustering (Ex {ex}nm)')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(comparison_dir, f"original_clustering_ex{ex}.png"), dpi=300)
        plt.close()

    # 5. Get combined 4D clustering result
    combined_labels = all_results['combined']['labels']

    # Save combined clustering
    plt.figure(figsize=(8, 8))
    plt.imshow(combined_labels, cmap='tab10', interpolation='nearest')
    plt.colorbar(label='Cluster')
    plt.title('Combined 4D Clustering')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, "combined_clustering.png"), dpi=300)
    plt.close()

    # 6. Create side-by-side comparison
    n_plots = len(selected_excitations) + 2  # +2 for original RGB and combined
    fig, axes = plt.subplots(1, n_plots, figsize=(n_plots * 4, 5))

    # Plot original RGB
    axes[0].imshow(rgb_image)
    axes[0].set_title(f'RGB Composite\n(Ex {reference_ex}nm)')
    axes[0].axis('off')

    # Plot individual clusterings
    for i, ex in enumerate(selected_excitations):
        axes[i+1].imshow(original_cluster_results[ex], cmap='tab10')
        axes[i+1].set_title(f'Clustering\n(Ex {ex}nm)')
        axes[i+1].axis('off')

    # Plot combined clustering
    axes[-1].imshow(combined_labels, cmap='tab10')
    axes[-1].set_title('Combined 4D\nClustering')
    axes[-1].axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, "full_comparison.png"), dpi=300)
    plt.close()

    print(f"Side-by-side comparison saved to: {comparison_dir}")

# Run the function
create_side_by_side_comparison()