# DDCM-Net: Dense Dilated Convolutions Merging Network

## 1. Import Required Libraries

In [None]:
# Standard libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# PyTorch libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# Custom modules
from model import create_model, create_trainer
from dataset_loader import create_dataloaders

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Model Architecture Setup

In [None]:
# Model configuration
NUM_CLASSES = 6
BACKBONE = 'resnet50'
PRETRAINED = True

# Class names for visualization
CLASS_NAMES = [
    'Impervious surfaces',  # 0 - White
    'Building',             # 1 - Blue  
    'Low vegetation',       # 2 - Cyan
    'Tree',                 # 3 - Green
    'Car',                  # 4 - Yellow
    'Clutter/background'    # 5 - Red
]

# Create model
model = create_model(
    variant='base',
    num_classes=NUM_CLASSES,
    backbone=BACKBONE,
    pretrained=PRETRAINED
)

# Create trainer wrapper
trainer = create_trainer(
    model=model,
    device=device,
    class_names=CLASS_NAMES
)

# Print model information
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Architecture: DDCM-Net with {BACKBONE} backbone")

## 3. Dataset Loading and Preprocessing

In [None]:
# Dataset configuration
DATA_ROOT = "./data"
DATASET = "potsdam"  # or "vaihingen" or "both"
BATCH_SIZE = 8
NUM_WORKERS = 4
PATCH_SIZE = 256
STRIDE = 128

print("Loading dataset...")

# Check if data exists
if not os.path.exists(DATA_ROOT):
    print(f"⚠️  Data directory not found: {DATA_ROOT}")
    print("Please ensure you have the ISPRS dataset in the data folder")
    print("Expected structure:")
    print("data/")
    print("├── potsdam/")
    print("│   ├── images/")
    print("│   └── labels/")
    print("└── vaihingen/")
    print("    ├── images/")
    print("    └── labels/")
else:
    # Create dataloaders
    train_loader, val_loader, test_loader = create_dataloaders(
        root_dir=DATA_ROOT,
        dataset=DATASET,
        patch_size=PATCH_SIZE,
        stride=STRIDE,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS
    )
    
    print(f"Dataset: {DATASET.upper()}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Train batches: {len(train_loader):,}")
    print(f"Validation batches: {len(val_loader):,}")
    print(f"Test batches: {len(test_loader):,}")
    print(f"Patch size: {PATCH_SIZE}x{PATCH_SIZE}")
    
    # Show a sample batch
    sample_images, sample_labels = next(iter(train_loader))
    print(f"Sample batch shape - Images: {sample_images.shape}, Labels: {sample_labels.shape}")
    print(f"Classes in sample: {torch.unique(sample_labels).tolist()}")

### Dataset Visualization

In [None]:
# Visualize dataset samples
def visualize_dataset_samples(dataloader, num_samples=4):
    """Visualize samples from the dataset"""
    fig, axes = plt.subplots(2, num_samples, figsize=(20, 10))
    
    # Get a batch
    images, labels = next(iter(dataloader))
    
    for i in range(min(num_samples, len(images))):
        # Get single sample
        img = images[i]
        label = labels[i]
        
        # Denormalize image for visualization
        mean = torch.tensor([0.485, 0.456, 0.406])
        std = torch.tensor([0.229, 0.224, 0.225])
        img_denorm = img * std[:, None, None] + mean[:, None, None]
        img_denorm = torch.clamp(img_denorm, 0, 1)
        
        # Plot image
        axes[0, i].imshow(img_denorm.permute(1, 2, 0))
        axes[0, i].set_title(f'Image {i+1}')
        axes[0, i].axis('off')
        
        # Plot label
        im = axes[1, i].imshow(label, cmap='tab10', vmin=0, vmax=5)
        axes[1, i].set_title(f'Ground Truth {i+1}')
        axes[1, i].axis('off')
    
    # Add colorbar for labels
    plt.colorbar(im, ax=axes[1, :], fraction=0.046, pad=0.04, 
                 ticks=range(6), label='Land Cover Classes')
    
    plt.tight_layout()
    plt.show()
    
    # Print class statistics
    unique_classes = torch.unique(labels)
    print(f"Classes in this batch: {unique_classes.tolist()}")
    for class_id in unique_classes:
        class_name = CLASS_NAMES[class_id]
        pixel_count = (labels == class_id).sum().item()
        percentage = pixel_count / labels.numel() * 100
        print(f"  {class_id}: {class_name} - {pixel_count:,} pixels ({percentage:.1f}%)")

# Visualize samples if data is available
if 'train_loader' in locals():
    print("Dataset samples:")
    visualize_dataset_samples(train_loader, num_samples=4)
else:
    print("Dataset not loaded. Please ensure data is available.")

