# DDCM-Net: Dense Dilated Convolutions Merging Network

## 1. Import Required Libraries

In [None]:
# Standard libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# PyTorch libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# Custom modules
from model import create_model, create_trainer
from dataset_loader import create_dataloaders

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Model Architecture Setup

In [None]:
# Model configuration
NUM_CLASSES = 6
BACKBONE = 'resnet50'
PRETRAINED = True

# Class names for visualization
CLASS_NAMES = [
    'Impervious surfaces',  # 0 - White
    'Building',             # 1 - Blue  
    'Low vegetation',       # 2 - Cyan
    'Tree',                 # 3 - Green
    'Car',                  # 4 - Yellow
    'Clutter/background'    # 5 - Red
]

# Create model
model = create_model(
    variant='base',
    num_classes=NUM_CLASSES,
    backbone=BACKBONE,
    pretrained=PRETRAINED
)

# Create trainer wrapper
trainer = create_trainer(
    model=model,
    device=device,
    class_names=CLASS_NAMES
)

# Print model information
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Architecture: DDCM-Net with {BACKBONE} backbone")

### Load model from a previous training session

In [None]:
trainer = trainer.load_model("best_model.pth")

## 3. Dataset Loading and Preprocessing

In [None]:
# Dataset configuration
DATA_ROOT = "./data"
DATASET = "potsdam"  # or "vaihingen" or "both"
BATCH_SIZE = 5  
NUM_WORKERS = 4
PATCH_SIZE = 256  

print("Loading dataset...")

# Check if data exists
if not os.path.exists(DATA_ROOT):
    print(f"Data directory not found: {DATA_ROOT}")
    print("Please ensure you have the ISPRS dataset in the data folder")
    print("Expected structure:")
    print("data/")
    print("├── potsdam/")
    print("│   ├── images/")
    print("│   └── labels/")
    print("└── vaihingen/")
    print("    ├── images/")
    print("    └── labels/")
else:
    train_loader, val_loader, test_loader, holdout_loader = create_dataloaders(
        root_dir=DATA_ROOT,
        dataset=DATASET,
        patch_size=PATCH_SIZE,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS
    )
    
    print(f"Dataset: {DATASET.upper()}")
    print(f"Batch size: {BATCH_SIZE} (paper specification)")
    print(f"Patch size: {PATCH_SIZE}x{PATCH_SIZE} (paper specification)")
    print(f"Random sampling: 5000 train patches, 1000 val/test patches (paper-compliant)")
    print(f"Train batches: {len(train_loader):,}")
    print(f"Validation batches: {len(val_loader):,}")
    print(f"Test batches: {len(test_loader):,}")
    
    # Show a sample batch
    # sample_images, sample_labels = next(iter(train_loader))
    # print(f"Sample batch shape - Images: {sample_images.shape}, Labels: {sample_labels.shape}")
    # print(f"Classes in sample: {torch.unique(sample_labels).tolist()}")

### Dataset Visualization

In [None]:
# Visualize dataset samples
def visualize_dataset_samples(dataloader, num_samples=4):
    """Visualize samples from the dataset"""
    fig, axes = plt.subplots(2, num_samples, figsize=(20, 10))
    
    # Get a batch
    images, labels = next(iter(dataloader))
    
    for i in range(min(num_samples, len(images))):
        # Get single sample
        img = images[i]
        label = labels[i]
        
        # Denormalize image for visualization
        mean = torch.tensor([0.485, 0.456, 0.406])
        std = torch.tensor([0.229, 0.224, 0.225])
        img_denorm = img * std[:, None, None] + mean[:, None, None]
        img_denorm = torch.clamp(img_denorm, 0, 1)
        
        # Plot image
        axes[0, i].imshow(img_denorm.permute(1, 2, 0))
        axes[0, i].set_title(f'Image {i+1}')
        axes[0, i].axis('off')
        
        # Plot label
        im = axes[1, i].imshow(label, cmap='tab10', vmin=0, vmax=5)
        axes[1, i].set_title(f'Ground Truth {i+1}')
        axes[1, i].axis('off')
    
    # Add colorbar for labels
    plt.colorbar(im, ax=axes[1, :], fraction=0.046, pad=0.04, 
                 ticks=range(6), label='Land Cover Classes')
    
    plt.tight_layout()
    plt.show()
    
    # Print class statistics
    unique_classes = torch.unique(labels)
    print(f"Classes in this batch: {unique_classes.tolist()}")
    for class_id in unique_classes:
        class_name = CLASS_NAMES[class_id]
        pixel_count = (labels == class_id).sum().item()
        percentage = pixel_count / labels.numel() * 100
        print(f"  {class_id}: {class_name} - {pixel_count:,} pixels ({percentage:.1f}%)")

