# Run HSIDwRD Codespace on Kaggle
This notebook clones your GitHub repo, installs dependencies, and runs your model.

In [None]:
# 1. Clone the GitHub repository
!git clone https://github.com/Simhadri123/HSIDwRD.git /kaggle/working/HSIDwRD
!ls /kaggle/working/HSIDwRD

In [None]:
# 2. Install pytorch-gradual-warmup-lr from GitHub (not on PyPI)
!pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git

In [None]:
# 3. Install other requirements
!pip install -r /kaggle/working/HSIDwRD/requirements.txt

In [None]:
# 4. Simple training script (REPLACED by enhanced version below)
# This cell has been replaced by the enhanced training pipeline in cells 6-11
# The new version includes:
# - Real-time loss tracking and visualization
# - Comprehensive training progress monitoring
# - Better error handling and GPU utilization

# Uncomment the line below ONLY if you want to run the basic training without visualization
# !python /kaggle/working/HSIDwRD/train_denoising.py

print("INFO: This basic training has been replaced by the enhanced version in cells 6-11")
print("Run cells 6-11 for training with loss visualization")
print("Run cells 12-15 for comprehensive test set evaluation")

In [None]:
# 5. Install additional plotting libraries
!pip install matplotlib seaborn plotly yacs

In [None]:
# 6. Modified training script with loss tracking and visualization
import os
import sys
sys.path.append('/kaggle/working/HSIDwRD')

# Import required libraries for plotting
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import json

# Set up the environment
from config import Config 
opt = Config('/kaggle/working/HSIDwRD/training.yml')

gpus = ','.join([str(i) for i in opt.GPU])
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpus

import torch
torch.backends.cudnn.benchmark = True

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from natsort import natsorted
import glob
import random
import time
import numpy as np
from scipy import stats

import utils
from dataloaders.data import get_training_data, get_validation_data
from models import *
from losses import CharbonnierLoss
from tqdm import tqdm 
from warmup_scheduler import GradualWarmupScheduler

print("All libraries imported successfully!")
print(f"Using GPU(s): {gpus}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device count: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  Device {i}: {torch.cuda.get_device_name(i)}")

In [None]:
# 7. Setup model and training parameters
start_epoch = 1
mode = opt.MODEL.MODE
session = opt.MODEL.SESSION

result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session)
model_dir  = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models',  session)

utils.mkdir(result_dir)
utils.mkdir(model_dir)

train_dir = opt.TRAINING.TRAIN_DIR
val_dir   = opt.TRAINING.VAL_DIR
save_images = opt.TRAINING.SAVE_IMAGES

# Initialize loss tracking lists
train_losses = []
val_losses = []
val_psnrs = []
epochs_logged = []
iterations_logged = []

print("Training setup complete!")
print(f"Model directory: {model_dir}")
print(f"Results directory: {result_dir}")
print(f"Training data: {train_dir}")
print(f"Validation data: {val_dir}")

######### Model ###########
model_restoration = U_Net_3D()
model_restoration.cuda()

device_ids = [i for i in range(torch.cuda.device_count())]
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")

new_lr = opt.OPTIM.LR_INITIAL
optimizer = optim.AdamW(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8, weight_decay=1e-4)

print(f"Model initialized with learning rate: {new_lr}")

In [None]:
# 8. Setup data loaders and training components
######### Resume ###########
if opt.TRAINING.RESUME:
    path_chk_rest = utils.get_last_path(model_dir, '_latest.pth')
    utils.load_checkpoint(model_restoration, path_chk_rest)
    start_epoch = utils.load_start_epoch(path_chk_rest) + 1
    lr = utils.load_optim(optimizer, path_chk_rest)
    
    for p in optimizer.param_groups: 
        p['lr'] = lr
    warmup = False
    new_lr = lr
    print('------------------------------------------------------------------------------')
    print(f"==> Resuming Training with learning rate: {new_lr}")
    print('------------------------------------------------------------------------------')
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-start_epoch+1, eta_min=1e-6)
else:
    warmup = True

######### Scheduler ###########
if warmup:
    warmup_epochs = 3
    scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=1e-6)
    scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
    scheduler.step()

if len(device_ids) > 1:
    model_restoration = nn.DataParallel(model_restoration, device_ids=device_ids)

######### Loss ###########
criterion = CharbonnierLoss().cuda()