## 4. Model Training

In [None]:
# Training configuration
EPOCHS = 15 
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

print("Starting model training...")
print(f"Configuration:")
print(f"Epochs: {EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Weight decay: {WEIGHT_DECAY}")
print(f"Device: {device}")

# Start training if data is available
if 'train_loader' in locals() and 'val_loader' in locals():
    print(f"\nTraining started...")
    
    # Train the model
    history = trainer.fit(
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=EPOCHS,
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY
    )
    
    print(f"\nTraining completed!")
    
else:
    print("Cannot start training - dataset not available")
    print("Please ensure the data is properly loaded before training")

## 5. Training Visualization and Metrics

In [None]:
# Plot training history
if 'history' in locals() and trainer.history['train_loss']:
    print("Training History Visualization")
    
    # Plot using trainer's built-in visualization
    trainer.plot_training_history(figsize=(18, 6))
    
    # Print final metrics
    final_train_loss = trainer.history['train_loss'][-1]
    final_val_loss = trainer.history['val_loss'][-1]
    final_train_acc = trainer.history['train_acc'][-1]
    final_val_acc = trainer.history['val_acc'][-1]
    final_train_miou = trainer.history['train_miou'][-1]
    final_val_miou = trainer.history['val_miou'][-1]
    
    print("\nFinal Training Metrics:")
    print(f"  Train Loss: {final_train_loss:.4f} | Val Loss: {final_val_loss:.4f}")
    print(f"  Train Acc:  {final_train_acc:.3f} | Val Acc:  {final_val_acc:.3f}")
    print(f"  Train mIoU: {final_train_miou:.3f} | Val mIoU: {final_val_miou:.3f}")
    
    # Find best epoch
    best_epoch = np.argmax(trainer.history['val_miou']) + 1
    best_miou = max(trainer.history['val_miou'])
    print(f"\nBest Validation mIoU: {best_miou:.3f} (Epoch {best_epoch})")
    
else:
    print("No training history available.")
    print("Please run training first or load a pre-trained model.")

## 6. Model Prediction and Visualization

In [None]:
# Load best model for inference
if os.path.exists('best_model.pth'):
    print("Loading best trained model...")
    trainer.load_model('best_model.pth')
    print("Model loaded successfully!")
else:
    print("No saved model found. Using current model state.")

# Visualize predictions on test data
if 'test_loader' in locals():
    print("\nModel Predictions Visualization")
    print("Comparing ground truth vs predictions on test samples...")
    
    # Visualize predictions
    trainer.visualize_predictions(test_loader, num_samples=4, figsize=(20, 8))
    
else:
    print("Test data not available for prediction visualization")

### Detailed Prediction Analysis

In [None]:
# Detailed evaluation on test set
if 'test_loader' in locals():
    print("Detailed Model Evaluation")
    
    # Evaluate model on test set
    model.eval()
    all_predictions = []
    all_targets = []
    test_losses = []
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for images, targets in tqdm(test_loader, desc="Evaluating"):
            images = images.to(device)
            targets = targets.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, targets)
            predictions = torch.argmax(outputs, dim=1)
            
            all_predictions.append(predictions.cpu())
            all_targets.append(targets.cpu())
            test_losses.append(loss.item())
    
    # Concatenate all results
    all_predictions = torch.cat(all_predictions)
    all_targets = torch.cat(all_targets)
    
    # Calculate overall metrics
    overall_acc = (all_predictions == all_targets).float().mean()
    test_loss = np.mean(test_losses)
    
    print(f"\nTest Set Results:")
    print(f"  Overall Accuracy: {overall_acc:.3f}")
    print(f"  Average Loss: {test_loss:.4f}")
    
    # Calculate class-wise IoU
    class_ious = []
    print(f"\nClass-wise IoU:")
    for class_id in range(NUM_CLASSES):
        pred_mask = (all_predictions == class_id)
        target_mask = (all_targets == class_id)
        
        if target_mask.sum() == 0:
            if pred_mask.sum() == 0:
                iou = 1.0
            else:
                iou = 0.0
        else:
            intersection = (pred_mask & target_mask).float().sum()
            union = (pred_mask | target_mask).float().sum()
            iou = (intersection / union).item()
        
        class_ious.append(iou)
        print(f"  {class_id}: {CLASS_NAMES[class_id]:<20} IoU: {iou:.3f}")
    
    mean_iou = np.mean(class_ious)
    print(f"\nMean IoU: {mean_iou:.3f}")
    
else:
    print("Test data not available for detailed evaluation")