# Visualize samples if data is available
if 'train_loader' in locals():
    print("Dataset samples:")
    visualize_dataset_samples(train_loader, num_samples=4)
else:
    print("Dataset not loaded. Please ensure data is available.")

## 4. Model Training

In [None]:
# Training configuration
EPOCHS = 50
LEARNING_RATE = 6.01e-5 # 8.5e-5/√2
WEIGHT_DECAY = 2e-5

print("Starting model training...")
print(f"Configuration:")
print(f"Epochs: {EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Weight decay: {WEIGHT_DECAY}")
print(f"Device: {device}")

# Start training if data is available
if 'train_loader' in locals() and 'val_loader' in locals():
    print(f"\nTraining started...")
    
    # Train the model
    history = trainer.fit(
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=EPOCHS,
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        use_mfb=True
    )
    
    print(f"\nTraining completed!")
    
else:
    print("Cannot start training - dataset not available")
    print("Please ensure the data is properly loaded before training")

## 5. Training Visualization and Metrics

In [None]:
# Plot training history
if 'history' in locals() and trainer.history['train_loss']:
    print("Training History Visualization")
    
    # Plot using trainer's built-in visualization
    trainer.plot_training_history(figsize=(18, 6))
    
    # Print final metrics
    final_train_loss = trainer.history['train_loss'][-1]
    final_val_loss = trainer.history['val_loss'][-1]
    final_train_acc = trainer.history['train_acc'][-1]
    final_val_acc = trainer.history['val_acc'][-1]
    final_train_miou = trainer.history['train_miou'][-1]
    final_val_miou = trainer.history['val_miou'][-1]
    
    print("\nFinal Training Metrics:")
    print(f"  Train Loss: {final_train_loss:.4f} | Val Loss: {final_val_loss:.4f}")
    print(f"  Train Acc:  {final_train_acc:.3f} | Val Acc:  {final_val_acc:.3f}")
    print(f"  Train mIoU: {final_train_miou:.3f} | Val mIoU: {final_val_miou:.3f}")
    
    # Find best epoch
    best_epoch = np.argmax(trainer.history['val_miou']) + 1
    best_miou = max(trainer.history['val_miou'])
    print(f"\nBest Validation mIoU: {best_miou:.3f} (Epoch {best_epoch})")
    
else:
    print("No training history available.")
    print("Please run training first or load a pre-trained model.")

## 6. Model Prediction and Visualization

In [None]:
# Load best model for inference
if os.path.exists('best_model.pth'):
    print("Loading best trained model...")
    trainer.load_model('best_model.pth')
    print("Model loaded successfully!")
else:
    print("No saved model found. Using current model state.")

# Visualize predictions on test data
if 'test_loader' in locals():
    print("\nModel Predictions Visualization")
    print("Comparing ground truth vs predictions on test samples...")
    
    # Visualize predictions
    trainer.visualize_predictions(test_loader, num_samples=4, figsize=(20, 8))
    
else:
    print("Test data not available for prediction visualization")

### Detailed Prediction Analysis

