# Multi-Class Semantic Segmentation Training

This notebook implements semantic segmentation for peatland navigation using a U-Net architecture with a ResNet34 backbone. The model performs pixel-wise classification of peatland terrain features for safe and efficient navigation.

Target Classes:
1. Path (Primary navigation routes)
2. Natural Ground (Safe traversable terrain)
3. Tree (Obstacles and landmarks)
4. Vegetation (Secondary terrain features)
5. Ignore/Background (Areas outside annotation scope)

Key Features:
- ResNet34 encoder with ImageNet pretraining
- U-Net decoder for precise segmentation
- Efficient data augmentation pipeline
- Multi-class pixel-wise classification
- Comprehensive metric tracking

## 1. Import Required Libraries

In [None]:
# Standard library imports
import os
from pathlib import Path
from datetime import datetime

# Deep learning and numerical computing
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Computer vision and image processing
import segmentation_models_pytorch as smp
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Data handling and utilities
import pandas as pd
from tqdm import tqdm

## 2. Dataset Implementation

The custom `PeatlandDataset` class provides:

1. Data Loading:
   - RGB image loading with format consistency
   - Multi-class segmentation mask handling
   - Automatic path resolution and file matching

2. Data Processing:
   - Dynamic transformation pipeline
   - Albumentations integration for augmentation
   - Proper tensor type conversion

3. Memory Efficiency:
   - On-demand data loading
   - Efficient file handling
   - Optimized memory usage

The implementation ensures proper alignment between images and their corresponding segmentation masks throughout the training process.

In [None]:
class PeatlandDataset(Dataset):
    """Custom PyTorch Dataset for peatland semantic segmentation.
    
    This dataset handles the loading and processing of peatland images and their
    corresponding segmentation masks for multi-class terrain classification.
    
    Args:
        images_dir (str or Path): Directory containing input RGB images
        masks_dir (str or Path): Directory containing segmentation masks
        transform (callable, optional): Transformations to apply to image-mask pairs
        
    The dataset expects:
    1. Matching filenames between images and masks
    2. RGB images in standard formats (jpg, png)
    3. Single-channel integer masks with class labels
    """
    def __init__(self, images_dir, masks_dir, transform=None):
        # Convert paths to Path objects for consistent handling
        self.images_dir = Path(images_dir)
        self.masks_dir = Path(masks_dir)
        self.transform = transform
        
        # Ensure consistent ordering of files
        self.image_filenames = sorted(os.listdir(self.images_dir))

    def __len__(self):
        """Returns the total number of image-mask pairs in the dataset."""
        return len(self.image_filenames)

    def __getitem__(self, idx):
        """Retrieves and processes an image-mask pair by index.
        
        Args:
            idx (int): Index of the image-mask pair to retrieve
            
        Returns:
            tuple: (image, mask) where image is a transformed RGB tensor and
                  mask is a long tensor with class labels
        """
        img_name = self.image_filenames[idx]
        
        # Construct full paths for data loading
        img_path = self.images_dir / img_name
        mask_path = self.masks_dir / img_name
        
        # Load and preprocess input data
        image = np.array(Image.open(img_path).convert("RGB"))  # Ensure RGB format
        mask = np.array(Image.open(mask_path))                 # Load as-is for classes
        
        # Apply transformations if specified
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            
        # Convert mask to appropriate tensor type for loss computation
        mask = mask.long()
            
        return image, mask

## 3. Data Augmentation Pipeline

The data augmentation strategy includes:

1. Image Sizing:
   - Height: 480 pixels
   - Width: 640 pixels
   - Consistent aspect ratio

2. Training Augmentations:
   - Resize to target dimensions
   - Horizontal flipping (50% probability)
   - ImageNet normalization

3. Validation Pipeline:
   - Resize only
   - Normalization
   - No random augmentations

The transforms ensure that:
- Input sizes match model requirements
- Data distribution matches pretrained expectations
- Augmentations preserve semantic validity

In [None]:
# Define input image dimensions
IMG_HEIGHT = 480
IMG_WIDTH = 640

# ImageNet normalization parameters for pretrained models
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Training transforms with augmentation
train_transform = A.Compose([
    A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH, p=1.0),    # Consistent sizing
    A.HorizontalFlip(p=0.5),                                # Random horizontal flips
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD,       # ImageNet normalization
               max_pixel_value=255.0, p=1.0),
    ToTensorV2(),                                           # Convert to PyTorch tensors
])

# Validation transforms (no augmentation)
val_transform = A.Compose([
    A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH, p=1.0),    # Consistent sizing
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD,       # ImageNet normalization
               max_pixel_value=255.0, p=1.0),
    ToTensorV2(),                                           # Convert to PyTorch tensors
])

# Configure data directories
BASE_PROCESSED_DIR = Path("../data/processed/segmentation")
TRAIN_IMG_DIR = BASE_PROCESSED_DIR / "train" / "images"
TRAIN_MASK_DIR = BASE_PROCESSED_DIR / "train" / "masks"
VAL_IMG_DIR = BASE_PROCESSED_DIR / "val" / "images"
VAL_MASK_DIR = BASE_PROCESSED_DIR / "val" / "masks"
TEST_IMG_DIR = BASE_PROCESSED_DIR / "test" / "images"
TEST_MASK_DIR = BASE_PROCESSED_DIR / "test" / "masks"