######### DataLoaders ###########
img_options_train = {'patch_size': opt.TRAINING.TRAIN_PS}

train_dataset = get_training_data(train_dir, opt.MODEL.RATIO, img_options_train)
train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=8, drop_last=True)

val_dataset = get_validation_data(val_dir, opt.MODEL.RATIO)
val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False)

print('Data loaders ready!')
print(f'Training batches: {len(train_loader)}')
print(f'Validation batches: {len(val_loader)}')
print(f'Training from epoch {start_epoch} to {opt.OPTIM.NUM_EPOCHS}')

In [None]:
# 9. Training loop with loss tracking
def plot_training_progress(train_losses, val_losses, val_psnrs, epochs_logged, save_path=None):
    """Plot training progress with dual y-axis for losses and PSNR"""
    
    # Create matplotlib figure with subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot training loss
    if train_losses:
        ax1.plot(range(1, len(train_losses) + 1), train_losses, 'b-', label='Training Loss', linewidth=2)
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training Loss Over Time')
        ax1.grid(True, alpha=0.3)
        ax1.legend()
    
    # Plot validation loss and PSNR
    if val_losses and val_psnrs:
        ax2_twin = ax2.twinx()
        
        line1 = ax2.plot(epochs_logged, val_losses, 'r-', label='Validation Loss', linewidth=2)
        line2 = ax2_twin.plot(epochs_logged, val_psnrs, 'g-', label='Validation PSNR', linewidth=2)
        
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Validation Loss', color='r')
        ax2_twin.set_ylabel('PSNR (dB)', color='g')
        ax2.set_title('Validation Metrics Over Time')
        ax2.grid(True, alpha=0.3)
        
        # Combine legends
        lines = line1 + line2
        labels = [l.get_label() for l in lines]
        ax2.legend(lines, labels, loc='upper left')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

def plot_interactive_progress(train_losses, val_losses, val_psnrs, epochs_logged):
    """Create interactive plotly plots for training progress"""
    
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=('Training Loss', 'Validation Loss', 'Validation PSNR', 'Combined View'),
        specs=[[{}, {}], [{}, {"secondary_y": True}]]
    )
    
    # Training loss
    if train_losses:
        fig.add_trace(
            go.Scatter(x=list(range(1, len(train_losses) + 1)), y=train_losses, 
                      mode='lines', name='Training Loss', line=dict(color='blue')),
            row=1, col=1
        )
    
    # Validation loss
    if val_losses:
        fig.add_trace(
            go.Scatter(x=epochs_logged, y=val_losses, 
                      mode='lines+markers', name='Validation Loss', line=dict(color='red')),
            row=1, col=2
        )
    
    # Validation PSNR
    if val_psnrs:
        fig.add_trace(
            go.Scatter(x=epochs_logged, y=val_psnrs, 
                      mode='lines+markers', name='Validation PSNR', line=dict(color='green')),
            row=2, col=1
        )
    
    # Combined view with dual y-axis
    if val_losses and val_psnrs:
        fig.add_trace(
            go.Scatter(x=epochs_logged, y=val_losses, 
                      mode='lines+markers', name='Val Loss', line=dict(color='red')),
            row=2, col=2
        )
        fig.add_trace(
            go.Scatter(x=epochs_logged, y=val_psnrs, 
                      mode='lines+markers', name='Val PSNR', line=dict(color='green'),
                      yaxis='y2'),
            row=2, col=2, secondary_y=True
        )
    
    fig.update_layout(height=800, showlegend=True, title_text="Training Progress Dashboard")
    fig.update_xaxes(title_text="Epoch")
    fig.update_yaxes(title_text="Loss", row=1, col=1)
    fig.update_yaxes(title_text="Loss", row=1, col=2)
    fig.update_yaxes(title_text="PSNR (dB)", row=2, col=1)
    fig.update_yaxes(title_text="Loss", row=2, col=2)
    fig.update_yaxes(title_text="PSNR (dB)", row=2, col=2, secondary_y=True)
    
    fig.show()

print("Plotting functions defined!")

# Initialize variables for training
mixup = utils.MixUp_AUG()
best_psnr = 0
best_epoch = 0
best_iter = 0

eval_now = len(train_loader) - 1
print(f"Evaluation after every {eval_now} iterations")
print(f"Starting training from epoch {start_epoch} to {opt.OPTIM.NUM_EPOCHS}")
print("=" * 80)

