# Binary Path Segmentation Training

This notebook implements binary semantic segmentation for peatland path detection using a U-Net architecture with a ResNet34 backbone. The model is trained to identify two classes:
- Path (1)
- Background (0)

The binary approach simplifies the segmentation task and focuses specifically on path detection, which is crucial for navigation.

## 1. Import Required Libraries

In [None]:
# Standard library imports
import os
from pathlib import Path
from datetime import datetime

# Deep learning and numerical computing
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Image processing and data manipulation
from PIL import Image
import numpy as np
import pandas as pd

# Progress tracking
from tqdm import tqdm

# Computer vision libraries
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2

  from .autonotebook import tqdm as notebook_tqdm


## 2. Dataset and Transformations

Implementation of the dataset class and data augmentation pipeline:
- Custom dataset for binary segmentation
- Image resizing and normalization
- Data augmentation strategies
- Proper tensor conversion for PyTorch

### Dataset Implementation Details

The `BinaryPeatlandDataset` class extends PyTorch's Dataset class to handle our binary segmentation data:

1. **Data Organization**:
   - Images are stored in one directory
   - Corresponding masks in another directory
   - Filenames match between images and masks

2. **Image Processing**:
   - RGB images are loaded in full color
   - Masks are loaded as single-channel binary images
   - Optional data augmentation is applied to both

3. **Return Format**:
   - Images: Normalized tensors in RGB format
   - Masks: Long tensors with binary values (0 for background, 1 for path)

In [None]:
class BinaryPeatlandDataset(Dataset):
    """Custom PyTorch Dataset for binary path segmentation in peatland images.
    
    This dataset handles pairs of images and their corresponding binary segmentation masks,
    where paths are labeled as 1 and all other terrain as 0.
    
    Args:
        images_dir (str or Path): Directory containing the input images
        masks_dir (str or Path): Directory containing the binary segmentation masks
        transform (callable, optional): Optional transform to be applied to both image and mask
    """
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_dir = Path(images_dir)
        self.masks_dir = Path(masks_dir)
        self.transform = transform
        self.image_filenames = sorted(os.listdir(self.images_dir))

    def __len__(self):
        """Returns the total number of image-mask pairs in the dataset."""
        return len(self.image_filenames)

    def __getitem__(self, idx):
        """Retrieves an image-mask pair by index.
        
        Args:
            idx (int): Index of the image-mask pair to retrieve
            
        Returns:
            tuple: (image, mask) where image is a transformed RGB tensor and
                  mask is a binary tensor with 0s and 1s
        """
        img_name = self.image_filenames[idx]
        img_path = self.images_dir / img_name
        mask_path = self.masks_dir / img_name
        
        # Load image in RGB format for consistent processing
        image = np.array(Image.open(img_path).convert("RGB"))
        # Load mask as single channel binary image
        mask = np.array(Image.open(mask_path)) 
        
        # Apply transformations if specified
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            
        # Convert mask to long tensor type for CrossEntropyLoss
        mask = mask.long()
        return image, mask

### Data Transformations

The training pipeline uses Albumentations for efficient data augmentation:

1. **Image Dimensions**:
   - Height: 480 pixels
   - Width: 640 pixels
   - Aspect ratio maintained for consistent processing

2. **Normalization**:
   - Using ImageNet mean and standard deviation
   - Ensures model compatibility with pretrained weights
   - Scales pixel values for better training stability

3. **Augmentation Strategy**:
   - Training: Includes horizontal flips for diversity
   - Validation: Only resizing and normalization
   - All transformations maintain mask alignment

In [None]:
# Define image dimensions for consistent processing
IMG_HEIGHT = 480
IMG_WIDTH = 640

# ImageNet normalization parameters
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Training transformations including data augmentation
train_transform = A.Compose([
    A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH),
    A.HorizontalFlip(p=0.5),  # 50% chance of horizontal flip
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD, max_pixel_value=255.0),
    ToTensorV2(),
])

# Validation transformations (no augmentation)
val_transform = A.Compose([
    A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD, max_pixel_value=255.0),
    ToTensorV2(),
])

## 3. Configuration and Data Loading

The training configuration includes several key parameters and setup steps:

1. **Hyperparameters**:
   - Learning rate: 1e-4
   - Batch size: 4
   - Number of epochs: 20
   - Early stopping patience: 4

2. **Hardware Configuration**:
   - Automatic device selection (CUDA/MPS/CPU)
   - Optimized for available hardware