In [None]:
# Detailed evaluation on test set
if 'test_loader' in locals():
    print("Detailed Model Evaluation")
    
    # Evaluate model on test set
    model.eval()
    test_loss = 0.0
    
    # Initialize confusion matrix for class-wise metrics
    num_classes = NUM_CLASSES
    confusion_matrix = torch.zeros(num_classes, num_classes, device=device)
    criterion = nn.CrossEntropyLoss(reduction='sum')
    total_pixels = 0
    
    with torch.no_grad():
        for images, targets in tqdm(test_loader, desc="Evaluating"):
            images = images.to(device)
            targets = targets.to(device)
            
            with torch.no_grad():  
                outputs = model(images)
                loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            
            predictions = torch.argmax(outputs, dim=1)
            
            pred_flat = predictions.reshape(-1)
            target_flat = targets.reshape(-1)
            
            # Update histogram
            for t in range(num_classes):
                mask = (target_flat == t)
                if mask.sum() > 0:
                    p = pred_flat[mask]
                    bincount = torch.bincount(p, minlength=num_classes)
                    confusion_matrix[t] += bincount
            
            total_pixels += targets.numel()
    
    # Calculate metrics
    test_loss /= total_pixels
    
    # Calculate accuracy
    overall_acc = torch.diag(confusion_matrix).sum() / confusion_matrix.sum()
    
    print(f"\nTest Set Results:")
    print(f"  Overall Accuracy: {overall_acc.item():.3f}")
    print(f"  Average Loss: {test_loss:.4f}")
    
    # Calculate class-wise IoU
    class_ious = []
    print(f"\nClass-wise IoU:")
    
    for class_id in range(num_classes):
        # True positives: diagonal elements of the confusion matrix
        tp = confusion_matrix[class_id, class_id].item()
        
        # Sum over row and column for class i
        # Row sum = all actual instances of class i
        # Column sum = all predicted instances of class i
        row_sum = confusion_matrix[class_id, :].sum().item()
        col_sum = confusion_matrix[:, class_id].sum().item()
        
        # IoU = TP / (TP + FP + FN) = TP / (row_sum + col_sum - TP)
        denominator = row_sum + col_sum - tp
        iou = tp / denominator if denominator > 0 else 0.0
        
        class_ious.append(iou)
        print(f"  {class_id}: {CLASS_NAMES[class_id]:<20} IoU: {iou:.3f}")
    
    mean_iou = sum(class_ious) / len(class_ious)
    print(f"\nMean IoU: {mean_iou:.3f}")
    
    # Plot confusion matrix
    plt.figure(figsize=(12, 10))
    confusion_norm = confusion_matrix / confusion_matrix.sum(dim=1, keepdim=True)
    confusion_np = confusion_norm.cpu().numpy()
    
    sns.heatmap(confusion_np, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
    plt.xlabel('Predicted')
    plt.ylabel('Ground Truth')
    plt.title('Normalized Confusion Matrix')
    plt.tight_layout()
    plt.show()
    
else:
    print("Test data not available for detailed evaluation")

## Model Saving and Loading

In [None]:
# Save the trained model
import os

# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)

# Save model state dict
model_path = 'models/ddcm_net_trained82.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'training_history': history,
    'class_names': CLASS_NAMES
}, model_path)

print(f"Model saved to: {model_path}")

### Load a Model

In [None]:
# def load_trained_model(model_path):
#     """Load a trained DDCM-Net model"""
#     checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
#     # Recreate model with same config
#     config = checkpoint['model_config']
#     loaded_model = DDCMNet(
#         num_classes=config['num_classes'],
#         variant=config['variant']
#     )
    
#     # Load weights
#     loaded_model.load_state_dict(checkpoint['model_state_dict'])
#     loaded_model.to(device)
#     loaded_model.eval()
    
#     print(f"Model loaded from: {model_path}")
#     print(f"Configuration: {config}")
    
#     return loaded_model, checkpoint['training_history']

# model_path = '/models/ddcm_net_trained.pth'
# model, history = load_trained_model(model_path)