# DinoV2-based Semantic Segmentation

This notebook implements semantic segmentation using Facebook's DinoV2 vision transformer as the backbone with a custom decoder head. This approach leverages the strong visual features learned by DinoV2 through self-supervised training on large-scale datasets.

Key features:
- DinoV2-base backbone (frozen)
- Custom decoder head for segmentation
- Multi-scale feature processing
- Five-class segmentation output

The model maintains DinoV2's strong feature extraction while adapting it for the specific task of peatland segmentation.

## 1. Imports

In [None]:
# Standard library imports
import os
from pathlib import Path
from datetime import datetime

# Deep learning and numerical computing
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Image processing and data manipulation
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Progress tracking
from tqdm import tqdm

# Vision transformer and transformations
from transformers import Dinov2Model, AutoImageProcessor
import albumentations as A
from albumentations.pytorch import ToTensorV2

  from .autonotebook import tqdm as notebook_tqdm


## 2. DinoV2 Dataset Class and Transforms

The dataset implementation is specialized for DinoV2's requirements:

1. **Image Processing Pipeline**:
   - Images are loaded and converted to RGB
   - Initial transformations (resize, augmentations) are applied
   - DinoV2's image processor handles final normalization

2. **Mask Handling**:
   - Masks are loaded as-is
   - Transformations maintain alignment with images
   - Converted to long tensors for loss computation

3. **DinoV2 Specifics**:
   - Input size fixed at 224x224 pixels
   - Uses DinoV2's custom processor
   - Maintains compatibility with transformer architecture

In [None]:
class PeatlandDinoDataset(Dataset):
    """Custom PyTorch Dataset for DinoV2-based semantic segmentation.
    
    This dataset handles the specific requirements of the DinoV2 vision transformer,
    including proper image preprocessing and tensor formatting.
    
    Args:
        images_dir (str or Path): Directory containing the input images
        masks_dir (str or Path): Directory containing the segmentation masks
        image_processor: DinoV2's image processor for normalization
        transform (callable, optional): Optional transform to be applied before DinoV2 processing
    """
    def __init__(self, images_dir, masks_dir, image_processor, transform=None):
        self.images_dir = Path(images_dir)
        self.masks_dir = Path(masks_dir)
        self.image_processor = image_processor
        self.transform = transform
        self.image_filenames = sorted(os.listdir(self.images_dir))

    def __len__(self):
        """Returns the total number of image-mask pairs in the dataset."""
        return len(self.image_filenames)

    def __getitem__(self, idx):
        """Retrieves and processes an image-mask pair.
        
        Args:
            idx (int): Index of the image-mask pair to retrieve
            
        Returns:
            tuple: (pixel_values, mask) where pixel_values is processed by DinoV2's
                  processor and mask is a long tensor
        """
        img_name = self.image_filenames[idx]
        img_path = self.images_dir / img_name
        mask_path = self.masks_dir / img_name
        
        # Load image in PIL format for DinoV2 processor
        image = Image.open(img_path).convert("RGB")
        mask = np.array(Image.open(mask_path))

        # Apply augmentations before DinoV2 processing
        if self.transform:
            # Convert PIL to numpy for albumentations
            augmented = self.transform(image=np.array(image), mask=mask)
            image = Image.fromarray(augmented['image'])
            mask = augmented['mask']
        
        # Process image using DinoV2's processor
        pixel_values = self.image_processor(image, return_tensors="pt").pixel_values.squeeze(0)
        
        # Convert mask to appropriate tensor type
        mask = torch.from_numpy(mask).long()
        return pixel_values, mask

# Initialize DinoV2 processor and define image size
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
DINO_IMAGE_SIZE = 224

# Training transformations with augmentation
dino_train_transform = A.Compose([
    A.Resize(height=DINO_IMAGE_SIZE, width=DINO_IMAGE_SIZE),
    A.HorizontalFlip(p=0.5),  # 50% chance of horizontal flip
])

# Validation transformations (no augmentation)
dino_val_transform = A.Compose([
    A.Resize(height=DINO_IMAGE_SIZE, width=DINO_IMAGE_SIZE),
])

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