3. **Data Organization**:
   - Structured directory setup for train/val splits
   - Separate image and mask directories
   - Consistent naming conventions

4. **Output Management**:
   - Timestamped run names
   - Organized metrics storage
   - Automatic directory creation

In [None]:
# Training hyperparameters
LEARNING_RATE = 1e-4
BATCH_SIZE = 4
NUM_EPOCHS = 20 

# Device configuration
if torch.cuda.is_available(): DEVICE = "cuda"
elif torch.backends.mps.is_available(): DEVICE = "mps"
else: DEVICE = "cpu"

# Set up run identification and metrics storage
run_name = f"binary_unet_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
metrics_dir = Path("metrics") / run_name
metrics_dir.mkdir(parents=True, exist_ok=True)

print(f"Using device: {DEVICE}")
print(f"Metrics for this run will be saved in: {metrics_dir}")

# Data directory configuration
BASE_PROCESSED_DIR = Path("../data/processed/binary_segmentation")
TRAIN_IMG_DIR = BASE_PROCESSED_DIR / "train" / "images"
TRAIN_MASK_DIR = BASE_PROCESSED_DIR / "train" / "masks"
VAL_IMG_DIR = BASE_PROCESSED_DIR / "val" / "images"
VAL_MASK_DIR = BASE_PROCESSED_DIR / "val" / "masks"

# Initialize datasets with appropriate transforms
train_dataset = BinaryPeatlandDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, train_transform)
val_dataset = BinaryPeatlandDataset(VAL_IMG_DIR, VAL_MASK_DIR, val_transform)

# Create data loaders for batched processing
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

Using device: mps
Metrics for this run will be saved in: metrics/binary_unet_2025-08-03_22-53-32


## 4. Training Loop

The training loop implements a complete training pipeline with:

1. **Model Architecture**:
   - U-Net with ResNet34 backbone
   - Pretrained ImageNet weights
   - Binary classification (2 classes)

2. **Optimization Strategy**:
   - AdamW optimizer with weight decay
   - Learning rate scheduling
   - CrossEntropy loss function

3. **Training Features**:
   - Progress tracking with tqdm
   - Early stopping mechanism
   - Best model checkpointing

4. **Monitoring**:
   - Training loss tracking
   - Validation loss monitoring
   - Pixel-wise accuracy calculation
   - Comprehensive metrics logging

In [None]:
# Initialize model with pretrained encoder
model = smp.Unet("resnet34", encoder_weights="imagenet", in_channels=3, classes=2).to(DEVICE)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)

# Initialize training tracking variables
training_log = []
best_val_loss = float('inf')
epochs_no_improve = 0
patience = 4  # Early stopping patience

