# DDCM-Net: Dense Dilated Convolutions Merging Network

## 1. Import Required Libraries

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from model import create_model, create_trainer
from dataset_loader import create_dataloaders

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Model Architecture Setup

In [None]:
# Model configuration
NUM_CLASSES = 6
BACKBONE = 'resnet50'
PRETRAINED = True

# Class names for visualization
CLASS_NAMES = [
    'Impervious surfaces',  # 0 - White
    'Building',             # 1 - Blue  
    'Low vegetation',       # 2 - Cyan
    'Tree',                 # 3 - Green
    'Car',                  # 4 - Yellow
    'Clutter/background'    # 5 - Red
]

# Create model
model = create_model(
    variant='base',
    num_classes=NUM_CLASSES,
    backbone=BACKBONE,
    pretrained=PRETRAINED
)

# Create trainer wrapper
trainer = create_trainer(
    model=model,
    device=device,
    class_names=CLASS_NAMES
)

# Print model information
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Architecture: DDCM-Net with {BACKBONE} backbone")

### Load model from a previous training session

In [None]:
# trainer = trainer.load_model("best_model.pth")

## 3. Dataset Loading and Preprocessing

In [None]:
# Dataset configuration
DATA_ROOT = "./data"
DATASET = "potsdam"  # or "vaihingen" or "both"
BATCH_SIZE = 5  
NUM_WORKERS = 4
PATCH_SIZE = 256  

print("Loading dataset...")

# Check if data exists
if not os.path.exists(DATA_ROOT):
    print(f"Data directory not found: {DATA_ROOT}")
    print("Please ensure you have the ISPRS dataset in the data folder")
    print("Expected structure:")
    print("data/")
    print("├── potsdam/")
    print("│   ├── images/")
    print("│   └── labels/")
    print("└── vaihingen/")
    print("    ├── images/")
    print("    └── labels/")
else:
    train_loader, val_loader, test_loader, holdout_loader = create_dataloaders(
        root_dir=DATA_ROOT,
        dataset=DATASET,
        patch_size=PATCH_SIZE,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS
    )
    
    print(f"Dataset: {DATASET.upper()}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Patch size: {PATCH_SIZE}x{PATCH_SIZE}")
    print(f"Train batches: {len(train_loader):,}")
    print(f"Validation batches: {len(val_loader):,}")
    print(f"Test batches: {len(test_loader):,}")
    
    # Show a sample batch
    # sample_images, sample_labels = next(iter(train_loader))
    # print(f"Sample batch shape - Images: {sample_images.shape}, Labels: {sample_labels.shape}")
    # print(f"Classes in sample: {torch.unique(sample_labels).tolist()}")

### Dataset Visualization