## 3. DinoV2 Segmentation Model

The DinoV2-based segmentation model architecture consists of:

1. **Backbone**:
   - DinoV2-base pretrained model
   - Frozen weights to preserve learned features
   - Patch-based transformer architecture

2. **Segmentation Head**:
   - Custom convolutional decoder
   - Progressive upsampling layers
   - Channel reduction for memory efficiency

3. **Feature Processing**:
   - Patch token extraction and reshaping
   - Spatial feature map reconstruction
   - Final logits upsampling to input size

4. **Design Choices**:
   - Skip CLS token for dense prediction
   - Bilinear upsampling for smooth outputs
   - Multi-stage feature refinement

In [None]:
class DinoV2ForSemanticSegmentation(nn.Module):
    """DinoV2-based semantic segmentation model.
    
    Combines a frozen DinoV2 backbone with a trainable segmentation head
    to perform dense pixel-wise classification.
    
    Args:
        num_classes (int): Number of segmentation classes (default: 5)
    """
    def __init__(self, num_classes=5):
        super(DinoV2ForSemanticSegmentation, self).__init__()
        # Initialize pretrained DinoV2 backbone
        self.dinov2 = Dinov2Model.from_pretrained("facebook/dinov2-base")
        
        # Freeze backbone weights
        for param in self.dinov2.parameters():
            param.requires_grad = False
            
        # Segmentation head: progressive upsampling with channel reduction
        self.head = nn.Sequential(
            # First stage: 768 -> 256 channels
            nn.Conv2d(768, 256, kernel_size=3, padding=1), 
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            
            # Second stage: 256 -> 128 channels
            nn.Conv2d(256, 128, kernel_size=3, padding=1), 
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            
            # Third stage: 128 -> 64 channels
            nn.Conv2d(128, 64, kernel_size=3, padding=1), 
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            
            # Final layer: project to class logits
            nn.Conv2d(64, num_classes, kernel_size=1)
        )

    def forward(self, pixel_values):
        """Forward pass through the model.
        
        Args:
            pixel_values: Input tensor processed by DinoV2's processor
            
        Returns:
            torch.Tensor: Segmentation logits of shape (B, C, H, W)
        """
        # Get transformer features
        outputs = self.dinov2(pixel_values, output_hidden_states=True)
        last_hidden_state = outputs.last_hidden_state
        
        # Process patch tokens (exclude CLS token)
        patch_tokens = last_hidden_state[:, 1:, :]
        batch_size, seq_len, num_channels = patch_tokens.shape
        
        # Reshape into 2D feature map
        height = width = int(seq_len**0.5)
        feature_map = patch_tokens.permute(0, 2, 1).contiguous().reshape(
            batch_size, num_channels, height, width)
        
        # Apply segmentation head
        logits = self.head(feature_map)
        
        # Upsample to input resolution
        final_logits = nn.functional.interpolate(
            logits, 
            size=(DINO_IMAGE_SIZE, DINO_IMAGE_SIZE), 
            mode='bilinear', 
            align_corners=False
        )
        return final_logits

## 4. Configuration and Data Loading

Training configuration and data setup:

1. **Training Parameters**:
   - Learning rate: 1e-3 (higher due to frozen backbone)
   - Batch size: 4 samples
   - Number of epochs: 10
   - Hardware-adaptive device selection

2. **Run Management**:
   - Unique timestamp-based run names
   - Organized metrics directory structure
   - Automatic output path creation

3. **Data Organization**:
   - Structured directory hierarchy
   - Clear train/val split
   - Consistent naming conventions

In [None]:
# Training hyperparameters
LEARNING_RATE = 1e-3  # Higher LR since only training the head
BATCH_SIZE = 4
NUM_EPOCHS = 10

# Hardware configuration
if torch.cuda.is_available():
    DEVICE = "cuda"
elif torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"

# Set up run identification and metrics storage
run_name = f"dinov2_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
metrics_dir = Path("metrics") / run_name
metrics_dir.mkdir(parents=True, exist_ok=True)