In [None]:
# 10. Main training loop
for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1):
    epoch_start_time = time.time()
    epoch_loss = 0
    train_id = 1
    
    print(f"Epoch {epoch}/{opt.OPTIM.NUM_EPOCHS}")
    print("-" * 50)
        
    for i, data in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}")): 
        # Zero gradients
        for param in model_restoration.parameters():
            param.grad = None

        target = torch.clamp(data[0].cuda(), 0, 1)
        input_ = torch.clamp(data[1].cuda(), 0, 1)

        if epoch > 5:
            target, input_ = mixup.aug(target, input_)

        restored = model_restoration(input_.unsqueeze(1))[:,0]
        restored = torch.clamp(restored, 0, 1)  
        
        loss = criterion(restored, target)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        #### Evaluation ####
        if epoch % 10 == 0 and i % eval_now == 0 and i > 0:
            if save_images:
                utils.mkdir(result_dir + f'/{epoch}/{i}')
                
            model_restoration.eval()
            with torch.no_grad():
                psnr_val = []
                val_loss_total = 0
                
                for ii, data_val in enumerate(val_loader):
                    target = data_val[0].cuda()
                    input_ = data_val[1].cuda()
                    filenames = data_val[2]

                    restored = model_restoration(input_.unsqueeze(1))[:,0]
                    restored = torch.clamp(restored, 0, 1)
                    
                    # Calculate validation loss
                    val_loss = criterion(restored, target)
                    val_loss_total += val_loss.item()
                    
                    psnr_val.append(utils.batch_PSNR(restored, target, 1.))

                    if save_images:
                        target_np = target.permute(0, 2, 3, 1).cpu().detach().numpy()
                        input_np = input_.permute(0, 2, 3, 1).cpu().detach().numpy()
                        restored_np = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
                        
                        for batch in range(input_np.shape[0]):
                            temp = np.concatenate((input_np[batch]*255, restored_np[batch]*255, target_np[batch]*255), axis=1)
                            utils.save_img(os.path.join(result_dir, str(epoch), str(i), filenames[batch][:-4] + '.jpg'), temp.astype(np.uint8))

                psnr_val = sum(psnr_val) / len(psnr_val)
                val_loss_avg = val_loss_total / len(val_loader)
                
                # Store validation metrics
                val_losses.append(val_loss_avg)
                val_psnrs.append(psnr_val)
                epochs_logged.append(epoch)
                
                if psnr_val > best_psnr:
                    best_psnr = psnr_val
                    best_epoch = epoch
                    best_iter = i 
                    torch.save({
                        'epoch': epoch, 
                        'state_dict': model_restoration.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, os.path.join(model_dir, "model_best.pth"))

                print(f"[Ep {epoch} it {i}\t PSNR: {psnr_val:.4f}\t Val Loss: {val_loss_avg:.4f}] ----  [Best Ep {best_epoch} Best PSNR {best_psnr:.4f}]")
            
            model_restoration.train()

    scheduler.step()
    
    # Store training loss for this epoch
    avg_epoch_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_epoch_loss)
    
    epoch_time = time.time() - epoch_start_time
    current_lr = scheduler.get_lr()[0]
    
    print("-" * 80)
    print(f"Epoch: {epoch}\tTime: {epoch_time:.2f}s\tLoss: {avg_epoch_loss:.6f}\tLR: {current_lr:.6f}")
    print("-" * 80)
    
    # Plot progress every 10 epochs or at the end
    if epoch % 10 == 0 or epoch == opt.OPTIM.NUM_EPOCHS:
        print(f"Plotting training progress at epoch {epoch}...")
        plot_training_progress(train_losses, val_losses, val_psnrs, epochs_logged, 
                             save_path=os.path.join(result_dir, f'training_progress_epoch_{epoch}.png'))

    # Save checkpoint periodically
    if epoch % 100 == 0:
        torch.save({
            'epoch': epoch, 
            'state_dict': model_restoration.state_dict(),
            'optimizer': optimizer.state_dict(),
            'train_losses': train_losses,
            'val_losses': val_losses,
            'val_psnrs': val_psnrs,
            'epochs_logged': epochs_logged
        }, os.path.join(model_dir, "model_latest.pth"))   

        torch.save({
            'epoch': epoch, 
            'state_dict': model_restoration.state_dict(),
            'optimizer': optimizer.state_dict(),
            'train_losses': train_losses,
            'val_losses': val_losses,
            'val_psnrs': val_psnrs,
            'epochs_logged': epochs_logged
        }, os.path.join(model_dir, f"model_epoch_{epoch}.pth"))