In [None]:
# Visualize dataset samples
def visualize_dataset_samples(dataloader, num_samples=4):
    """Visualize samples from the dataset"""
    fig, axes = plt.subplots(2, num_samples, figsize=(20, 10))
    
    # Get a batch
    images, labels = next(iter(dataloader))
    
    for i in range(min(num_samples, len(images))):
        # Get single sample
        img = images[i]
        label = labels[i]
        
        # Denormalize image for visualization
        mean = torch.tensor([0.485, 0.456, 0.406])
        std = torch.tensor([0.229, 0.224, 0.225])
        img_denorm = img * std[:, None, None] + mean[:, None, None]
        img_denorm = torch.clamp(img_denorm, 0, 1)
        
        # Plot image
        axes[0, i].imshow(img_denorm.permute(1, 2, 0))
        axes[0, i].set_title(f'Image {i+1}')
        axes[0, i].axis('off')
        
        # Plot label
        im = axes[1, i].imshow(label, cmap='tab10', vmin=0, vmax=5)
        axes[1, i].set_title(f'Ground Truth {i+1}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print class statistics
    unique_classes = torch.unique(labels)
    print(f"Classes in this batch: {unique_classes.tolist()}")
    for class_id in unique_classes:
        class_name = CLASS_NAMES[class_id]
        pixel_count = (labels == class_id).sum().item()
        percentage = pixel_count / labels.numel() * 100
        print(f"  {class_id}: {class_name} - {pixel_count:,} pixels ({percentage:.1f}%)")

# Visualize samples if data is available
if 'train_loader' in locals():
    print("Dataset samples:")
    visualize_dataset_samples(train_loader, num_samples=4)
else:
    print("Dataset not loaded. Please ensure data is available.")

## 4. Model Training

In [5]:
# Training configuration
EPOCHS = 80
LEARNING_RATE = 8.5e-5 / np.sqrt(2)  
WEIGHT_DECAY = 2e-5
USE_DUAL_LR = True  

print("Starting model training...")
print(f"Configuration:")
print(f"Epochs: {EPOCHS}")
print(f"Learning rate: {LEARNING_RATE:.2e}")
print(f"Weight decay: {WEIGHT_DECAY}")
print(f"Device: {device}")

# Start training if data is available
if 'train_loader' in locals() and 'val_loader' in locals():
    print(f"\nTraining started...")
        
    # Train the model
    history = trainer.fit(
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=EPOCHS,
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        use_mfb=True,
        lr_scheduler='step',  # Will be overridden by use_dual_lr
        use_dual_lr=USE_DUAL_LR  
    )
    
else:
    print("Cannot start training - dataset not available")

Computing class frequencies: 100%|██████████| 1000/1000 [05:16<00:00,  3.16it/s]
Computing class frequencies: 100%|██████████| 1000/1000 [05:16<00:00,  3.16it/s]


Class frequencies: [0.3154688  0.24464083 0.23769112 0.14688781 0.01443058 0.04088106]
Median frequency: 0.146888
Class weights: [ 0.46561757  0.60042226  0.6179777   0.9999999  10.178922    3.5930521 ]
Using dual LR scheduling: per-iteration polynomial + per-epoch StepLR
Training on cuda
Model parameters: 9,992,628
Dual LR mode: per-iteration polynomial (lr=6.01e-05) + per-epoch StepLR (γ=0.85, step=15)
Using pixel-weighted averaging (matches paper implementation)

Epoch 1/80


                                                                                                

KeyboardInterrupt: 

## 5. Training Visualization and Metrics

In [None]:
# Plot training history
if 'history' in locals() and trainer.history['train_loss']:
    print("Training History Visualization")
    
    # Plot using trainer's built-in visualization
    trainer.plot_training_history(figsize=(18, 6))
    
    # Print final metrics
    final_train_loss = trainer.history['train_loss'][-1]
    final_val_loss = trainer.history['val_loss'][-1]
    final_train_acc = trainer.history['train_acc'][-1]
    final_val_acc = trainer.history['val_acc'][-1]
    final_train_miou = trainer.history['train_miou'][-1]
    final_val_miou = trainer.history['val_miou'][-1]
    
    print("\nFinal Training Metrics:")
    print(f"  Train Loss: {final_train_loss:.4f} | Val Loss: {final_val_loss:.4f}")
    print(f"  Train Acc:  {final_train_acc:.3f} | Val Acc:  {final_val_acc:.3f}")
    print(f"  Train mIoU: {final_train_miou:.3f} | Val mIoU: {final_val_miou:.3f}")
    
    # Find best epoch
    best_epoch = np.argmax(trainer.history['val_miou']) + 1
    best_miou = max(trainer.history['val_miou'])
    print(f"\nBest Validation mIoU: {best_miou:.3f} (Epoch {best_epoch})")
    
else:
    print("No training history available.")
    print("Please run training first or load a pre-trained model.")

## 5.5. Continue Training from Checkpoint

In [None]:
# Configuration for continued training
CONTINUE_TRAINING = True  # Set to True to enable
ADDITIONAL_EPOCHS = 20     
LAST_COMPLETED_EPOCH = 1 
LEARNING_RATE = 1e-6 
WEIGHT_DECAY = 2e-5

if CONTINUE_TRAINING:
    if 'trainer' in locals() and 'train_loader' in locals() and 'val_loader' in locals():
        # Load the saved model
        if os.path.exists('best_model.pth'):
            print("Loading best model for continued training...")
            trainer.load_model('best_model.pth')
            
            # Show current training state
            total_epochs_trained = len(trainer.history['train_loss'])
            print(f"Model has been trained for {total_epochs_trained} epochs")
            
            if trainer.history['val_miou']:
                current_best = max(trainer.history['val_miou'])
                print(f"Current best validation mIoU: {current_best:.3f}")
            
            print(f"\nContinuing training for {ADDITIONAL_EPOCHS} more epochs...")
            updated_history = trainer.continue_training(
                train_loader=train_loader, 
                val_loader=val_loader,
                additional_epochs=ADDITIONAL_EPOCHS,
                current_epoch=LAST_COMPLETED_EPOCH,  # Use this to set the LR correctly
                initial_lr=LEARNING_RATE,
                weight_decay=WEIGHT_DECAY
            )
            
            # Plot updated training history
            print("\nUpdated Training History:")
            trainer.plot_training_history(figsize=(18, 6))
            
        else:
            print("No saved model found at 'best_model.pth'!")
            print("Please train the model first or check the model path.")
    else:
        print("Dataloaders not found!")
        print("Please ensure 'trainer', 'train_loader', and 'val_loader' are available.")
else:
    print("Continued training is disabled.")

## 6. Model Prediction and Visualization

In [None]:
# Load best model for inference
if os.path.exists('best_model.pth'):
    print("Loading best trained model...")
    trainer.load_model('best_model.pth')
    print("Model loaded successfully!")
else:
    print("No saved model found. Using current model state.")

# Visualize predictions on test data
if 'test_loader' in locals():
    print("\nModel Predictions Visualization")
    print("Comparing ground truth vs predictions on test samples...")
    
    # Visualize predictions
    trainer.visualize_predictions(test_loader, num_samples=4, figsize=(20, 8))
    
else:
    print("Test data not available for prediction visualization")

### Detailed Prediction Analysis

In [None]:
# Detailed evaluation on test set
if 'test_loader' in locals():
    print("Detailed Model Evaluation")
    
    # Evaluate model on test set
    model.eval()
    test_loss = 0.0
    
    # Initialize confusion matrix for class-wise metrics
    num_classes = NUM_CLASSES
    confusion_matrix = torch.zeros(num_classes, num_classes, device=device)
    criterion = nn.CrossEntropyLoss(reduction='sum')
    total_pixels = 0
    
    with torch.no_grad():
        for images, targets in tqdm(test_loader, desc="Evaluating"):
            images = images.to(device)
            targets = targets.to(device)
            
            with torch.no_grad():  
                outputs = model(images)
                loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            
            predictions = torch.argmax(outputs, dim=1)
            
            pred_flat = predictions.reshape(-1)
            target_flat = targets.reshape(-1)
            
            # Update histogram
            for t in range(num_classes):
                mask = (target_flat == t)
                if mask.sum() > 0:
                    p = pred_flat[mask]
                    bincount = torch.bincount(p, minlength=num_classes)
                    confusion_matrix[t] += bincount
            
            total_pixels += targets.numel()
    
    # Calculate metrics
    test_loss /= total_pixels
    
    # Calculate accuracy
    overall_acc = torch.diag(confusion_matrix).sum() / confusion_matrix.sum()
    
    print(f"\nTest Set Results:")
    print(f"  Overall Accuracy: {overall_acc.item():.3f}")
    print(f"  Average Loss: {test_loss:.4f}")
    
    # Calculate class-wise IoU
    class_ious = []
    print(f"\nClass-wise IoU:")
    
    for class_id in range(num_classes):
        # True positives: diagonal elements of the confusion matrix
        tp = confusion_matrix[class_id, class_id].item()
        
        # Sum over row and column for class i
        # Row sum = all actual instances of class i
        # Column sum = all predicted instances of class iF
        row_sum = confusion_matrix[class_id, :].sum().item()
        col_sum = confusion_matrix[:, class_id].sum().item()
        
        # IoU = TP / (TP + FP + FN) = TP / (row_sum + col_sum - TP)
        denominator = row_sum + col_sum - tp
        iou = tp / denominator if denominator > 0 else 0.0
        
        class_ious.append(iou)
        print(f"  {class_id}: {CLASS_NAMES[class_id]:<20} IoU: {iou:.3f}")
    
    mean_iou = sum(class_ious) / len(class_ious)
    print(f"\nMean IoU: {mean_iou:.3f}")
    
    # Plot confusion matrix
    plt.figure(figsize=(12, 10))
    confusion_norm = confusion_matrix / confusion_matrix.sum(dim=1, keepdim=True)
    confusion_np = confusion_norm.cpu().numpy()
    
    sns.heatmap(confusion_np, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
    plt.xlabel('Predicted')
    plt.ylabel('Ground Truth')
    plt.title('Normalized Confusion Matrix')
    plt.tight_layout()
    plt.show()
    
else:
    print("Test data not available for detailed evaluation")

# Test Time Augmentation

In [None]:
from tta_utils import TTAPredictor, TTAEvaluator

## Model Saving and Loading

In [None]:
# Save the trained model
import os

# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)

# Save model state dict
model_path = 'models/ddcm_net_trained85v2tta.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'training_history': history,
    'class_names': CLASS_NAMES
}, model_path)