print(f"Using device: {DEVICE}")
print(f"Metrics for this run will be saved in: {metrics_dir}")

# Data directory configuration
BASE_PROCESSED_DIR = Path("../data/processed/segmentation")
TRAIN_IMG_DIR = BASE_PROCESSED_DIR / "train" / "images"
TRAIN_MASK_DIR = BASE_PROCESSED_DIR / "train" / "masks"
VAL_IMG_DIR = BASE_PROCESSED_DIR / "val" / "images"
VAL_MASK_DIR = BASE_PROCESSED_DIR / "val" / "masks"

# Initialize datasets with appropriate transforms
train_dataset = PeatlandDinoDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, processor, dino_train_transform)
val_dataset = PeatlandDinoDataset(VAL_IMG_DIR, VAL_MASK_DIR, processor, dino_val_transform)

# Create data loaders for batched processing
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

Using device: mps
Metrics for this run will be saved in: metrics/dinov2_2025-08-03_17-48-54


## 5. Training Loop

The training implementation includes:

1. **Model Setup**:
   - DinoV2 model instantiation
   - Cross entropy loss for multi-class segmentation
   - AdamW optimizer with weight decay
   - Learning rate scheduling

2. **Training Process**:
   - Epoch-based training loop
   - Batch-wise processing
   - Gradient computation and updates
   - Loss tracking and logging

3. **Validation**:
   - Regular model evaluation
   - Pixel-wise accuracy calculation
   - Loss monitoring for early stopping
   - Best model checkpointing

4. **Progress Tracking**:
   - Per-epoch statistics
   - Training/validation metrics logging
   - Comprehensive CSV output

In [None]:
# Initialize model, loss function, and optimizer
model = DinoV2ForSemanticSegmentation(num_classes=5).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.head.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)

# Initialize training tracking
training_log = []
best_val_loss = float('inf')
epochs_no_improve = 0
patience = 3  # Early stopping after 3 epochs without improvement

# Training loop
for epoch in range(NUM_EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
    
    # Training phase
    model.train()
    train_loss = 0.0
    for images, masks in tqdm(train_loader, desc="Training"):
        # Move data to device
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        
        # Forward pass and loss calculation
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
    # Validation phase
    model.eval()
    val_loss, val_correct_pixels, val_total_pixels = 0.0, 0, 0
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validating"):
            # Move data to device
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Calculate metrics
            val_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            val_correct_pixels += (preds == masks).sum().item()
            val_total_pixels += torch.numel(masks)
            
    # Calculate epoch statistics
    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = (val_correct_pixels / val_total_pixels) * 100
    
    # Print epoch results
    print(f"Average Training Loss: {avg_train_loss:.4f}")
    print(f"Average Validation Loss: {avg_val_loss:.4f}")
    print(f"Validation Pixel Accuracy: {val_accuracy:.2f}%")
    
    # Learning rate scheduling
    scheduler.step(avg_val_loss)

    # Log metrics
    training_log.append({
        'epoch': epoch + 1,
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'val_accuracy': val_accuracy
    })

    # Model checkpointing
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), metrics_dir / "best_model.pth")
        print(f"Validation loss improved. Saving best model to {metrics_dir / 'best_model.pth'}")
    else:
        epochs_no_improve += 1
        print(f"Validation loss did not improve for {epochs_no_improve} epoch(s).")

    # Early stopping check
    if epochs_no_improve >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

# Save final model and metrics
model_save_path = metrics_dir / "peatland_dinov2_model.pth"
torch.save(model.state_dict(), model_save_path)
print(f"\nTraining complete. Model saved to {model_save_path}")

# Save training history
log_df = pd.DataFrame(training_log)
log_df.to_csv(metrics_dir / "training_log.csv", index=False)
print(f"Training log saved to {metrics_dir / 'training_log.csv'}")


--- Epoch 1/10 ---


Training: 100%|██████████| 998/998 [05:10<00:00,  3.21it/s]
Validating: 100%|██████████| 51/51 [00:15<00:00,  3.29it/s]