# Initialize datasets with appropriate transforms
train_dataset = PeatlandDataset(
    images_dir=TRAIN_IMG_DIR, 
    masks_dir=TRAIN_MASK_DIR, 
    transform=train_transform
)

val_dataset = PeatlandDataset(
    images_dir=VAL_IMG_DIR, 
    masks_dir=VAL_MASK_DIR, 
    transform=val_transform
)

test_dataset = PeatlandDataset(
    images_dir=TEST_IMG_DIR, 
    masks_dir=TEST_MASK_DIR, 
    transform=val_transform
)

# Verify dataset initialization
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

# Verify data format and types
image, mask = train_dataset[0]
print(f"\nSample 0 from training set:")
print(f"Image shape: {image.shape}, Image dtype: {image.dtype}")
print(f"Mask shape: {mask.shape}, Mask dtype: {mask.dtype}")

Number of training samples: 3989
Number of validation samples: 203
Number of test samples: 204

Sample 0 from training set:
Image shape: torch.Size([3, 480, 640]), Image dtype: torch.float32
Mask shape: torch.Size([480, 640]), Mask dtype: torch.int64


## 4. Training Configuration

The training pipeline is configured with:

1. Learning Parameters:
   - Learning rate: 1e-4 (AdamW optimizer)
   - Batch size: 4 (GPU memory optimized)
   - Training epochs: 10 (with early stopping)

2. Hardware Setup:
   - Automatic device selection
   - GPU acceleration when available
   - Memory-efficient processing

3. Run Management:
   - Timestamped run identification
   - Organized metrics storage
   - Automatic checkpoint saving

In [None]:
# Training hyperparameters
LEARNING_RATE = 1e-4      # Initial learning rate for optimizer
BATCH_SIZE = 4           # Batch size balanced for memory and speed
NUM_EPOCHS = 10          # Maximum number of training epochs

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else \
         "mps" if torch.backends.mps.is_available() else "cpu"

# Run identification and logging setup
run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
metrics_dir = Path("metrics") / run_name
metrics_dir.mkdir(parents=True, exist_ok=True)

# Display configuration
print(f"Using device: {DEVICE}")
print(f"Metrics for this run will be saved in: {metrics_dir}")

Using device: mps
Using device: mps
Metrics for this run will be saved in: metrics/2025-08-03_15-56-59


## 5. DataLoader Configuration

Initialize data loaders with:

1. Training Settings:
   - Shuffled batch ordering
   - Specified batch size
   - Memory pinning for GPU

2. Validation Settings:
   - Sequential batch ordering
   - Consistent batch size
   - Memory optimization

The configuration ensures:
- Efficient data loading
- Proper randomization
- Optimal memory usage
- Hardware-specific optimizations

In [None]:
# Initialize training data loader
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,              # Shuffle data each epoch
    num_workers=0,            # Single process data loading
    pin_memory=(DEVICE == "cuda")  # Pin memory for faster GPU transfer
)

# Initialize validation data loader
val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,             # Sequential processing for validation
    num_workers=0,            # Single process data loading
    pin_memory=(DEVICE == "cuda")  # Pin memory for faster GPU transfer
)

## 6. Model Architecture and Training Components

The segmentation pipeline consists of:

1. U-Net Architecture:
   - ResNet34 encoder backbone
   - Pretrained ImageNet weights
   - Skip connections for detail preservation
   - 5-class output head

2. Loss Function:
   - Cross-Entropy Loss
   - Multi-class classification
   - Pixel-wise error computation

3. Optimization:
   - AdamW optimizer
   - Weight decay regularization
   - Learning rate configuration
   - Gradient-based updates

In [None]:
# Initialize U-Net model with pretrained encoder
model = smp.Unet(
    encoder_name="resnet34",        # Encoder architecture
    encoder_weights="imagenet",     # Pretrained weights
    in_channels=3,                  # RGB input
    classes=5,                      # Number of segmentation classes
).to(DEVICE)                        # Move model to appropriate device

# Initialize loss function for multi-class segmentation
criterion = nn.CrossEntropyLoss()

# Configure optimizer with weight decay for regularization
optimizer = optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,              # Learning rate from configuration
)

## 7. Training Loop Implementation

The training process includes:

1. Epoch-level Processing:
   - Training phase with gradient updates
   - Validation phase with metric computation
   - Progress tracking and logging

2. Batch Processing:
   - Forward pass through model
   - Loss computation
   - Backpropagation
   - Optimizer updates

3. Validation:
   - No gradient computation
   - Loss calculation
   - Accuracy measurement
   - Memory-efficient processing

4. Metrics and Logging:
   - Loss tracking for both phases
   - Pixel-wise accuracy computation
   - Progress visualization
   - Checkpoint saving

In [None]:
# Initialize training metrics log
training_log = []