print(f"Model saved to: {model_path}")

### Load a Model

In [None]:
# def load_trained_model(model_path):
#     """Load a trained DDCM-Net model"""
#     checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
#     # Recreate model with same config
#     config = checkpoint['model_config']
#     loaded_model = DDCMNet(
#         num_classes=config['num_classes'],
#         variant=config['variant']
#     )
    
#     # Load weights
#     loaded_model.load_state_dict(checkpoint['model_state_dict'])
#     loaded_model.to(device)
#     loaded_model.eval()
    
#     print(f"Model loaded from: {model_path}")
#     print(f"Configuration: {config}")
    
#     return loaded_model, checkpoint['training_history']

# model_path = '/models/ddcm_net_trained.pth'
# model, history = load_trained_model(model_path)

# Enhanced DDCM-Net with Global Context via Self-Attention

This section demonstrates the enhanced DDCM-Net model that integrates transformer-style self-attention blocks to provide global context modeling. The enhancement improves long-range dependency modeling while maintaining computational efficiency through windowed attention mechanisms.

## Model Variants Comparison

In [None]:
print("=== Enhanced DDCM-Net Model Comparison ===\n")

# Create different model variants for comparison
print("1. Creating model variants...")

# Original DDCM-Net
model_base = create_model(variant='base', num_classes=NUM_CLASSES, backbone=BACKBONE)
base_params = sum(p.numel() for p in model_base.parameters())
print(f"   Base DDCM-Net: {base_params:,} parameters")