print("Training completed!")
print(f"Best PSNR: {best_psnr:.4f} at epoch {best_epoch}")

# Save final training history
training_history = {
    'train_losses': train_losses,
    'val_losses': val_losses,
    'val_psnrs': val_psnrs,
    'epochs_logged': epochs_logged,
    'best_psnr': best_psnr,
    'best_epoch': best_epoch
}

with open(os.path.join(result_dir, 'training_history.json'), 'w') as f:
    json.dump(training_history, f, indent=2)

print(f"Training history saved to: {os.path.join(result_dir, 'training_history.json')}")

In [None]:
# 11. Final training visualization
print("Creating final training visualization...")
plot_training_progress(train_losses, val_losses, val_psnrs, epochs_logged, 
                     save_path=os.path.join(result_dir, 'final_training_progress.png'))

print("Creating interactive dashboard...")
plot_interactive_progress(train_losses, val_losses, val_psnrs, epochs_logged)

# Test Set Inference and Evaluation

Now that training is complete, let's run inference on the test set to evaluate our model's performance on unseen data.

In [None]:
# 12. Load best model for testing
def load_model_for_testing(model_path, device='cuda'):
    """Load the best trained model for testing"""
    model = U_Net_3D()
    
    if os.path.exists(model_path):
        print(f"Loading model from: {model_path}")
        checkpoint = torch.load(model_path)
        
        # Handle different checkpoint formats
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
            epoch = checkpoint.get('epoch', 'unknown')
            print(f"Model loaded from epoch {epoch}")
        else:
            model.load_state_dict(checkpoint)
            print("Model loaded successfully")
    else:
        print(f"Model file not found: {model_path}")
        return None
    
    model.to(device)
    model.eval()
    
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        print(f"Using {torch.cuda.device_count()} GPUs for inference")
    
    return model

# Load the best model
best_model_path = os.path.join(model_dir, "model_best.pth")
test_model = load_model_for_testing(best_model_path)

In [None]:
# 13. Setup test dataset and evaluation metrics
import scipy.io as sio
from sklearn.metrics import mean_squared_error, peak_signal_noise_ratio
from skimage.metrics import structural_similarity as ssim
import cv2

def calculate_metrics(gt, restored):
    """Calculate comprehensive evaluation metrics"""
    metrics = {}
    
    # Ensure data is in correct format and range
    gt = np.clip(gt, 0, 1)
    restored = np.clip(restored, 0, 1)
    
    # PSNR (Peak Signal-to-Noise Ratio)
    mse = mean_squared_error(gt.flatten(), restored.flatten())
    if mse == 0:
        psnr = float('inf')
    else:
        psnr = 20 * np.log10(1.0 / np.sqrt(mse))
    metrics['PSNR'] = psnr
    
    # MSE (Mean Squared Error)
    metrics['MSE'] = mse
    
    # SSIM (Structural Similarity Index) - calculate per channel and average
    if len(gt.shape) == 3:  # Multi-channel
        ssim_values = []
        for i in range(gt.shape[0]):  # Assuming channels first
            ssim_val = ssim(gt[i], restored[i], data_range=1.0)
            ssim_values.append(ssim_val)
        metrics['SSIM'] = np.mean(ssim_values)
    else:  # Single channel
        metrics['SSIM'] = ssim(gt, restored, data_range=1.0)
    
    # SAM (Spectral Angle Mapper) for hyperspectral images
    if len(gt.shape) == 3:
        gt_flat = gt.reshape(gt.shape[0], -1)  # (channels, pixels)
        restored_flat = restored.reshape(restored.shape[0], -1)
        
        # Calculate spectral angle for each pixel
        dot_product = np.sum(gt_flat * restored_flat, axis=0)
        norm_gt = np.linalg.norm(gt_flat, axis=0)
        norm_restored = np.linalg.norm(restored_flat, axis=0)
        
        # Avoid division by zero
        valid_pixels = (norm_gt > 0) & (norm_restored > 0)
        cos_angles = np.zeros_like(dot_product)
        cos_angles[valid_pixels] = dot_product[valid_pixels] / (norm_gt[valid_pixels] * norm_restored[valid_pixels])
        
        # Clip to [-1, 1] to avoid numerical errors in arccos
        cos_angles = np.clip(cos_angles, -1, 1)
        angles = np.arccos(cos_angles)
        metrics['SAM'] = np.mean(angles[valid_pixels]) if np.any(valid_pixels) else 0
    
    return metrics