In [None]:
def apply_color_map(mask):
    """
    Apply color mapping to segmentation mask.
    
    Args:
        mask: numpy array of shape (H, W) with integer class labels
        
    Returns:
        colored_mask: numpy array of shape (H, W, 3) with RGB colors
    """
    # Define colors for each class (RGB format)
    colors = [
        [255, 255, 255],  # 0 - White - Impervious surfaces
        [0, 0, 255],      # 1 - Blue - Building
        [0, 255, 255],    # 2 - Cyan - Low vegetation
        [0, 255, 0],      # 3 - Green - Tree
        [255, 255, 0],    # 4 - Yellow - Car
        [255, 0, 0],      # 5 - Red - Clutter/background
    ]
    
    # Convert colors to numpy array
    colors = np.array(colors)
    
    # Create empty RGB image
    h, w = mask.shape
    colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
    
    # Apply colors based on class indices
    for class_idx in range(len(colors)):
        colored_mask[mask == class_idx] = colors[class_idx]
    
    return colored_mask

# Visualize sample predictions with analysis
def visualize_predictions_detailed(model, dataloader, num_samples=6):
    """Enhanced prediction visualization with confidence analysis"""
    model.eval()
    
    # Get batch of samples
    dataiter = iter(dataloader)
    images, targets = next(dataiter)
    
    # Take subset of samples
    images = images[:num_samples]
    targets = targets[:num_samples]
    
    # Generate predictions
    with torch.no_grad():
        outputs = model(images.to(device))
        probabilities = torch.softmax(outputs, dim=1)
        predictions = torch.argmax(outputs, dim=1)
        confidence = torch.max(probabilities, dim=1)[0]
    
    # Move to CPU for visualization
    images = images.cpu()
    targets = targets.cpu()
    predictions = predictions.cpu()
    confidence = confidence.cpu()
    probabilities = probabilities.cpu()
    
    # Create comprehensive visualization
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        # Original image
        img = np.transpose(images[i], (1, 2, 0))
        img = (img - img.min()) / (img.max() - img.min())  # Normalize for display
        axes[i, 0].imshow(img)
        axes[i, 0].set_title(f'Original Image {i+1}')
        axes[i, 0].axis('off')
        
        # Ground truth
        gt_colored = apply_color_map(targets[i].numpy())
        axes[i, 1].imshow(gt_colored)
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        # Prediction
        pred_colored = apply_color_map(predictions[i].numpy())
        axes[i, 2].imshow(pred_colored)
        axes[i, 2].set_title(f'Prediction\nConf: {confidence[i]:.3f}')
        axes[i, 2].axis('off')
        
        # Confidence map
        conf_map = confidence[i].numpy()
        im = axes[i, 3].imshow(conf_map, cmap='hot', vmin=0, vmax=1)
        axes[i, 3].set_title('Confidence Map')
        axes[i, 3].axis('off')
        
        # Add colorbar for confidence
        plt.colorbar(im, ax=axes[i, 3], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.show()
    
    # Class distribution analysis
    print("\nPrediction Analysis:")
    for i in range(num_samples):
        pred_flat = predictions[i].flatten()
        target_flat = targets[i].flatten()
        
        # Calculate accuracy for this sample
        sample_acc = (pred_flat == target_flat).float().mean()
        print(f"\nSample {i+1} - Accuracy: {sample_acc:.3f}")
        
        # Show class distributions
        print("  Predicted classes:", np.bincount(pred_flat, minlength=NUM_CLASSES))
        print("  Ground truth:     ", np.bincount(target_flat, minlength=NUM_CLASSES))

# Run detailed prediction visualization
print("Generating Detailed Prediction Analysis...")
visualize_predictions_detailed(model, test_loader if 'test_loader' in locals() else val_loader)

## Model Saving and Loading

In [None]:
# Save the trained model
import os

# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)

# Save model state dict
model_path = 'models/ddcm_net_trained.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': {
        'variant': 'ddcm',
        'num_classes': NUM_CLASSES,
        'input_channels': 3
    },
    'training_history': history,
    'class_names': CLASS_NAMES
}, model_path)

print(f"Model saved to: {model_path}")

# Example of loading the model later
def load_trained_model(model_path):
    """Load a trained DDCM-Net model"""
    checkpoint = torch.load(model_path, map_location=device)
    
    # Recreate model with same config
    config = checkpoint['model_config']
    loaded_model = DDCMNet(
        num_classes=config['num_classes'],
        variant=config['variant']
    )
    
    # Load weights
    loaded_model.load_state_dict(checkpoint['model_state_dict'])
    loaded_model.to(device)
    loaded_model.eval()
    
    print(f"Model loaded from: {model_path}")
    print(f"Configuration: {config}")
    
    return loaded_model, checkpoint['training_history']

# Demonstrate loading (commented out to avoid reloading in this session)
# loaded_model, loaded_history = load_trained_model(model_path)
# print("Model successfully reloaded!")