# Enhanced DDCM-Net with default global context
model_enhanced = create_model(variant='enhanced', num_classes=NUM_CLASSES, backbone=BACKBONE)
enhanced_params = sum(p.numel() for p in model_enhanced.parameters())
print(f"   Enhanced DDCM-Net: {enhanced_params:,} parameters (+{enhanced_params - base_params:,})")

# Enhanced DDCM-Net with custom configuration
custom_config = {
    'num_heads': 8,           # Number of attention heads
    'num_layers': 2,          # Number of transformer layers  
    'use_windowed': True,     # Use windowed attention for efficiency
    'window_size': 7,         # Window size for windowed attention
    'dropout': 0.1,           # Dropout rate
    'pos_embed': True         # Use positional embeddings
}

model_custom = create_model(
    variant='enhanced', 
    num_classes=NUM_CLASSES, 
    backbone=BACKBONE,
    global_context_config=custom_config
)
custom_params = sum(p.numel() for p in model_custom.parameters())
print(f"   Custom Enhanced DDCM-Net: {custom_params:,} parameters")

# Test forward pass compatibility
print(f"\n2. Testing forward pass compatibility...")
sample_input = torch.randn(2, 3, 512, 512)
print(f"   Input shape: {sample_input.shape}")

model_base.eval()
model_enhanced.eval()

with torch.no_grad():
    output_base = model_base(sample_input)
    output_enhanced = model_enhanced(sample_input)

print(f"   Base model output: {output_base.shape}")
print(f"   Enhanced model output: {output_enhanced.shape}")
print(f"   Output shapes match: {output_base.shape == output_enhanced.shape}")

# Show parameter breakdown
print(f"\n3. Parameter Analysis...")
parameter_overhead = ((enhanced_params - base_params) / base_params * 100)
print(f"   Parameter overhead: {parameter_overhead:.1f}%")
print(f"   Additional parameters: {enhanced_params - base_params:,}")

# Show different configuration options
print(f"\n4. Available Configuration Examples...")
configs = [
    {
        'name': 'Lightweight (for limited compute)',
        'config': {'num_heads': 4, 'num_layers': 1, 'window_size': 7}
    },
    {
        'name': 'Standard (recommended)',
        'config': {'num_heads': 8, 'num_layers': 2, 'window_size': 7}
    },
    {
        'name': 'High-capacity (for best accuracy)',
        'config': {'num_heads': 8, 'num_layers': 3, 'window_size': 14}
    }
]

for config_info in configs:
    name = config_info['name']
    config = config_info['config']
    
    test_model = create_model(
        variant='enhanced', 
        num_classes=NUM_CLASSES,
        global_context_config=config
    )
    params = sum(p.numel() for p in test_model.parameters())
    overhead = ((params - base_params) / base_params * 100)
    
    print(f"   {name}: {params:,} params (+{overhead:.1f}%)")

print(f"\n=== Model variants created successfully! ===")
print("The enhanced model is backward compatible and uses the same training interface.")

## Enhanced Model Training

Now let's train the enhanced model with the same configuration as the base model to compare performance:

In [None]:
# Enhanced Model Training Configuration
TRAIN_ENHANCED = True  # Set to True to train the enhanced model
ENHANCED_EPOCHS = 80
ENHANCED_LR = 8.5e-5 / np.sqrt(2) 
ENHANCED_WEIGHT_DECAY = 2e-5
USE_ENHANCED_DUAL_LR = True  