# Training loop
for epoch in range(NUM_EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
    
    # Training phase
    model.train()
    train_loss = 0.0
    for images, masks in tqdm(train_loader, desc="Training"):
        # Move data to appropriate device
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        
        # Forward pass and loss computation
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Accumulate batch loss
        train_loss += loss.item()
        
    # Validation phase
    model.eval()
    val_loss, val_correct_pixels, val_total_pixels = 0.0, 0, 0
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validating"):
            # Move data to appropriate device
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            
            # Forward pass and loss computation
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Accumulate validation metrics
            val_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            val_correct_pixels += (preds == masks).sum().item()
            val_total_pixels += torch.numel(masks)
            
    # Calculate epoch statistics
    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = (val_correct_pixels / val_total_pixels) * 100
    
    # Display epoch results
    print(f"Average Training Loss: {avg_train_loss:.4f}")
    print(f"Average Validation Loss: {avg_val_loss:.4f}")
    print(f"Validation Pixel Accuracy: {val_accuracy:.2f}%")
    
    # Log training metrics
    training_log.append({
        'epoch': epoch + 1,
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'val_accuracy': val_accuracy
    })

# Save final model and training metrics
model_save_path = metrics_dir / "peatland_segmentation_model.pth"
torch.save(model.state_dict(), model_save_path)
print(f"\nTraining complete. Model saved to {model_save_path}")

# Save training history
log_df = pd.DataFrame(training_log)
log_df.to_csv(metrics_dir / "training_log.csv", index=False)
print(f"Training log saved to {metrics_dir / 'training_log.csv'}")


--- Epoch 1/10 ---


Training: 100%|██████████| 998/998 [08:48<00:00,  1.89it/s]
Validating: 100%|██████████| 51/51 [00:17<00:00,  2.89it/s]


Average Training Loss: 0.7886
Average Validation Loss: 0.4593
Validation Pixel Accuracy: 85.18%

--- Epoch 2/10 ---


Training: 100%|██████████| 998/998 [08:42<00:00,  1.91it/s]
Validating: 100%|██████████| 51/51 [00:17<00:00,  2.92it/s]


Average Training Loss: 0.5364
Average Validation Loss: 0.3958
Validation Pixel Accuracy: 85.52%

--- Epoch 3/10 ---


Training: 100%|██████████| 998/998 [08:37<00:00,  1.93it/s]
Validating: 100%|██████████| 51/51 [00:17<00:00,  2.94it/s]


Average Training Loss: 0.4546
Average Validation Loss: 0.3434
Validation Pixel Accuracy: 88.75%

--- Epoch 4/10 ---


Training: 100%|██████████| 998/998 [08:35<00:00,  1.94it/s]
Validating: 100%|██████████| 51/51 [00:17<00:00,  2.96it/s]


Average Training Loss: 0.3965
Average Validation Loss: 0.3538
Validation Pixel Accuracy: 87.22%

--- Epoch 5/10 ---


Training: 100%|██████████| 998/998 [08:36<00:00,  1.93it/s]
Validating: 100%|██████████| 51/51 [00:17<00:00,  2.98it/s]


Average Training Loss: 0.3617
Average Validation Loss: 0.3149
Validation Pixel Accuracy: 88.82%

--- Epoch 6/10 ---


Training: 100%|██████████| 998/998 [08:38<00:00,  1.92it/s]
Validating: 100%|██████████| 51/51 [00:17<00:00,  2.96it/s]


Average Training Loss: 0.3294
Average Validation Loss: 0.3362
Validation Pixel Accuracy: 87.87%

--- Epoch 7/10 ---


Training: 100%|██████████| 998/998 [08:42<00:00,  1.91it/s]
Validating: 100%|██████████| 51/51 [00:17<00:00,  2.96it/s]


Average Training Loss: 0.3146
Average Validation Loss: 0.3269
Validation Pixel Accuracy: 87.88%

--- Epoch 8/10 ---


Training: 100%|██████████| 998/998 [08:38<00:00,  1.92it/s]
Validating: 100%|██████████| 51/51 [00:17<00:00,  3.00it/s]


Average Training Loss: 0.2961
Average Validation Loss: 0.3587
Validation Pixel Accuracy: 86.58%

--- Epoch 9/10 ---


Training: 100%|██████████| 998/998 [08:41<00:00,  1.91it/s]
Validating: 100%|██████████| 51/51 [00:17<00:00,  2.88it/s]


Average Training Loss: 0.2860
Average Validation Loss: 0.2834
Validation Pixel Accuracy: 89.77%

--- Epoch 10/10 ---


Training: 100%|██████████| 998/998 [08:38<00:00,  1.93it/s]
Validating: 100%|██████████| 51/51 [00:17<00:00,  2.87it/s]


Average Training Loss: 0.2642
Average Validation Loss: 0.2856
Validation Pixel Accuracy: 89.27%

Training complete. Model saved to metrics/2025-08-03_15-56-59/peatland_segmentation_model.pth
Training log saved to metrics/2025-08-03_15-56-59/training_log.csv