Average Training Loss: 0.5401
Average Validation Loss: 0.4033
Validation Pixel Accuracy: 84.40%
Validation loss improved. Saving best model to metrics/dinov2_2025-08-03_17-48-54/best_model.pth

--- Epoch 2/10 ---


Training: 100%|██████████| 998/998 [05:05<00:00,  3.26it/s]
Validating: 100%|██████████| 51/51 [00:15<00:00,  3.35it/s]


Average Training Loss: 0.4270
Average Validation Loss: 0.4714
Validation Pixel Accuracy: 81.19%
Validation loss did not improve for 1 epoch(s).

--- Epoch 3/10 ---


Training: 100%|██████████| 998/998 [05:12<00:00,  3.19it/s]
Validating: 100%|██████████| 51/51 [00:15<00:00,  3.22it/s]


Average Training Loss: 0.3901
Average Validation Loss: 0.3678
Validation Pixel Accuracy: 86.10%
Validation loss improved. Saving best model to metrics/dinov2_2025-08-03_17-48-54/best_model.pth

--- Epoch 4/10 ---


Training: 100%|██████████| 998/998 [05:10<00:00,  3.22it/s]
Validating: 100%|██████████| 51/51 [00:14<00:00,  3.52it/s]


Average Training Loss: 0.3707
Average Validation Loss: 0.3602
Validation Pixel Accuracy: 85.35%
Validation loss improved. Saving best model to metrics/dinov2_2025-08-03_17-48-54/best_model.pth

--- Epoch 5/10 ---


Training: 100%|██████████| 998/998 [04:54<00:00,  3.38it/s]
Validating: 100%|██████████| 51/51 [00:14<00:00,  3.54it/s]


Average Training Loss: 0.3552
Average Validation Loss: 0.3608
Validation Pixel Accuracy: 86.45%
Validation loss did not improve for 1 epoch(s).

--- Epoch 6/10 ---


Training: 100%|██████████| 998/998 [04:52<00:00,  3.41it/s]
Validating: 100%|██████████| 51/51 [00:14<00:00,  3.52it/s]


Average Training Loss: 0.3403
Average Validation Loss: 0.3628
Validation Pixel Accuracy: 85.98%
Validation loss did not improve for 2 epoch(s).

--- Epoch 7/10 ---


Training: 100%|██████████| 998/998 [05:00<00:00,  3.32it/s]
Validating: 100%|██████████| 51/51 [00:16<00:00,  3.07it/s]


Average Training Loss: 0.3305
Average Validation Loss: 0.3495
Validation Pixel Accuracy: 86.25%
Validation loss improved. Saving best model to metrics/dinov2_2025-08-03_17-48-54/best_model.pth

--- Epoch 8/10 ---


Training: 100%|██████████| 998/998 [05:06<00:00,  3.26it/s]
Validating: 100%|██████████| 51/51 [00:15<00:00,  3.39it/s]


Average Training Loss: 0.3208
Average Validation Loss: 0.3769
Validation Pixel Accuracy: 85.38%
Validation loss did not improve for 1 epoch(s).

--- Epoch 9/10 ---


Training: 100%|██████████| 998/998 [05:02<00:00,  3.30it/s]
Validating: 100%|██████████| 51/51 [00:15<00:00,  3.35it/s]


Average Training Loss: 0.3134
Average Validation Loss: 0.3252
Validation Pixel Accuracy: 87.32%
Validation loss improved. Saving best model to metrics/dinov2_2025-08-03_17-48-54/best_model.pth

--- Epoch 10/10 ---


Training: 100%|██████████| 998/998 [04:54<00:00,  3.39it/s]
Validating: 100%|██████████| 51/51 [00:14<00:00,  3.49it/s]


Average Training Loss: 0.3099
Average Validation Loss: 0.3456
Validation Pixel Accuracy: 87.37%
Validation loss did not improve for 1 epoch(s).

Training complete. Model saved to metrics/dinov2_2025-08-03_17-48-54/peatland_dinov2_model.pth
Training log saved to metrics/dinov2_2025-08-03_17-48-54/training_log.csv