if TRAIN_ENHANCED:
    if 'train_loader' in locals() and 'val_loader' in locals():
        print("=== Training Enhanced DDCM-Net ===")
        
        # Create enhanced model with standard configuration
        enhanced_model = create_model(
            variant='enhanced',
            num_classes=NUM_CLASSES,
            backbone=BACKBONE,
            pretrained=PRETRAINED,
            global_context_config={
                'num_heads': 8,
                'num_layers': 2,
                'use_windowed': True,
                'window_size': 7,
                'dropout': 0.1,
                'pos_embed': True
            }
        )
        
        # Create enhanced trainer
        enhanced_trainer = create_trainer(
            model=enhanced_model,
            device=device,
            class_names=CLASS_NAMES
        )
        
        # Print model comparison
        enhanced_params = sum(p.numel() for p in enhanced_model.parameters())
        base_params = sum(p.numel() for p in model.parameters())
        print(f"Enhanced model parameters: {enhanced_params:,}")
        print(f"Base model parameters: {base_params:,}")
        print(f"Parameter increase: +{enhanced_params - base_params:,} (+{((enhanced_params - base_params) / base_params * 100):.1f}%)")
        
        print(f"\nStarting enhanced model training...")
        print(f"Configuration:")
        print(f"  Epochs: {ENHANCED_EPOCHS}")
        print(f"  Learning rate: {ENHANCED_LR:.2e}")
        print(f"  Weight decay: {ENHANCED_WEIGHT_DECAY}")
        print(f"  Device: {device}")
        
        # Train the enhanced model
        enhanced_history = enhanced_trainer.fit(
            train_loader=train_loader,
            val_loader=val_loader,
            epochs=ENHANCED_EPOCHS,
            lr=ENHANCED_LR,
            weight_decay=ENHANCED_WEIGHT_DECAY,
            use_mfb=True,
            lr_scheduler='step',  # Will be overridden by use_dual_lr
            use_dual_lr=USE_ENHANCED_DUAL_LR 
        )
        
        # Save enhanced model with different name
        enhanced_trainer.save_model('best_enhanced_model.pth')
        print("Enhanced model saved as 'best_enhanced_model.pth'")
        
    else:
        print("Cannot start enhanced training - dataset not available")
else:
    print("Enhanced model training is disabled.")

### Enhanced Model Training Visualization

