# DDCM-Net: Dense Dilated Convolutions Merging Network

## 1. Import Required Libraries

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from model import create_model, create_trainer
from dataset_loader import create_dataloaders

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Model Architecture Setup

In [None]:
# Model configuration
NUM_CLASSES = 6
BACKBONE = 'resnet50'
PRETRAINED = True

# Class names for visualization
CLASS_NAMES = [
    'Impervious surfaces',  # 0 - White
    'Building',             # 1 - Blue  
    'Low vegetation',       # 2 - Cyan
    'Tree',                 # 3 - Green
    'Car',                  # 4 - Yellow
    'Clutter/background'    # 5 - Red
]

# Create model
model = create_model(
    variant='base',
    num_classes=NUM_CLASSES,
    backbone=BACKBONE,
    pretrained=PRETRAINED
)

# Create trainer wrapper
trainer = create_trainer(
    model=model,
    device=device,
    class_names=CLASS_NAMES
)

# Print model information
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Architecture: DDCM-Net with {BACKBONE} backbone")

### Load model from a previous training session

In [None]:
# trainer = trainer.load_model("best_model.pth")

## 3. Dataset Loading and Preprocessing

In [None]:
# Dataset configuration
DATA_ROOT = "./data"
DATASET = "potsdam"  # or "vaihingen" or "both"
BATCH_SIZE = 5  
NUM_WORKERS = 4
PATCH_SIZE = 256  

print("Loading dataset...")

# Check if data exists
if not os.path.exists(DATA_ROOT):
    print(f"Data directory not found: {DATA_ROOT}")
    print("Please ensure you have the ISPRS dataset in the data folder")
    print("Expected structure:")
    print("data/")
    print("├── potsdam/")
    print("│   ├── images/")
    print("│   └── labels/")
    print("└── vaihingen/")
    print("    ├── images/")
    print("    └── labels/")
else:
    train_loader, val_loader, test_loader, holdout_loader = create_dataloaders(
        root_dir=DATA_ROOT,
        dataset=DATASET,
        patch_size=PATCH_SIZE,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS
    )
    
    print(f"Dataset: {DATASET.upper()}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Patch size: {PATCH_SIZE}x{PATCH_SIZE}")
    print(f"Train batches: {len(train_loader):,}")
    print(f"Validation batches: {len(val_loader):,}")
    print(f"Test batches: {len(test_loader):,}")
    
    # Show a sample batch
    # sample_images, sample_labels = next(iter(train_loader))
    # print(f"Sample batch shape - Images: {sample_images.shape}, Labels: {sample_labels.shape}")
    # print(f"Classes in sample: {torch.unique(sample_labels).tolist()}")

### Dataset Visualization