def save_test_results(results, save_dir):
    """Save test results to files"""
    os.makedirs(save_dir, exist_ok=True)
    
    # Save metrics summary
    metrics_file = os.path.join(save_dir, 'test_metrics.json')
    with open(metrics_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    # Create metrics summary table
    if 'per_image_metrics' in results:
        metrics_df = []
        for filename, metrics in results['per_image_metrics'].items():
            row = {'filename': filename}
            row.update(metrics)
            metrics_df.append(row)
        
        import pandas as pd
        df = pd.DataFrame(metrics_df)
        csv_file = os.path.join(save_dir, 'detailed_metrics.csv')
        df.to_csv(csv_file, index=False)
        print(f"Detailed metrics saved to: {csv_file}")
    
    print(f"Test results saved to: {save_dir}")

# Setup test data directory (you may need to adjust this path)
test_dir = val_dir  # Using validation data as test data for demonstration
test_result_dir = os.path.join(result_dir, 'test_results')
utils.mkdir(test_result_dir)

print(f"Test data directory: {test_dir}")
print(f"Test results will be saved to: {test_result_dir}")

In [None]:
# 14. Run inference on test set
def run_test_inference(model, test_loader, save_images=True, save_dir=None):
    """Run inference on test dataset and calculate metrics"""
    
    if model is None:
        print("No model available for testing")
        return None
    
    model.eval()
    
    all_metrics = []
    per_image_metrics = {}
    total_files = len(test_loader)
    
    print(f"Running inference on {total_files} test samples...")
    print("=" * 80)
    
    with torch.no_grad():
        for ii, data_test in enumerate(tqdm(test_loader, desc="Testing")):
            hsi_gt = data_test[0].cuda()
            hsi_noisy = data_test[1].cuda()
            filenames = data_test[2]
            
            # Run inference
            hsi_restored = model(hsi_noisy.unsqueeze(1))[:,0]
            hsi_restored = torch.clamp(hsi_restored, 0, 1)
            
            # Convert to numpy
            hsi_gt_np = hsi_gt.cpu().detach().numpy()
            hsi_noisy_np = hsi_noisy.cpu().detach().numpy()
            hsi_restored_np = hsi_restored.cpu().detach().numpy()
            
            # Process each item in the batch
            for batch_idx in range(len(hsi_gt_np)):
                gt = hsi_gt_np[batch_idx]
                noisy = hsi_noisy_np[batch_idx]
                restored = hsi_restored_np[batch_idx]
                filename = filenames[batch_idx]
                
                # Calculate metrics
                metrics = calculate_metrics(gt, restored)
                metrics_noisy = calculate_metrics(gt, noisy)
                
                # Store metrics
                per_image_metrics[filename] = {
                    'PSNR_restored': metrics['PSNR'],
                    'PSNR_noisy': metrics_noisy['PSNR'],
                    'PSNR_improvement': metrics['PSNR'] - metrics_noisy['PSNR'],
                    'MSE_restored': metrics['MSE'],
                    'MSE_noisy': metrics_noisy['MSE'],
                    'SSIM_restored': metrics['SSIM'],
                    'SSIM_noisy': metrics_noisy['SSIM'],
                    'SSIM_improvement': metrics['SSIM'] - metrics_noisy['SSIM']
                }
                
                if 'SAM' in metrics:
                    per_image_metrics[filename]['SAM_restored'] = metrics['SAM']
                    per_image_metrics[filename]['SAM_noisy'] = metrics_noisy['SAM']
                    per_image_metrics[filename]['SAM_improvement'] = metrics_noisy['SAM'] - metrics['SAM']
                
                all_metrics.append(metrics)
                
                # Save images if requested
                if save_images and save_dir:
                    # Save individual channels or RGB representation
                    if len(gt.shape) == 3 and gt.shape[0] > 3:  # Hyperspectral
                        # Save as .mat file for hyperspectral data
                        denoised_hsi = np.rot90(restored[:, ::-1], axes=(-2,-1))
                        mat_path = os.path.join(save_dir, filename[:-4] + '_restored.mat')
                        sio.savemat(mat_path, {'R_hsi': np.transpose(denoised_hsi, (1,2,0))})
                        
                        # Also save RGB representation (using specific bands)
                        if gt.shape[0] >= 50:  # Ensure we have enough bands
                            # Select RGB-like bands (adjust indices based on your data)
                            rgb_bands = [30, 20, 10]  # Example band selection
                            rgb_gt = gt[rgb_bands].transpose(1, 2, 0)
                            rgb_noisy = noisy[rgb_bands].transpose(1, 2, 0)
                            rgb_restored = restored[rgb_bands].transpose(1, 2, 0)
                            
                            # Normalize for display
                            rgb_gt = (rgb_gt * 255).astype(np.uint8)
                            rgb_noisy = (rgb_noisy * 255).astype(np.uint8)
                            rgb_restored = (rgb_restored * 255).astype(np.uint8)
                            
                            # Create comparison image
                            comparison = np.concatenate([rgb_noisy, rgb_restored, rgb_gt], axis=1)
                            cv2.imwrite(os.path.join(save_dir, filename[:-4] + '_comparison.jpg'), comparison)
                    else:
                        # Regular RGB or grayscale image
                        if len(gt.shape) == 3:
                            gt_img = (gt.transpose(1, 2, 0) * 255).astype(np.uint8)
                            noisy_img = (noisy.transpose(1, 2, 0) * 255).astype(np.uint8)
                            restored_img = (restored.transpose(1, 2, 0) * 255).astype(np.uint8)
                        else:
                            gt_img = (gt * 255).astype(np.uint8)
                            noisy_img = (noisy * 255).astype(np.uint8)
                            restored_img = (restored * 255).astype(np.uint8)
                        
                        # Create comparison
                        comparison = np.concatenate([noisy_img, restored_img, gt_img], axis=1)
                        cv2.imwrite(os.path.join(save_dir, filename[:-4] + '_comparison.jpg'), comparison)
                
                # Print progress for first few samples
                if ii < 5:
                    print(f"{filename}:")
                    print(f"   PSNR: {metrics_noisy['PSNR']:.2f} -> {metrics['PSNR']:.2f} dB (improvement: {metrics['PSNR'] - metrics_noisy['PSNR']:.2f})")
                    print(f"   SSIM: {metrics_noisy['SSIM']:.4f} -> {metrics['SSIM']:.4f} (improvement: {metrics['SSIM'] - metrics_noisy['SSIM']:.4f})")
    
    # Calculate average metrics
    avg_metrics = {}
    if all_metrics:
        for key in all_metrics[0].keys():
            avg_metrics[f'avg_{key}'] = np.mean([m[key] for m in all_metrics])
            avg_metrics[f'std_{key}'] = np.std([m[key] for m in all_metrics])
    
    # Calculate improvement statistics
    improvements = {
        'avg_PSNR_improvement': np.mean([m['PSNR_improvement'] for m in per_image_metrics.values()]),
        'avg_SSIM_improvement': np.mean([m['SSIM_improvement'] for m in per_image_metrics.values()]),
        'std_PSNR_improvement': np.std([m['PSNR_improvement'] for m in per_image_metrics.values()]),
        'std_SSIM_improvement': np.std([m['SSIM_improvement'] for m in per_image_metrics.values()])
    }
    
    if any('SAM_improvement' in m for m in per_image_metrics.values()):
        improvements['avg_SAM_improvement'] = np.mean([m['SAM_improvement'] for m in per_image_metrics.values() if 'SAM_improvement' in m])
        improvements['std_SAM_improvement'] = np.std([m['SAM_improvement'] for m in per_image_metrics.values() if 'SAM_improvement' in m])
    
    results = {
        'average_metrics': avg_metrics,
        'improvements': improvements,
        'per_image_metrics': per_image_metrics,
        'total_samples': total_files
    }
    
    return results

# Create test dataset
test_dataset = get_validation_data(test_dir, opt.MODEL.RATIO)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False)