In [None]:
# Plot enhanced model training history and compare with base model
if 'enhanced_history' in locals() and enhanced_trainer.history['train_loss']:
    print("Enhanced Model Training History Visualization")
    
    # Plot enhanced training history
    enhanced_trainer.plot_training_history(figsize=(18, 6))
    
    # Print final enhanced metrics
    final_enhanced_train_loss = enhanced_trainer.history['train_loss'][-1]
    final_enhanced_val_loss = enhanced_trainer.history['val_loss'][-1]
    final_enhanced_train_acc = enhanced_trainer.history['train_acc'][-1]
    final_enhanced_val_acc = enhanced_trainer.history['val_acc'][-1]
    final_enhanced_train_miou = enhanced_trainer.history['train_miou'][-1]
    final_enhanced_val_miou = enhanced_trainer.history['val_miou'][-1]
    
    print("\nFinal Enhanced Model Metrics:")
    print(f"  Train Loss: {final_enhanced_train_loss:.4f} | Val Loss: {final_enhanced_val_loss:.4f}")
    print(f"  Train Acc:  {final_enhanced_train_acc:.3f} | Val Acc:  {final_enhanced_val_acc:.3f}")
    print(f"  Train mIoU: {final_enhanced_train_miou:.3f} | Val mIoU: {final_enhanced_val_miou:.3f}")
    
    # Find best epoch for enhanced model
    best_enhanced_epoch = np.argmax(enhanced_trainer.history['val_miou']) + 1
    best_enhanced_miou = max(enhanced_trainer.history['val_miou'])
    print(f"\nBest Enhanced Validation mIoU: {best_enhanced_miou:.3f} (Epoch {best_enhanced_epoch})")
    
    # Compare with base model if available
    if 'trainer' in locals() and trainer.history['val_miou']:
        base_best_miou = max(trainer.history['val_miou'])
        improvement = best_enhanced_miou - base_best_miou
        print(f"\nModel Comparison:")
        print(f"  Base Model Best mIoU: {base_best_miou:.3f}")
        print(f"  Enhanced Model Best mIoU: {best_enhanced_miou:.3f}")
        print(f"  Improvement: {improvement:+.3f} ({(improvement/base_best_miou*100):+.1f}%)")
        
        # Side-by-side training curves comparison
        if len(trainer.history['val_miou']) > 0 and len(enhanced_trainer.history['val_miou']) > 0:
            fig, axes = plt.subplots(1, 3, figsize=(21, 6))
            
            # Get epochs for both models
            base_epochs = range(1, len(trainer.history['train_loss']) + 1)
            enhanced_epochs = range(1, len(enhanced_trainer.history['train_loss']) + 1)
            
            # Loss comparison
            axes[0].plot(base_epochs, trainer.history['train_loss'], 'b-', label='Base Train', linewidth=2)
            axes[0].plot(base_epochs, trainer.history['val_loss'], 'b--', label='Base Val', linewidth=2)
            axes[0].plot(enhanced_epochs, enhanced_trainer.history['train_loss'], 'r-', label='Enhanced Train', linewidth=2)
            axes[0].plot(enhanced_epochs, enhanced_trainer.history['val_loss'], 'r--', label='Enhanced Val', linewidth=2)
            axes[0].set_title('Loss Comparison')
            axes[0].set_xlabel('Epoch')
            axes[0].set_ylabel('Loss')
            axes[0].legend()
            axes[0].grid(True, alpha=0.3)
            
            # Accuracy comparison
            axes[1].plot(base_epochs, trainer.history['train_acc'], 'b-', label='Base Train', linewidth=2)
            axes[1].plot(base_epochs, trainer.history['val_acc'], 'b--', label='Base Val', linewidth=2)
            axes[1].plot(enhanced_epochs, enhanced_trainer.history['train_acc'], 'r-', label='Enhanced Train', linewidth=2)
            axes[1].plot(enhanced_epochs, enhanced_trainer.history['val_acc'], 'r--', label='Enhanced Val', linewidth=2)
            axes[1].set_title('Accuracy Comparison')
            axes[1].set_xlabel('Epoch')
            axes[1].set_ylabel('Accuracy')
            axes[1].legend()
            axes[1].grid(True, alpha=0.3)
            
            # mIoU comparison
            axes[2].plot(base_epochs, trainer.history['train_miou'], 'b-', label='Base Train', linewidth=2)
            axes[2].plot(base_epochs, trainer.history['val_miou'], 'b--', label='Base Val', linewidth=2)
            axes[2].plot(enhanced_epochs, enhanced_trainer.history['train_miou'], 'r-', label='Enhanced Train', linewidth=2)
            axes[2].plot(enhanced_epochs, enhanced_trainer.history['val_miou'], 'r--', label='Enhanced Val', linewidth=2)
            axes[2].set_title('mIoU Comparison')
            axes[2].set_xlabel('Epoch')
            axes[2].set_ylabel('mIoU')
            axes[2].legend()
            axes[2].grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.show()
    
else:
    print("No enhanced training history available.")
    print("Please run enhanced training first.")

## Test Time Augmentation (TTA) Comparison

Now let's compare both the base and enhanced models using Test Time Augmentation to see the performance differences:

In [None]:
# Import TTA utilities
from tta_utils import TTAEvaluator, load_model_for_tta

print("=== Test Time Augmentation Model Comparison ===\n")

# TTA Configuration
RUN_TTA_COMPARISON = True  # Set to True to enable TTA comparison
MAX_TTA_IMAGES = 3         # Limit images for faster demonstration
TTA_PATCH_SIZE = 448       
TTA_STRIDE = 100           
TTA_VISUALIZE = True       # Show visualizations for each image