In [None]:
# Visualize dataset samples
def visualize_dataset_samples(dataloader, num_samples=4):
    """Visualize samples from the dataset"""
    fig, axes = plt.subplots(2, num_samples, figsize=(20, 10))
    
    # Get a batch
    images, labels = next(iter(dataloader))
    
    for i in range(min(num_samples, len(images))):
        # Get single sample
        img = images[i]
        label = labels[i]
        
        # Denormalize image for visualization
        mean = torch.tensor([0.485, 0.456, 0.406])
        std = torch.tensor([0.229, 0.224, 0.225])
        img_denorm = img * std[:, None, None] + mean[:, None, None]
        img_denorm = torch.clamp(img_denorm, 0, 1)
        
        # Plot image
        axes[0, i].imshow(img_denorm.permute(1, 2, 0))
        axes[0, i].set_title(f'Image {i+1}')
        axes[0, i].axis('off')
        
        # Plot label
        im = axes[1, i].imshow(label, cmap='tab10', vmin=0, vmax=5)
        axes[1, i].set_title(f'Ground Truth {i+1}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print class statistics
    unique_classes = torch.unique(labels)
    print(f"Classes in this batch: {unique_classes.tolist()}")
    for class_id in unique_classes:
        class_name = CLASS_NAMES[class_id]
        pixel_count = (labels == class_id).sum().item()
        percentage = pixel_count / labels.numel() * 100
        print(f"  {class_id}: {class_name} - {pixel_count:,} pixels ({percentage:.1f}%)")

# Visualize samples if data is available
if 'train_loader' in locals():
    print("Dataset samples:")
    visualize_dataset_samples(train_loader, num_samples=4)
else:
    print("Dataset not loaded. Please ensure data is available.")

## 4. Model Training

In [None]:
# Training configuration
EPOCHS = 80
LEARNING_RATE = 8.5e-5  
WEIGHT_DECAY = 2e-5

print("Starting model training...")
print(f"Configuration:")
print(f"Epochs: {EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Weight decay: {WEIGHT_DECAY}")
print(f"Device: {device}")

# Start training if data is available
if 'train_loader' in locals() and 'val_loader' in locals():
    print(f"\nTraining started...")
    
    # Train the model
    history = trainer.fit(
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=EPOCHS,
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        use_mfb=True,
        lr_scheduler='poly'
    )
    
    print(f"\nTraining completed!")
    
else:
    print("Cannot start training - dataset not available")
    print("Please ensure the data is properly loaded before training")

## 5. Training Visualization and Metrics

In [None]:
# Plot training history
if 'history' in locals() and trainer.history['train_loss']:
    print("Training History Visualization")
    
    # Plot using trainer's built-in visualization
    trainer.plot_training_history(figsize=(18, 6))
    
    # Print final metrics
    final_train_loss = trainer.history['train_loss'][-1]
    final_val_loss = trainer.history['val_loss'][-1]
    final_train_acc = trainer.history['train_acc'][-1]
    final_val_acc = trainer.history['val_acc'][-1]
    final_train_miou = trainer.history['train_miou'][-1]
    final_val_miou = trainer.history['val_miou'][-1]
    
    print("\nFinal Training Metrics:")
    print(f"  Train Loss: {final_train_loss:.4f} | Val Loss: {final_val_loss:.4f}")
    print(f"  Train Acc:  {final_train_acc:.3f} | Val Acc:  {final_val_acc:.3f}")
    print(f"  Train mIoU: {final_train_miou:.3f} | Val mIoU: {final_val_miou:.3f}")
    
    # Find best epoch
    best_epoch = np.argmax(trainer.history['val_miou']) + 1
    best_miou = max(trainer.history['val_miou'])
    print(f"\nBest Validation mIoU: {best_miou:.3f} (Epoch {best_epoch})")
    
else:
    print("No training history available.")
    print("Please run training first or load a pre-trained model.")

## 5.5. Continue Training from Checkpoint

In [None]:
# Configuration for continued training
CONTINUE_TRAINING = True  # Set to True to enable
ADDITIONAL_EPOCHS = 30     
LAST_COMPLETED_EPOCH = 60 
LEARNING_RATE = 8.5e-5 / np.sqrt(2) 
WEIGHT_DECAY = 2e-5

if CONTINUE_TRAINING:
    if 'trainer' in locals() and 'train_loader' in locals() and 'val_loader' in locals():
        # Load the saved model
        if os.path.exists('best_model.pth'):
            print("Loading best model for continued training...")
            trainer.load_model('best_model.pth')
            
            # Show current training state
            total_epochs_trained = len(trainer.history['train_loss'])
            print(f"Model has been trained for {total_epochs_trained} epochs")
            
            if trainer.history['val_miou']:
                current_best = max(trainer.history['val_miou'])
                print(f"Current best validation mIoU: {current_best:.3f}")
            
            print(f"\nContinuing training for {ADDITIONAL_EPOCHS} more epochs...")
            updated_history = trainer.continue_training(
                train_loader=train_loader, 
                val_loader=val_loader,
                additional_epochs=ADDITIONAL_EPOCHS,
                current_epoch=LAST_COMPLETED_EPOCH,  # Use this to set the LR correctly
                initial_lr=LEARNING_RATE,
                weight_decay=WEIGHT_DECAY
            )
            
            # Plot updated training history
            print("\nUpdated Training History:")
            trainer.plot_training_history(figsize=(18, 6))
            
        else:
            print("No saved model found at 'best_model.pth'!")
            print("Please train the model first or check the model path.")
    else:
        print("Dataloaders not found!")
        print("Please ensure 'trainer', 'train_loader', and 'val_loader' are available.")
else:
    print("Continued training is disabled.")

## 6. Model Prediction and Visualization

In [None]:
# Load best model for inference
if os.path.exists('best_model.pth'):
    print("Loading best trained model...")
    trainer.load_model('best_model.pth')
    print("Model loaded successfully!")
else:
    print("No saved model found. Using current model state.")

# Visualize predictions on test data
if 'test_loader' in locals():
    print("\nModel Predictions Visualization")
    print("Comparing ground truth vs predictions on test samples...")
    
    # Visualize predictions
    trainer.visualize_predictions(test_loader, num_samples=4, figsize=(20, 8))
    
else:
    print("Test data not available for prediction visualization")

### Detailed Prediction Analysis

In [None]:
# Detailed evaluation on test set
if 'test_loader' in locals():
    print("Detailed Model Evaluation")
    
    # Evaluate model on test set
    model.eval()
    test_loss = 0.0
    
    # Initialize confusion matrix for class-wise metrics
    num_classes = NUM_CLASSES
    confusion_matrix = torch.zeros(num_classes, num_classes, device=device)
    criterion = nn.CrossEntropyLoss(reduction='sum')
    total_pixels = 0
    
    with torch.no_grad():
        for images, targets in tqdm(test_loader, desc="Evaluating"):
            images = images.to(device)
            targets = targets.to(device)
            
            with torch.no_grad():  
                outputs = model(images)
                loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            
            predictions = torch.argmax(outputs, dim=1)
            
            pred_flat = predictions.reshape(-1)
            target_flat = targets.reshape(-1)
            
            # Update histogram
            for t in range(num_classes):
                mask = (target_flat == t)
                if mask.sum() > 0:
                    p = pred_flat[mask]
                    bincount = torch.bincount(p, minlength=num_classes)
                    confusion_matrix[t] += bincount
            
            total_pixels += targets.numel()
    
    # Calculate metrics
    test_loss /= total_pixels
    
    # Calculate accuracy
    overall_acc = torch.diag(confusion_matrix).sum() / confusion_matrix.sum()
    
    print(f"\nTest Set Results:")
    print(f"  Overall Accuracy: {overall_acc.item():.3f}")
    print(f"  Average Loss: {test_loss:.4f}")
    
    # Calculate class-wise IoU
    class_ious = []
    print(f"\nClass-wise IoU:")
    
    for class_id in range(num_classes):
        # True positives: diagonal elements of the confusion matrix
        tp = confusion_matrix[class_id, class_id].item()
        
        # Sum over row and column for class i
        # Row sum = all actual instances of class i
        # Column sum = all predicted instances of class iF
        row_sum = confusion_matrix[class_id, :].sum().item()
        col_sum = confusion_matrix[:, class_id].sum().item()
        
        # IoU = TP / (TP + FP + FN) = TP / (row_sum + col_sum - TP)
        denominator = row_sum + col_sum - tp
        iou = tp / denominator if denominator > 0 else 0.0
        
        class_ious.append(iou)
        print(f"  {class_id}: {CLASS_NAMES[class_id]:<20} IoU: {iou:.3f}")
    
    mean_iou = sum(class_ious) / len(class_ious)
    print(f"\nMean IoU: {mean_iou:.3f}")
    
    # Plot confusion matrix
    plt.figure(figsize=(12, 10))
    confusion_norm = confusion_matrix / confusion_matrix.sum(dim=1, keepdim=True)
    confusion_np = confusion_norm.cpu().numpy()
    
    sns.heatmap(confusion_np, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
    plt.xlabel('Predicted')
    plt.ylabel('Ground Truth')
    plt.title('Normalized Confusion Matrix')
    plt.tight_layout()
    plt.show()
    
else:
    print("Test data not available for detailed evaluation")

# Test Time Augmentation

In [None]:
# Paper specification: 448×448 patches with 100-pixel stride, flipping and mirroring

import torch.nn.functional as F
from PIL import Image

def apply_tta_transforms(patch):
    """
    Apply TTA transformations: original, horizontal flip, vertical flip, both flips
    
    Args:
        patch: Input tensor of shape [1, 3, H, W]
    
    Returns:
        List of transformed patches
    """
    transforms = []
    
    # Original
    transforms.append(('original', patch))
    
    # Horizontal flip (mirroring)
    transforms.append(('hflip', torch.flip(patch, [3])))
    
    # Vertical flip
    transforms.append(('vflip', torch.flip(patch, [2])))
    
    # Both flips
    transforms.append(('hvflip', torch.flip(patch, [2, 3])))
    
    return transforms

def reverse_tta_transforms(prediction, transform_name):
    """
    Reverse the TTA transformation on the prediction
    
    Args:
        prediction: Model output tensor [1, num_classes, H, W]
        transform_name: Name of the transformation to reverse
    
    Returns:
        Reversed prediction tensor
    """
    if transform_name == 'original':
        return prediction
    elif transform_name == 'hflip':
        return torch.flip(prediction, [3])
    elif transform_name == 'vflip':
        return torch.flip(prediction, [2])
    elif transform_name == 'hvflip':
        return torch.flip(prediction, [2, 3])
    else:
        return prediction

def test_time_augmentation_sliding_window(model, image, patch_size=448, stride=100, device='cuda'):
    """
    Apply TTA with sliding windows as described in the paper
    
    Args:
        model: Trained DDCM-Net model
        image: Input image tensor [1, 3, H, W] or [3, H, W]
        patch_size: Size of sliding window patches (default: 448)
        stride: Stride for sliding window (default: 100)
        device: Device to run inference on
    
    Returns:
        Final averaged prediction [1, num_classes, H, W]
    """
    model.eval()
    
    # Ensure input is 4D [1, 3, H, W]
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    
    batch_size, channels, height, width = image.shape
    num_classes = NUM_CLASSES
    
    # Move to device
    image = image.to(device)
    
    # Initialize output canvas and count map
    prediction_canvas = torch.zeros(batch_size, num_classes, height, width, device=device)
    count_canvas = torch.zeros(batch_size, 1, height, width, device=device)
    
    print(f"Applying TTA with {patch_size}×{patch_size} patches, stride={stride}")
    print(f"Image size: {height}×{width}")
    
    # Calculate positions for sliding windows (ensuring full coverage)
    y_positions = []
    x_positions = []
    
    # Generate y positions
    for y in range(0, height - patch_size + 1, stride):
        y_positions.append(y)
    # Ensure we include the bottom edge
    if y_positions[-1] + patch_size < height:
        y_positions.append(height - patch_size)
    
    # Generate x positions
    for x in range(0, width - patch_size + 1, stride):
        x_positions.append(x)
    # Ensure we include the right edge
    if x_positions[-1] + patch_size < width:
        x_positions.append(width - patch_size)
    
    total_patches = len(y_positions) * len(x_positions)
    print(f"Processing {total_patches} patches ({len(y_positions)}×{len(x_positions)})")
    print(f"Y positions: {len(y_positions)} (from 0 to {y_positions[-1]})")
    print(f"X positions: {len(x_positions)} (from 0 to {x_positions[-1]})")
    
    patch_count = 0
    
    # Sliding window extraction with complete coverage
    for y in y_positions:
        for x in x_positions:
            patch_count += 1
            
            # Extract patch
            patch = image[:, :, y:y+patch_size, x:x+patch_size]
            
            # Apply TTA transformations
            transforms = apply_tta_transforms(patch)
            
            patch_predictions = []
            
            # Process each transformation
            for transform_name, transformed_patch in transforms:
                with torch.no_grad():
                    # Get model prediction
                    pred = model(transformed_patch)
                    
                    # Reverse the transformation on prediction
                    pred_reversed = reverse_tta_transforms(pred, transform_name)
                    patch_predictions.append(pred_reversed)
            
            # Average predictions from all TTA transformations
            avg_patch_pred = torch.stack(patch_predictions).mean(dim=0)
            
            # Add to prediction canvas
            prediction_canvas[:, :, y:y+patch_size, x:x+patch_size] += avg_patch_pred
            count_canvas[:, :, y:y+patch_size, x:x+patch_size] += 1
            
            # Progress update
            if patch_count % 50 == 0 or patch_count == total_patches:
                print(f"Processed {patch_count}/{total_patches} patches")
    
    # Average overlapping predictions (with safety check for division by zero)
    # Add small epsilon to avoid division by zero
    epsilon = 1e-8
    final_prediction = prediction_canvas / (count_canvas + epsilon)
    
    # Verify complete coverage
    min_count = count_canvas.min().item()
    max_count = count_canvas.max().item()
    print(f"Coverage verification: min_count={min_count}, max_count={max_count}")
    
    if min_count == 0:
        print("⚠️  Warning: Some pixels were not covered by any patches!")
        # Show which areas weren't covered
        uncovered_mask = (count_canvas[0, 0] == 0).cpu().numpy()
        uncovered_pixels = uncovered_mask.sum()
        print(f"   Uncovered pixels: {uncovered_pixels}/{count_canvas.numel()}")
    
    print("TTA inference completed!")
    return final_prediction


def visualize_tta_results(original_image, ground_truth, tta_prediction, regular_prediction=None):
    """
    Visualize TTA results compared to ground truth and regular prediction
    
    Args:
        original_image: Original image tensor [1, 3, H, W]
        ground_truth: Ground truth tensor [H, W] 
        tta_prediction: TTA prediction tensor [1, num_classes, H, W]
        regular_prediction: Optional regular prediction for comparison
    """
    # Convert predictions to class labels
    tta_pred_labels = torch.argmax(tta_prediction, dim=1)[0].cpu()
    
    # Denormalize image for visualization
    if original_image.dim() == 4:
        img = original_image[0]
    else:
        img = original_image
        
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    img_denorm = img * std[:, None, None] + mean[:, None, None]
    img_denorm = torch.clamp(img_denorm, 0, 1)
    
    # Create visualization
    if regular_prediction is not None:
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        reg_pred_labels = torch.argmax(regular_prediction, dim=1)[0].cpu()
        
        axes[3].imshow(reg_pred_labels, cmap='tab10', vmin=0, vmax=5)
        axes[3].set_title('Regular Prediction')
        axes[3].axis('off')
    else:
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    axes[0].imshow(img_denorm.permute(1, 2, 0))
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Ground truth
    axes[1].imshow(ground_truth, cmap='tab10', vmin=0, vmax=5)
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')
    
    # TTA prediction
    im = axes[2].imshow(tta_pred_labels, cmap='tab10', vmin=0, vmax=5)
    axes[2].set_title('TTA Prediction')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

def run_tta_evaluation(use_holdout=False, max_images=None, visualize=True):
    """
    Run TTA evaluation on full-resolution test images (paper-compliant approach)
    
    Args:
        use_holdout: If True, use holdout_loader instead of test_loader
        max_images: Maximum number of images to process (None for all)
        visualize: Whether to show visualizations for each image
    """
    # Select the appropriate dataset
    if use_holdout:
        if 'holdout_loader' not in locals() and 'holdout_loader' not in globals():
            print("Holdout loader not available. Please run the dataset loading cell first.")
            return
        selected_loader = holdout_loader
        dataset_name = "holdout"
    else:
        if 'test_loader' not in locals() and 'test_loader' not in globals():
            print("Test loader not available. Please run the dataset loading cell first.")
            return
        selected_loader = test_loader
        dataset_name = "test"
    
    print(f"Running TTA Evaluation with Full-Resolution Images ({dataset_name.upper()} SET)")
    print("=" * 70)
    
    # Load best model
    if os.path.exists('best_model.pth'):
        print("Loading best trained model for TTA...")
        trainer.load_model('best_model.pth')
    
    model.eval()
    
    test_dataset = selected_loader.dataset
    
    # Show available images
    available_images = test_dataset.get_available_images()
    print(f"\nAvailable {dataset_name} images: {len(available_images)}")
    
    # Limit number of images if specified
    if max_images is not None and max_images < len(available_images):
        available_images = available_images[:max_images]
        print(f"Processing first {max_images} images for demonstration")
    
    # Initialize aggregated metrics
    all_tta_accuracies = []
    all_reg_accuracies = []
    all_tta_ious = []
    all_reg_ious = []
    all_class_tta_accuracies = [[] for _ in range(NUM_CLASSES)]
    all_class_reg_accuracies = [[] for _ in range(NUM_CLASSES)]
    
    print(f"\nProcessing {len(available_images)} images...")
    
    # Process each available image
    for img_idx, img_name in enumerate(available_images):
        print(f"\n{'='*50}")
        print(f"Processing Image {img_idx + 1}/{len(available_images)}: {img_name}")
        print(f"{'='*50}")
        
        # Load full-resolution image and label
        try:
            full_image, full_label = test_dataset.get_full_image(img_idx)
            print(f"Loaded full image shape: {full_image.shape}")
            print(f"Loaded full label shape: {full_label.shape}")
            
            # Check if image is large enough for TTA
            _, _, height, width = full_image.shape
            min_size_for_tta = 448 + 100  # patch_size + stride
            
            if height < min_size_for_tta or width < min_size_for_tta:
                print(f"Image size ({height}×{width}) is smaller than minimum for TTA ({min_size_for_tta}×{min_size_for_tta})")
                print("TTA will still work but with reduced patch overlap.")
            
            # For very large images, we might want to crop a region for demonstration
            if height > 2000 or width > 2000:
                print(f"Large image detected ({height}×{width}). Cropping to 1500×1500 for faster processing.")
                crop_h = min(1500, height)
                crop_w = min(1500, width)
                start_h = (height - crop_h) // 2
                start_w = (width - crop_w) // 2
                
                full_image = full_image[:, :, start_h:start_h+crop_h, start_w:start_w+crop_w]
                full_label = full_label[start_h:start_h+crop_h, start_w:start_w+crop_w]
                
                print(f"Cropped to: {full_image.shape}")
            
        except Exception as e:
            print(f"Error loading full image {img_name}: {e}")
            continue
        
        # Apply TTA
        print(f"\nApplying Test Time Augmentation...")
        print(f"Using 448×448 patches with 100-pixel stride (paper specification)")
        
        try:
            tta_prediction = test_time_augmentation_sliding_window(
                model=model,
                image=full_image,
                patch_size=448,
                stride=100,
                device=device
            )
            
            # Regular prediction for comparison (resize full image to fit model if too large)
            print(f"\nRunning regular prediction for comparison...")
            with torch.no_grad():
                # For regular prediction, we need to handle potentially large images
                _, _, h, w = full_image.shape
                max_size = 1024  # Maximum size for regular prediction to avoid memory issues
                
                if h > max_size or w > max_size:
                    # Resize for regular prediction
                    scale_factor = max_size / max(h, w)
                    resized_image = F.interpolate(full_image, scale_factor=scale_factor, mode='bilinear', align_corners=False)
                    regular_prediction = model(resized_image.to(device))
                    # Resize back to original size
                    regular_prediction = F.interpolate(regular_prediction, size=(h, w), mode='bilinear', align_corners=False)
                else:
                    regular_prediction = model(full_image.to(device))
            
            # Visualize results for this image (if enabled)
            if visualize:
                print(f"\nVisualizing results for {img_name}...")
                visualize_tta_results(
                    original_image=full_image,
                    ground_truth=full_label,
                    tta_prediction=tta_prediction,
                    regular_prediction=regular_prediction
                )
            
            # Calculate metrics for this image
            tta_pred_labels = torch.argmax(tta_prediction, dim=1)[0].cpu()
            reg_pred_labels = torch.argmax(regular_prediction, dim=1)[0].cpu()
            
            # Overall accuracy
            tta_acc = (tta_pred_labels == full_label).float().mean().item()
            reg_acc = (reg_pred_labels == full_label).float().mean().item()
            
            all_tta_accuracies.append(tta_acc)
            all_reg_accuracies.append(reg_acc)
            
            print(f"\nResults for {img_name}:")
            print(f"Image size: {full_label.shape}")
            print(f"Regular Prediction Accuracy: {reg_acc:.4f}")
            print(f"TTA Prediction Accuracy: {tta_acc:.4f}")
            print(f"TTA Improvement: {tta_acc - reg_acc:+.4f}")
            
            # Class-wise accuracy for this image
            print(f"\nClass-wise Accuracy for {img_name}:")
            for class_id in range(NUM_CLASSES):
                mask = (full_label == class_id)
                if mask.sum() > 0:
                    tta_class_acc = (tta_pred_labels[mask] == class_id).float().mean().item()
                    reg_class_acc = (reg_pred_labels[mask] == class_id).float().mean().item()
                    improvement = tta_class_acc - reg_class_acc
                    
                    all_class_tta_accuracies[class_id].append(tta_class_acc)
                    all_class_reg_accuracies[class_id].append(reg_class_acc)
                    
                    print(f"  {CLASS_NAMES[class_id]:<20}: Regular {reg_class_acc:.3f} | TTA {tta_class_acc:.3f} | Δ{improvement:+.3f}")
            
            # Calculate IoU for this image
            def calculate_iou(pred, target, num_classes):
                ious = []
                for c in range(num_classes):
                    pred_c = (pred == c)
                    target_c = (target == c)
                    intersection = (pred_c & target_c).sum().float()
                    union = (pred_c | target_c).sum().float()
                    if union > 0:
                        ious.append((intersection / union).item())
                    else:
                        ious.append(0.0)
                return ious
            
            tta_ious = calculate_iou(tta_pred_labels, full_label, NUM_CLASSES)
            reg_ious = calculate_iou(reg_pred_labels, full_label, NUM_CLASSES)
            
            all_tta_ious.append(np.mean(tta_ious))
            all_reg_ious.append(np.mean(reg_ious))
            
            print(f"\nMean IoU for {img_name}:")
            print(f"Regular Prediction mIoU: {np.mean(reg_ious):.4f}")
            print(f"TTA Prediction mIoU: {np.mean(tta_ious):.4f}")
            print(f"TTA mIoU Improvement: {np.mean(tta_ious) - np.mean(reg_ious):+.4f}")
            
        except Exception as e:
            print(f"Error processing image {img_name}: {e}")
            continue
    
    # Aggregate results across all images
    if all_tta_accuracies:
        print(f"\n{'='*70}")
        print(f"AGGREGATED RESULTS ACROSS ALL {len(all_tta_accuracies)} IMAGES ({dataset_name.upper()} SET)")
        print(f"{'='*70}")
        
        # Overall metrics
        avg_tta_acc = np.mean(all_tta_accuracies)
        avg_reg_acc = np.mean(all_reg_accuracies)
        avg_tta_miou = np.mean(all_tta_ious)
        avg_reg_miou = np.mean(all_reg_ious)
        
        print(f"\nOverall Performance:")
        print(f"Average Regular Accuracy: {avg_reg_acc:.4f} (±{np.std(all_reg_accuracies):.4f})")
        print(f"Average TTA Accuracy: {avg_tta_acc:.4f} (±{np.std(all_tta_accuracies):.4f})")
        print(f"Average Accuracy Improvement: {avg_tta_acc - avg_reg_acc:+.4f}")
        
        print(f"\nAverage Regular mIoU: {avg_reg_miou:.4f} (±{np.std(all_reg_ious):.4f})")
        print(f"Average TTA mIoU: {avg_tta_miou:.4f} (±{np.std(all_tta_ious):.4f})")
        print(f"Average mIoU Improvement: {avg_tta_miou - avg_reg_miou:+.4f}")
        
        # Class-wise aggregated results
        print(f"\nAggregated Class-wise Performance:")
        for class_id in range(NUM_CLASSES):
            if all_class_tta_accuracies[class_id]:  # Only if this class appeared in the images
                avg_tta_class = np.mean(all_class_tta_accuracies[class_id])
                avg_reg_class = np.mean(all_class_reg_accuracies[class_id])
                improvement = avg_tta_class - avg_reg_class
                
                print(f"  {CLASS_NAMES[class_id]:<20}: Regular {avg_reg_class:.3f} | TTA {avg_tta_class:.3f} | Δ{improvement:+.3f}")
        
        # Performance distribution plot
        if len(all_tta_accuracies) > 1:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            
            # Accuracy comparison
            ax1.scatter(range(len(all_reg_accuracies)), all_reg_accuracies, alpha=0.7, label='Regular', color='blue')
            ax1.scatter(range(len(all_tta_accuracies)), all_tta_accuracies, alpha=0.7, label='TTA', color='red')
            ax1.plot([avg_reg_acc] * len(all_reg_accuracies), '--', color='blue', alpha=0.7, label=f'Avg Regular ({avg_reg_acc:.3f})')
            ax1.plot([avg_tta_acc] * len(all_tta_accuracies), '--', color='red', alpha=0.7, label=f'Avg TTA ({avg_tta_acc:.3f})')
            ax1.set_xlabel('Image Index')
            ax1.set_ylabel('Accuracy')
            ax1.set_title(f'Accuracy Comparison Across {len(all_tta_accuracies)} Images')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
            
            # mIoU comparison
            ax2.scatter(range(len(all_reg_ious)), all_reg_ious, alpha=0.7, label='Regular', color='blue')
            ax2.scatter(range(len(all_tta_ious)), all_tta_ious, alpha=0.7, label='TTA', color='red')
            ax2.plot([avg_reg_miou] * len(all_reg_ious), '--', color='blue', alpha=0.7, label=f'Avg Regular ({avg_reg_miou:.3f})')
            ax2.plot([avg_tta_miou] * len(all_tta_ious), '--', color='red', alpha=0.7, label=f'Avg TTA ({avg_tta_miou:.3f})')
            ax2.set_xlabel('Image Index')
            ax2.set_ylabel('mIoU')
            ax2.set_title(f'mIoU Comparison Across {len(all_tta_ious)} Images')
            ax2.legend()
            ax2.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.show()
    
    else:
        print("No images were successfully processed.")

In [None]:
# Run TTA Evaluation
# This will demonstrate the paper's TTA approach on test samples

# if 'test_loader' in locals():
#     print("TTA Evaluation Options:")
#     print("1. Process all test images with visualization")
#     print("2. Process limited test images (faster)")
#     print("3. Use holdout dataset instead of test dataset")
#     print()
    
#     print("Running TTA on first 3 test images...")
#     run_tta_evaluation(use_holdout=False, max_images=3, visualize=True)
# else:
#     print("Test loader not available.")
#     print("Please run the dataset loading cells first to enable TTA evaluation.")

if 'holdout_loader' in locals():
    print("\n" + "="*80)
    print("Running TTA on holdout dataset...")
    run_tta_evaluation(use_holdout=True, visualize=True)
else:
    print("\nHoldout dataset not available (this is normal if dataset='potsdam' or 'vaihingen')")
    print("Holdout dataset is only available when dataset='both'")

## Model Saving and Loading

In [None]:
# Save the trained model
import os

# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)

# Save model state dict
model_path = 'models/ddcm_net_trained85v2tta.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'training_history': history,
    'class_names': CLASS_NAMES
}, model_path)

print(f"Model saved to: {model_path}")

### Load a Model

In [None]:
# def load_trained_model(model_path):
#     """Load a trained DDCM-Net model"""
#     checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
#     # Recreate model with same config
#     config = checkpoint['model_config']
#     loaded_model = DDCMNet(
#         num_classes=config['num_classes'],
#         variant=config['variant']
#     )
    
#     # Load weights
#     loaded_model.load_state_dict(checkpoint['model_state_dict'])
#     loaded_model.to(device)
#     loaded_model.eval()
    
#     print(f"Model loaded from: {model_path}")
#     print(f"Configuration: {config}")
    
#     return loaded_model, checkpoint['training_history']

# model_path = '/models/ddcm_net_trained.pth'
# model, history = load_trained_model(model_path)