print(f"Test dataset ready with {len(test_loader)} samples")

# Run inference
test_results = run_test_inference(test_model, test_loader, save_images=True, save_dir=test_result_dir)

In [None]:
# 15. Display and analyze test results
def display_test_results(results):
    """Display comprehensive test results with visualizations"""
    
    if results is None:
        print("No test results available")
        return
    
    print("\n" + "="*80)
    print("TEST RESULTS SUMMARY")
    print("="*80)
    
    # Average metrics
    avg_metrics = results['average_metrics']
    improvements = results['improvements']
    
    print(f"\nAVERAGE PERFORMANCE METRICS:")
    print("-" * 50)
    print(f"PSNR: {avg_metrics.get('avg_PSNR', 0):.4f} ± {avg_metrics.get('std_PSNR', 0):.4f} dB")
    print(f"MSE:  {avg_metrics.get('avg_MSE', 0):.6f} ± {avg_metrics.get('std_MSE', 0):.6f}")
    print(f"SSIM: {avg_metrics.get('avg_SSIM', 0):.4f} ± {avg_metrics.get('std_SSIM', 0):.4f}")
    if 'avg_SAM' in avg_metrics:
        print(f"SAM:  {avg_metrics.get('avg_SAM', 0):.4f} ± {avg_metrics.get('std_SAM', 0):.4f} rad")
    
    print(f"\nIMPROVEMENT OVER NOISY INPUT:")
    print("-" * 50)
    print(f"PSNR Improvement: {improvements.get('avg_PSNR_improvement', 0):.4f} ± {improvements.get('std_PSNR_improvement', 0):.4f} dB")
    print(f"SSIM Improvement: {improvements.get('avg_SSIM_improvement', 0):.4f} ± {improvements.get('std_SSIM_improvement', 0):.4f}")
    if 'avg_SAM_improvement' in improvements:
        print(f"SAM Improvement:  {improvements.get('avg_SAM_improvement', 0):.4f} ± {improvements.get('std_SAM_improvement', 0):.4f} rad")
    
    print(f"\nTOTAL SAMPLES PROCESSED: {results['total_samples']}")
    
    # Create visualizations
    per_image = results['per_image_metrics']
    
    if per_image:
        # Extract data for plotting
        filenames = list(per_image.keys())
        psnr_restored = [per_image[f]['PSNR_restored'] for f in filenames]
        psnr_noisy = [per_image[f]['PSNR_noisy'] for f in filenames]
        psnr_improvement = [per_image[f]['PSNR_improvement'] for f in filenames]
        ssim_restored = [per_image[f]['SSIM_restored'] for f in filenames]
        ssim_noisy = [per_image[f]['SSIM_noisy'] for f in filenames]
        ssim_improvement = [per_image[f]['SSIM_improvement'] for f in filenames]
        
        # Create matplotlib visualization
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        
        # PSNR comparison
        axes[0, 0].bar(range(len(filenames)), psnr_noisy, alpha=0.7, label='Noisy', color='red')
        axes[0, 0].bar(range(len(filenames)), psnr_restored, alpha=0.7, label='Restored', color='blue')
        axes[0, 0].set_title('PSNR Comparison')
        axes[0, 0].set_ylabel('PSNR (dB)')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # PSNR improvement
        axes[0, 1].bar(range(len(filenames)), psnr_improvement, color='green', alpha=0.7)
        axes[0, 1].set_title('PSNR Improvement')
        axes[0, 1].set_ylabel('PSNR Improvement (dB)')
        axes[0, 1].grid(True, alpha=0.3)
        axes[0, 1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
        
        # PSNR distribution
        axes[0, 2].hist([psnr_noisy, psnr_restored], bins=15, alpha=0.7, 
                       label=['Noisy', 'Restored'], color=['red', 'blue'])
        axes[0, 2].set_title('PSNR Distribution')
        axes[0, 2].set_xlabel('PSNR (dB)')
        axes[0, 2].set_ylabel('Frequency')
        axes[0, 2].legend()
        axes[0, 2].grid(True, alpha=0.3)
        
        # SSIM comparison
        axes[1, 0].bar(range(len(filenames)), ssim_noisy, alpha=0.7, label='Noisy', color='red')
        axes[1, 0].bar(range(len(filenames)), ssim_restored, alpha=0.7, label='Restored', color='blue')
        axes[1, 0].set_title('SSIM Comparison')
        axes[1, 0].set_ylabel('SSIM')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # SSIM improvement
        axes[1, 1].bar(range(len(filenames)), ssim_improvement, color='green', alpha=0.7)
        axes[1, 1].set_title('SSIM Improvement')
        axes[1, 1].set_ylabel('SSIM Improvement')
        axes[1, 1].grid(True, alpha=0.3)
        axes[1, 1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
        
        # SSIM distribution
        axes[1, 2].hist([ssim_noisy, ssim_restored], bins=15, alpha=0.7,
                       label=['Noisy', 'Restored'], color=['red', 'blue'])
        axes[1, 2].set_title('SSIM Distribution')
        axes[1, 2].set_xlabel('SSIM')
        axes[1, 2].set_ylabel('Frequency')
        axes[1, 2].legend()
        axes[1, 2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(test_result_dir, 'test_results_analysis.png'), dpi=300, bbox_inches='tight')
        plt.show()
        
        # Interactive Plotly visualization
        fig_plotly = make_subplots(
            rows=2, cols=2,
            subplot_titles=('PSNR Metrics', 'SSIM Metrics', 'Improvements', 'Correlation Analysis'),
            specs=[[{}, {}], [{}, {}]]
        )
        
        # PSNR metrics
        fig_plotly.add_trace(
            go.Scatter(x=list(range(len(filenames))), y=psnr_noisy, 
                      mode='lines+markers', name='PSNR Noisy', line=dict(color='red')),
            row=1, col=1
        )
        fig_plotly.add_trace(
            go.Scatter(x=list(range(len(filenames))), y=psnr_restored, 
                      mode='lines+markers', name='PSNR Restored', line=dict(color='blue')),
            row=1, col=1
        )
        
        # SSIM metrics
        fig_plotly.add_trace(
            go.Scatter(x=list(range(len(filenames))), y=ssim_noisy, 
                      mode='lines+markers', name='SSIM Noisy', line=dict(color='red')),
            row=1, col=2
        )
        fig_plotly.add_trace(
            go.Scatter(x=list(range(len(filenames))), y=ssim_restored, 
                      mode='lines+markers', name='SSIM Restored', line=dict(color='blue')),
            row=1, col=2
        )
        
        # Improvements
        fig_plotly.add_trace(
            go.Bar(x=list(range(len(filenames))), y=psnr_improvement, 
                   name='PSNR Improvement', marker_color='green'),
            row=2, col=1
        )
        
        # Correlation between PSNR and SSIM improvements
        fig_plotly.add_trace(
            go.Scatter(x=psnr_improvement, y=ssim_improvement, 
                      mode='markers', name='PSNR vs SSIM Improvement',
                      marker=dict(size=8, color='purple')),
            row=2, col=2
        )
        
        fig_plotly.update_layout(height=800, showlegend=True, title_text="Test Results Analysis Dashboard")
        fig_plotly.show()

# Display results
if test_results:
    display_test_results(test_results)
    
    # Save results
    save_test_results(test_results, test_result_dir)
    
    print(f"\nAll test results saved to: {test_result_dir}")
    print(f"Check the following files:")
    print(f"   - test_metrics.json: Overall metrics summary")
    print(f"   - detailed_metrics.csv: Per-image detailed metrics")
    print(f"   - test_results_analysis.png: Static analysis plots")
    print(f"   - Individual comparison images for visual inspection")
else:
    print("No test results to display")

# Summary and Conclusion

## Complete Pipeline Overview

This notebook provides a comprehensive pipeline for HSI denoising:

### 1. **Training with Visualization**
- Real-time loss tracking for both training and validation
- Interactive plots showing training progress
- Automatic model checkpointing and best model saving

### 2. **Test Set Evaluation**
- Comprehensive metrics: PSNR, SSIM, MSE, and SAM
- Before/after comparisons with improvement statistics
- Visual analysis with both static and interactive plots
- Detailed per-image and aggregate results

### 3. **Key Features**
- **GPU acceleration** with multi-GPU support
- **Comprehensive metrics** for hyperspectral image evaluation
- **Visual outputs** including comparison images and analysis plots
- **Robust error handling** and progress tracking
- **Flexible configuration** via YAML config files

### 4. **Output Files**
- `training_history.json`: Complete training metrics
- `test_metrics.json`: Comprehensive test results
- `detailed_metrics.csv`: Per-image breakdown
- Visualization plots and comparison images