if RUN_TTA_COMPARISON:
    if 'test_loader' in locals():
        print("Preparing models for TTA comparison...")
        
        # Create TTA evaluator
        tta_evaluator = TTAEvaluator(class_names=CLASS_NAMES)
        
        # Prepare models dictionary
        models_to_compare = {}
        
        # Add base model (current trained model)
        if 'trainer' in locals() and os.path.exists('best_model.pth'):
            print("Loading base model...")
            base_model_for_tta = load_model_for_tta(
                'best_model.pth', 
                variant='base', 
                num_classes=NUM_CLASSES, 
                backbone=BACKBONE,
                device=device
            )
            models_to_compare['Base DDCM-Net'] = base_model_for_tta
        
        # Add enhanced model if available
        if 'enhanced_trainer' in locals() and os.path.exists('best_enhanced_model.pth'):
            print("Loading enhanced model...")
            enhanced_model_for_tta = load_model_for_tta(
                'best_enhanced_model.pth',
                variant='enhanced',
                num_classes=NUM_CLASSES,
                backbone=BACKBONE,
                device=device
            )
            models_to_compare['Enhanced DDCM-Net'] = enhanced_model_for_tta
        elif 'enhanced_model' in locals():
            print("Using current enhanced model...")
            models_to_compare['Enhanced DDCM-Net'] = enhanced_model
        
        if len(models_to_compare) > 0:
            print(f"Models to compare: {list(models_to_compare.keys())}")
            print(f"TTA Configuration:")
            print(f"  Patch size: {TTA_PATCH_SIZE}×{TTA_PATCH_SIZE}")
            print(f"  Stride: {TTA_STRIDE} pixels")
            print(f"  Max images: {MAX_TTA_IMAGES}")
            print(f"  Visualizations: {'Enabled' if TTA_VISUALIZE else 'Disabled'}")
            
            # Run TTA comparison
            comparison_results = tta_evaluator.compare_models_with_tta(
                models_dict=models_to_compare,
                test_loader=test_loader,
                max_images=MAX_TTA_IMAGES,
                visualize=TTA_VISUALIZE,
                patch_size=TTA_PATCH_SIZE,
                stride=TTA_STRIDE
            )
            
            print("\n=== TTA Comparison Complete ===")
            
        else:
            print("No trained models available for TTA comparison.")
            print("Please train at least one model first.")
            
    else:
        print("Test loader not available.")
        print("Please run the dataset loading cells first to enable TTA evaluation.")
else:
    print("TTA comparison is disabled.")
    print("Set RUN_TTA_COMPARISON = True to enable TTA model comparison.")

### TTA Results Visualization and Analysis