# Training loop
for epoch in range(NUM_EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
    
    # Training phase
    model.train()
    train_loss = 0.0
    for images, masks in tqdm(train_loader, desc="Training"):
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        
        # Forward pass and loss calculation
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
    # Validation phase
    model.eval()
    val_loss, val_correct_pixels, val_total_pixels = 0.0, 0, 0
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validating"):
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Calculate metrics
            val_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            val_correct_pixels += (preds == masks).sum().item()
            val_total_pixels += torch.numel(masks)
            
    # Calculate epoch statistics
    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = (val_correct_pixels / val_total_pixels) * 100
    
    # Print epoch results
    print(f"Average Training Loss: {avg_train_loss:.4f}")
    print(f"Average Validation Loss: {avg_val_loss:.4f}")
    print(f"Validation Pixel Accuracy: {val_accuracy:.2f}%")
    
    # Learning rate scheduling
    scheduler.step(avg_val_loss)
    
    # Log metrics
    training_log.append({
        'epoch': epoch + 1,
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'val_accuracy': val_accuracy
    })

    # Model checkpointing and early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), metrics_dir / "best_binary_model.pth")
        print(f"Validation loss improved. Saving best model to {metrics_dir / 'best_binary_model.pth'}")
    else:
        epochs_no_improve += 1
        print(f"Validation loss did not improve for {epochs_no_improve} epoch(s).")

    if epochs_no_improve >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

# Save final training metrics
print("\nTraining finished.")
log_df = pd.DataFrame(training_log)
log_df.to_csv(metrics_dir / "training_log.csv", index=False)
print(f"Training log saved to {metrics_dir / 'training_log.csv'}")


--- Epoch 1/20 ---


Training: 100%|██████████| 998/998 [08:52<00:00,  1.87it/s]
Validating: 100%|██████████| 51/51 [00:17<00:00,  2.86it/s]


Average Training Loss: 0.3172
Average Validation Loss: 0.1828
Validation Pixel Accuracy: 92.30%
Validation loss improved. Saving best model to metrics/binary_unet_2025-08-03_22-53-32/best_binary_model.pth

--- Epoch 2/20 ---


Training: 100%|██████████| 998/998 [08:33<00:00,  1.94it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.04it/s]


Average Training Loss: 0.2266
Average Validation Loss: 0.1539
Validation Pixel Accuracy: 93.91%
Validation loss improved. Saving best model to metrics/binary_unet_2025-08-03_22-53-32/best_binary_model.pth

--- Epoch 3/20 ---


Training: 100%|██████████| 998/998 [08:36<00:00,  1.93it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.10it/s]


Average Training Loss: 0.1877
Average Validation Loss: 0.1485
Validation Pixel Accuracy: 94.45%
Validation loss improved. Saving best model to metrics/binary_unet_2025-08-03_22-53-32/best_binary_model.pth

--- Epoch 4/20 ---


Training: 100%|██████████| 998/998 [08:36<00:00,  1.93it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.11it/s]


Average Training Loss: 0.1624
Average Validation Loss: 0.1206
Validation Pixel Accuracy: 95.27%
Validation loss improved. Saving best model to metrics/binary_unet_2025-08-03_22-53-32/best_binary_model.pth

--- Epoch 5/20 ---


Training: 100%|██████████| 998/998 [08:36<00:00,  1.93it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.06it/s]


Average Training Loss: 0.1460
Average Validation Loss: 0.1345
Validation Pixel Accuracy: 94.97%
Validation loss did not improve for 1 epoch(s).

--- Epoch 6/20 ---


Training: 100%|██████████| 998/998 [08:36<00:00,  1.93it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.06it/s]


Average Training Loss: 0.1346
Average Validation Loss: 0.1360
Validation Pixel Accuracy: 94.50%
Validation loss did not improve for 2 epoch(s).

--- Epoch 7/20 ---


Training: 100%|██████████| 998/998 [08:35<00:00,  1.94it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.02it/s]


Average Training Loss: 0.1237
Average Validation Loss: 0.1399
Validation Pixel Accuracy: 94.91%
Validation loss did not improve for 3 epoch(s).

--- Epoch 8/20 ---


Training: 100%|██████████| 998/998 [08:29<00:00,  1.96it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.05it/s]


Average Training Loss: 0.0944
Average Validation Loss: 0.1060
Validation Pixel Accuracy: 95.68%
Validation loss improved. Saving best model to metrics/binary_unet_2025-08-03_22-53-32/best_binary_model.pth

--- Epoch 9/20 ---


Training: 100%|██████████| 998/998 [08:32<00:00,  1.95it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.08it/s]


Average Training Loss: 0.0854
Average Validation Loss: 0.1051
Validation Pixel Accuracy: 95.72%
Validation loss improved. Saving best model to metrics/binary_unet_2025-08-03_22-53-32/best_binary_model.pth

--- Epoch 10/20 ---


Training: 100%|██████████| 998/998 [08:32<00:00,  1.95it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.07it/s]


Average Training Loss: 0.0811
Average Validation Loss: 0.1101
Validation Pixel Accuracy: 95.51%
Validation loss did not improve for 1 epoch(s).

--- Epoch 11/20 ---


Training: 100%|██████████| 998/998 [08:33<00:00,  1.95it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.09it/s]


Average Training Loss: 0.0777
Average Validation Loss: 0.1060
Validation Pixel Accuracy: 95.60%
Validation loss did not improve for 2 epoch(s).

--- Epoch 12/20 ---


Training: 100%|██████████| 998/998 [08:33<00:00,  1.95it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.18it/s]


Average Training Loss: 0.0730
Average Validation Loss: 0.1111
Validation Pixel Accuracy: 95.64%
Validation loss did not improve for 3 epoch(s).

--- Epoch 13/20 ---


Training: 100%|██████████| 998/998 [08:33<00:00,  1.94it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.07it/s]

Average Training Loss: 0.0691
Average Validation Loss: 0.1092
Validation Pixel Accuracy: 95.58%
Validation loss did not improve for 4 epoch(s).
Early stopping triggered after 13 epochs.

Training finished.
Training log saved to metrics/binary_unet_2025-08-03_22-53-32/training_log.csv