In [None]:
# Analyze and visualize TTA comparison results
if 'comparison_results' in locals() and comparison_results:
    print("=== TTA Results Analysis ===\n")
    
    # Extract model names and results
    model_names = list(comparison_results.keys())
    
    if len(model_names) >= 2:
        # Create comprehensive comparison visualization
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        
        # Prepare data for plotting
        models_data = []
        for model_name in model_names:
            results = comparison_results[model_name]
            models_data.append({
                'name': model_name,
                'reg_accs': results['all_reg_accuracies'],
                'tta_accs': results['all_tta_accuracies'],
                'reg_ious': results['all_reg_ious'],
                'tta_ious': results['all_tta_ious'],
                'avg_reg_acc': results['avg_reg_acc'],
                'avg_tta_acc': results['avg_tta_acc'],
                'avg_reg_miou': results['avg_reg_miou'],
                'avg_tta_miou': results['avg_tta_miou']
            })
        
        # Plot 1: Accuracy comparison per image
        for i, model_data in enumerate(models_data):
            x_positions = range(len(model_data['reg_accs']))
            axes[0, 0].scatter([x + i*0.1 for x in x_positions], model_data['reg_accs'], 
                              alpha=0.7, label=f"{model_data['name']} Regular", marker='o')
            axes[0, 0].scatter([x + i*0.1 for x in x_positions], model_data['tta_accs'], 
                              alpha=0.7, label=f"{model_data['name']} TTA", marker='^')
        
        axes[0, 0].set_xlabel('Image Index')
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].set_title('Accuracy Comparison Across Images')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Plot 2: mIoU comparison per image
        for i, model_data in enumerate(models_data):
            x_positions = range(len(model_data['reg_ious']))
            axes[0, 1].scatter([x + i*0.1 for x in x_positions], model_data['reg_ious'], 
                              alpha=0.7, label=f"{model_data['name']} Regular", marker='o')
            axes[0, 1].scatter([x + i*0.1 for x in x_positions], model_data['tta_ious'], 
                              alpha=0.7, label=f"{model_data['name']} TTA", marker='^')
        
        axes[0, 1].set_xlabel('Image Index')
        axes[0, 1].set_ylabel('mIoU')
        axes[0, 1].set_title('mIoU Comparison Across Images')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # Plot 3: Average performance comparison (bar chart)
        x_pos = np.arange(len(model_names))
        width = 0.35
        
        reg_accs = [model_data['avg_reg_acc'] for model_data in models_data]
        tta_accs = [model_data['avg_tta_acc'] for model_data in models_data]
        
        bars1 = axes[1, 0].bar([x - width/2 for x in x_pos], reg_accs, width, label='Regular', alpha=0.8)
        bars2 = axes[1, 0].bar([x + width/2 for x in x_pos], tta_accs, width, label='TTA', alpha=0.8)
        
        axes[1, 0].set_xlabel('Model')
        axes[1, 0].set_ylabel('Average Accuracy')
        axes[1, 0].set_title('Average Accuracy Comparison')
        axes[1, 0].set_xticks(x_pos)
        axes[1, 0].set_xticklabels(model_names)
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # Add value labels on bars
        for bar in bars1 + bars2:
            height = bar.get_height()
            axes[1, 0].annotate(f'{height:.3f}',
                               xy=(bar.get_x() + bar.get_width() / 2, height),
                               xytext=(0, 3),  # 3 points vertical offset
                               textcoords="offset points",
                               ha='center', va='bottom', fontsize=9)
        
        # Plot 4: Average mIoU comparison (bar chart)
        reg_ious = [model_data['avg_reg_miou'] for model_data in models_data]
        tta_ious = [model_data['avg_tta_miou'] for model_data in models_data]
        
        bars3 = axes[1, 1].bar([x - width/2 for x in x_pos], reg_ious, width, label='Regular', alpha=0.8)
        bars4 = axes[1, 1].bar([x + width/2 for x in x_pos], tta_ious, width, label='TTA', alpha=0.8)
        
        axes[1, 1].set_xlabel('Model')
        axes[1, 1].set_ylabel('Average mIoU')
        axes[1, 1].set_title('Average mIoU Comparison')
        axes[1, 1].set_xticks(x_pos)
        axes[1, 1].set_xticklabels(model_names)
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        
        # Add value labels on bars
        for bar in bars3 + bars4:
            height = bar.get_height()
            axes[1, 1].annotate(f'{height:.3f}',
                               xy=(bar.get_x() + bar.get_width() / 2, height),
                               xytext=(0, 3),  # 3 points vertical offset
                               textcoords="offset points",
                               ha='center', va='bottom', fontsize=9)
        
        plt.tight_layout()
        plt.show()
        
        # Print detailed comparison
        print("\n" + "="*60)
        print("DETAILED COMPARISON ANALYSIS")
        print("="*60)
        
        for i, model_data in enumerate(models_data):
            print(f"\n{model_data['name']}:")
            print(f"  Regular -> TTA Accuracy: {model_data['avg_reg_acc']:.4f} -> {model_data['avg_tta_acc']:.4f} "
                  f"(+{model_data['avg_tta_acc'] - model_data['avg_reg_acc']:.4f})")
            print(f"  Regular -> TTA mIoU: {model_data['avg_reg_miou']:.4f} -> {model_data['avg_tta_miou']:.4f} "
                  f"(+{model_data['avg_tta_miou'] - model_data['avg_reg_miou']:.4f})")
        
        # Compare models against each other
        if len(models_data) == 2:
            base_model = models_data[0]
            enhanced_model = models_data[1]
            
            print(f"\n" + "="*40)
            print(f"MODEL vs MODEL COMPARISON")
            print(f"="*40)
            
            # Regular prediction comparison
            reg_acc_diff = enhanced_model['avg_reg_acc'] - base_model['avg_reg_acc']
            reg_miou_diff = enhanced_model['avg_reg_miou'] - base_model['avg_reg_miou']
            
            print(f"Regular Prediction Comparison:")
            print(f"  Accuracy: {enhanced_model['name']} vs {base_model['name']} = {reg_acc_diff:+.4f}")
            print(f"  mIoU: {enhanced_model['name']} vs {base_model['name']} = {reg_miou_diff:+.4f}")
            
            # TTA prediction comparison
            tta_acc_diff = enhanced_model['avg_tta_acc'] - base_model['avg_tta_acc']
            tta_miou_diff = enhanced_model['avg_tta_miou'] - base_model['avg_tta_miou']
            
            print(f"\nTTA Prediction Comparison:")
            print(f"  Accuracy: {enhanced_model['name']} vs {base_model['name']} = {tta_acc_diff:+.4f}")
            print(f"  mIoU: {enhanced_model['name']} vs {base_model['name']} = {tta_miou_diff:+.4f}")
            
            # Overall assessment
            print(f"\nOverall Assessment:")
            if tta_acc_diff > 0 and tta_miou_diff > 0:
                print(f"✓ Enhanced model shows improvement in both accuracy and mIoU with TTA")
            elif tta_acc_diff > 0 or tta_miou_diff > 0:
                print(f"~ Enhanced model shows mixed results compared to base model")
            else:
                print(f"- Enhanced model does not show clear improvement with TTA")
    
    else:
        print(f"Only {len(model_names)} model(s) available for comparison.")
        print("Need at least 2 models for comprehensive comparison.")

else:
    print("No TTA comparison results available.")
    print("Please run the TTA comparison first.")