In [3]:
"""
UNet implementation for pet segmentation
This module implements a UNet model that follows the architecture and hyperparameters
defined by nnU-Net for the pet segmentation task.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Union, Optional, Type


class SpatialDropout2d(nn.Module):
    """
    Spatial dropout for 2D feature maps that drops entire channels.
    This performs better than standard dropout for convolutional features.
    """
    def __init__(self, drop_prob):
        super(SpatialDropout2d, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if not self.training or self.drop_prob == 0:
            return x
            
        # Get dimensions
        _, channels, height, width = x.size()
        
        # Sample binary dropout mask
        mask = x.new_empty(x.size(0), channels, 1, 1).bernoulli_(1 - self.drop_prob)
        mask = mask.div_(1 - self.drop_prob)
        
        # Apply mask
        x = x * mask.expand_as(x)
        return x

class ConvBlock(nn.Module):
    """
    Basic convolutional block for UNet with spatial dropout.
    This block consists of n_convs convolutional layers, each followed by normalization, 
    activation, and optional spatial dropout.
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        stride: Union[int, Tuple[int, int]],
        n_convs: int = 2,
        padding: Optional[int] = None,
        norm_op: Type[nn.Module] = nn.InstanceNorm2d,
        norm_op_kwargs: Dict = None,
        dropout_op: Optional[Type[nn.Module]] = None,
        dropout_op_kwargs: Dict = None,
        nonlin: Type[nn.Module] = nn.LeakyReLU,
        nonlin_kwargs: Dict = None,
        conv_bias: bool = True,
        spatial_dropout_rate: float = 0.0
    ):
        """
        Initialize the ConvBlock.
        
        Args:
            in_channels: Number of input channels
            out_channels: Number of output channels
            kernel_size: Size of the convolutional kernel
            stride: Stride of the convolution
            n_convs: Number of convolutional layers in the block
            padding: Padding size (if None, calculated to maintain spatial dimensions)
            norm_op: Normalization operation to use
            norm_op_kwargs: Arguments for normalization operation
            dropout_op: Dropout operation to use (if any)
            dropout_op_kwargs: Arguments for dropout operation
            nonlin: Non-linear activation function to use
            nonlin_kwargs: Arguments for non-linear activation
            conv_bias: Whether to use bias in convolutions
            spatial_dropout_rate: Rate for spatial dropout (0 to disable)
        """
        super(ConvBlock, self).__init__()
        
        # Default arguments if not provided
        if norm_op_kwargs is None:
            norm_op_kwargs = {'eps': 1e-5, 'affine': True}
        if nonlin_kwargs is None:
            nonlin_kwargs = {'inplace': True}
        if dropout_op_kwargs is None:
            dropout_op_kwargs = {}
        
        # Calculate padding if not provided
        if padding is None:
            if isinstance(kernel_size, int):
                padding = kernel_size // 2
            else:
                padding = (kernel_size[0] // 2, kernel_size[1] // 2)
        
        # Create the convolutional blocks
        layers = []
        current_channels = in_channels
        
        for i in range(n_convs):
            # Only apply stride in the first convolution
            current_stride = stride if i == 0 else 1
            
            # Add convolutional layer
            layers.append(
                nn.Conv2d(
                    current_channels,
                    out_channels,
                    kernel_size,
                    current_stride,
                    padding,
                    bias=conv_bias
                )
            )
            
            # Add normalization
            if norm_op is not None:
                layers.append(norm_op(out_channels, **norm_op_kwargs))
            
            # Add non-linearity
            if nonlin is not None:
                layers.append(nonlin(**nonlin_kwargs))
            
            # Add spatial dropout if rate > 0
            if spatial_dropout_rate > 0:
                layers.append(SpatialDropout2d(spatial_dropout_rate))
            
            # Add regular dropout if specified
            if dropout_op is not None:
                layers.append(dropout_op(**dropout_op_kwargs))
            
            # Update current number of channels
            current_channels = out_channels
        
        # Create the sequential block
        self.block = nn.Sequential(*layers)
    
    def forward(self, x):
        """Forward pass of the convolutional block."""
        return self.block(x)

class UpBlock(nn.Module):
    """
    Upsampling block for the decoder part of UNet.
    This block upsamples the feature maps and concatenates with skip connections.
    """
    
    def __init__(
        self,
        in_channels: int,
        skip_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        n_convs: int = 2,
        norm_op: Type[nn.Module] = nn.InstanceNorm2d,
        norm_op_kwargs: Dict = None,
        dropout_op: Optional[Type[nn.Module]] = None,
        dropout_op_kwargs: Dict = None,
        nonlin: Type[nn.Module] = nn.LeakyReLU,
        nonlin_kwargs: Dict = None,
        conv_bias: bool = True,
        spatial_dropout_rate: float = 0.0
    ):
        """
        Initialize the UpBlock.
        
        Args:
            in_channels: Number of input channels from the lower level
            skip_channels: Number of channels from the skip connection
            out_channels: Number of output channels
            kernel_size: Size of the convolutional kernel
            n_convs: Number of convolutional layers in the block
            norm_op: Normalization operation to use
            norm_op_kwargs: Arguments for normalization operation
            dropout_op: Dropout operation to use (if any)
            dropout_op_kwargs: Arguments for dropout operation
            nonlin: Non-linear activation function to use
            nonlin_kwargs: Arguments for non-linear activation
            conv_bias: Whether to use bias in convolutions
            spatial_dropout_rate: Rate for spatial dropout (0 to disable)
        """
        super(UpBlock, self).__init__()
        
        # Create the convolution block
        self.conv_block = ConvBlock(
            in_channels + skip_channels,
            out_channels,
            kernel_size,
            stride=1,
            n_convs=n_convs,
            padding=None,
            norm_op=norm_op,
            norm_op_kwargs=norm_op_kwargs,
            dropout_op=dropout_op,
            dropout_op_kwargs=dropout_op_kwargs,
            nonlin=nonlin,
            nonlin_kwargs=nonlin_kwargs,
            conv_bias=conv_bias,
            spatial_dropout_rate=spatial_dropout_rate
        )
    
    def forward(self, x, skip):
        """
        Forward pass of the upsampling block.
        
        Args:
            x: Input feature maps from lower level
            skip: Feature maps from skip connection
        
        Returns:
            Output feature maps
        """
        # Upsample the input to match skip connection size
        x_shape = x.shape[2:]
        skip_shape = skip.shape[2:]
        
        # Compute upsampling size to match skip connection
        if x_shape[0] != skip_shape[0] or x_shape[1] != skip_shape[1]:
            x = F.interpolate(
                x,
                size=skip_shape,
                mode='bilinear',
                align_corners=False
            )
        
        # Concatenate with skip connection
        x = torch.cat([x, skip], dim=1)
        
        # Apply convolution block
        return self.conv_block(x)

class UNet(nn.Module):
    """
    UNet model for semantic segmentation with original architecture and added spatial dropout.
    """
    
    def __init__(
        self,
        in_channels: int = 3,
        num_classes: int = 3,
        n_stages: int = 8,  # Original depth
        features_per_stage: List[int] = None,
        kernel_sizes: List[Tuple[int, int]] = None,
        strides: List[Tuple[int, int]] = None,
        n_conv_per_stage: List[int] = None,
        n_conv_per_stage_decoder: List[int] = None,
        conv_bias: bool = True,
        norm_op: Type[nn.Module] = nn.InstanceNorm2d,
        norm_op_kwargs: Dict = None,
        dropout_op: Optional[Type[nn.Module]] = None,
        dropout_op_kwargs: Dict = None,
        nonlin: Type[nn.Module] = nn.LeakyReLU,
        nonlin_kwargs: Dict = None,
        encoder_dropout_rates: List[float] = None,
        decoder_dropout_rates: List[float] = None
    ):
        """
        Initialize the UNet model.
        
        Args:
            in_channels: Number of input channels (3 for RGB images)
            num_classes: Number of output classes (3 for background, cat, dog)
            n_stages: Number of stages in the encoder
            features_per_stage: Number of features per stage
            kernel_sizes: Kernel sizes for each stage
            strides: Strides for each stage
            n_conv_per_stage: Number of convolutions per encoder stage
            n_conv_per_stage_decoder: Number of convolutions per decoder stage
            conv_bias: Whether to use bias in convolutions
            norm_op: Normalization operation to use
            norm_op_kwargs: Arguments for normalization operation
            dropout_op: Dropout operation to use (if any)
            dropout_op_kwargs: Arguments for dropout operation
            nonlin: Non-linear activation function to use
            nonlin_kwargs: Arguments for non-linear activation
            encoder_dropout_rates: Dropout rates for each encoder stage
            decoder_dropout_rates: Dropout rates for each decoder stage
        """
        super(UNet, self).__init__()
        
        # Set default values for parameters if not provided
        if features_per_stage is None:
            features_per_stage = [32, 64, 128, 256, 512, 512, 512, 512]  # Original architecture
        
        if kernel_sizes is None:
            kernel_sizes = [[3, 3]] * n_stages
        
        if strides is None:
            strides = [[1, 1]] + [[2, 2]] * (n_stages - 1)
        
        if n_conv_per_stage is None:
            n_conv_per_stage = [2] * n_stages
        
        if n_conv_per_stage_decoder is None:
            n_conv_per_stage_decoder = [2] * (n_stages - 1)
        
        if norm_op_kwargs is None:
            norm_op_kwargs = {'eps': 1e-5, 'affine': True}
        
        if nonlin_kwargs is None:
            nonlin_kwargs = {'inplace': True}
            
        # Default dropout rates if not provided
        if encoder_dropout_rates is None:
            # Light dropout in encoder, increasing with depth
            encoder_dropout_rates = [0.0, 0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 0.3]
            
        if decoder_dropout_rates is None:
            # Dropout in decoder, decreasing toward output
            decoder_dropout_rates = [0.3, 0.2, 0.2, 0.1, 0.1, 0.0, 0.0]
        
        # Store parameters
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.n_stages = n_stages
        self.features_per_stage = features_per_stage
        
        # Create encoder stages
        self.encoder_stages = nn.ModuleList()
        
        current_channels = in_channels
        
        for stage in range(n_stages):
            # Create encoder block
            self.encoder_stages.append(
                ConvBlock(
                    current_channels,
                    features_per_stage[stage],
                    kernel_sizes[stage],
                    strides[stage],
                    n_convs=n_conv_per_stage[stage],
                    norm_op=norm_op,
                    norm_op_kwargs=norm_op_kwargs,
                    dropout_op=dropout_op,
                    dropout_op_kwargs=dropout_op_kwargs,
                    nonlin=nonlin,
                    nonlin_kwargs=nonlin_kwargs,
                    conv_bias=conv_bias,
                    spatial_dropout_rate=encoder_dropout_rates[stage]
                )
            )
            
            # Update current channels
            current_channels = features_per_stage[stage]
        
        # Create decoder stages
        self.decoder_stages = nn.ModuleList()
        
        for stage in range(n_stages - 1):
            # Decoder stage goes in reverse order
            decoder_idx = n_stages - 2 - stage
            
            # Create decoder block
            self.decoder_stages.append(
                UpBlock(
                    features_per_stage[decoder_idx + 1],
                    features_per_stage[decoder_idx],
                    features_per_stage[decoder_idx],
                    kernel_sizes[decoder_idx],
                    n_convs=n_conv_per_stage_decoder[decoder_idx],
                    norm_op=norm_op,
                    norm_op_kwargs=norm_op_kwargs,
                    dropout_op=dropout_op,
                    dropout_op_kwargs=dropout_op_kwargs,
                    nonlin=nonlin,
                    nonlin_kwargs=nonlin_kwargs,
                    conv_bias=conv_bias,
                    spatial_dropout_rate=decoder_dropout_rates[stage]
                )
            )
        
        # Create final segmentation output layer
        self.segmentation_output = nn.Conv2d(
            features_per_stage[0],  # First decoder stage features
            num_classes,           # Number of output classes
            kernel_size=1,         # 1x1 convolution
            stride=1,
            padding=0,
            bias=True
        )
        
        # Initialize weights
        self.initialize_weights()
    
    def initialize_weights(self):
        """Initialize the weights of the network."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.InstanceNorm2d):
                if m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        """
        Forward pass of the UNet model.
        
        Args:
            x: Input tensor of shape (batch_size, in_channels, height, width)
                
        Returns:
            output: Final output tensor
        """
        # Store skip connections
        skip_connections = []
        
        # Encoder path
        for stage in self.encoder_stages[:-1]:  # All but the last stage
            x = stage(x)
            skip_connections.append(x)
        
        # Bottom stage (without skip connection)
        x = self.encoder_stages[-1](x)
        
        # Decoder path
        for idx, decoder_stage in enumerate(self.decoder_stages):
            # Use the appropriate skip connection (in reverse order)
            skip_idx = len(skip_connections) - 1 - idx
            skip = skip_connections[skip_idx]
            
            # Decoder block
            x = decoder_stage(x, skip)
        
        # Final 1x1 convolution to produce segmentation map
        output = self.segmentation_output(x)
        
        return output

In [4]:
#!/usr/bin/env python
"""
Script: train.py

This script trains a UNet model for pet segmentation using the Oxford-IIIT Pet Dataset.
It handles the training loop, validation, checkpointing, and logging.

Example Usage:
    python src/train.py --data_dir data/processed --output_dir models/unet_pet_segmentation
"""

import argparse
import json
import os
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import cv2
from tqdm import tqdm

# Import local modules

class PetSegmentationDataset(Dataset):
    """Dataset class for the Oxford-IIIT Pet segmentation dataset."""
    
    def __init__(
        self,
        images_dir: str,
        masks_dir: str,
        include_augmented: bool = True,
        target_size: Tuple[int, int] = (512, 512)
    ):
        """
        Initialize the dataset.
        
        Args:
            images_dir: Directory containing images
            masks_dir: Directory containing mask annotations
            include_augmented: Whether to include augmented images (if available)
            target_size: Target size for images and masks (height, width)
        """
        self.images_dir = Path(images_dir)
        self.masks_dir = Path(masks_dir)
        self.target_size = target_size
        
        # Get all image files from the directory
        self.image_files = sorted(list(self.images_dir.glob("*.jpg")))
        
        # Check for augmented data directory
        if include_augmented and (self.images_dir.parent / "augmented" / "images").exists():
            aug_images_dir = self.images_dir.parent / "augmented" / "images"
            aug_masks_dir = self.images_dir.parent / "augmented" / "masks"
            
            # Add augmented images to the dataset
            aug_image_files = sorted(list(aug_images_dir.glob("*.jpg")))
            self.aug_image_files = aug_image_files
            self.aug_masks_dir = aug_masks_dir
            
            self.image_files.extend(aug_image_files)
        else:
            self.aug_image_files = []
            self.aug_masks_dir = None
    
    def __len__(self) -> int:
        """Return the number of samples in the dataset."""
        return len(self.image_files)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """
        Get a sample from the dataset.
        
        Args:
            idx: Index of the sample to retrieve
            
        Returns:
            Dict containing image and mask tensors
        """
        # Get image file path
        img_path = self.image_files[idx]
        
        # Determine if this is an augmented image
        is_augmented = img_path in self.aug_image_files if self.aug_image_files else False
        
        # Get corresponding mask file path
        if is_augmented and self.aug_masks_dir:
            mask_path = self.aug_masks_dir / f"{img_path.stem}.png"
        else:
            mask_path = self.masks_dir / f"{img_path.stem}.png"
        
        # Load image and mask
        try:
            image = cv2.imread(str(img_path))
            if image is None:
                raise ValueError(f"Failed to load image: {img_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
            if mask is None:
                raise ValueError(f"Failed to load mask: {mask_path}")
                
            # Store original dimensions before resizing
            original_dims = mask.shape[:2]  # (height, width)
        except Exception as e:
            print(f"Error loading image or mask: {e}")
            # Return a blank sample as fallback
            image = np.zeros((self.target_size[0], self.target_size[1], 3), dtype=np.uint8)
            mask = np.zeros(self.target_size, dtype=np.uint8)
            original_dims = self.target_size
        
        # Ensure image and mask have the correct dimensions
        if image.shape[:2] != self.target_size:
            image = cv2.resize(image, (self.target_size[1], self.target_size[0]), interpolation=cv2.INTER_LINEAR)
        
        if mask.shape != self.target_size:
            mask = cv2.resize(mask, (self.target_size[1], self.target_size[0]), interpolation=cv2.INTER_NEAREST)
        
        # Keep 255 as is - we'll handle it properly in the loss function
        # Just ensure other values are within valid range (0, 1, 2)
        mask = np.where((mask > 2) & (mask != 255), 0, mask)
        
        # Convert image to tensor and normalize (0-1)
        image = torch.from_numpy(image).float().permute(2, 0, 1) / 255.0
        
        # Apply standardization (approximately equivalent to ImageNet normalization)
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        image = (image - mean) / std
        
        # Convert mask to tensor
        mask = torch.from_numpy(mask).long()
        
        # Store original dimensions as tensor
        original_dims = torch.tensor(original_dims)
        
        return {
            "image": image, 
            "mask": mask,
            "original_dims": original_dims
        }



In [5]:
# 📦 Imports
import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# 🧠 Load best model
model = UNet(
    in_channels=3,
    num_classes=3,
    n_stages=6,
    features_per_stage=[32, 64, 128, 256, 512, 512],
    kernel_sizes=[[3, 3]] * 6,
    strides=[[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]],
    n_conv_per_stage=[2] * 6,
    n_conv_per_stage_decoder=[2] * 5,
    conv_bias=True,
    norm_op=torch.nn.InstanceNorm2d,
    norm_op_kwargs={"eps": 1e-5, "affine": True},
    dropout_op=None,
    nonlin=torch.nn.LeakyReLU,
    nonlin_kwargs={"inplace": True},
    encoder_dropout_rates=[0.0, 0.0, 0.1, 0.2, 0.3, 0.3],
    decoder_dropout_rates=[0.3, 0.2, 0.2, 0.1, 0.0]
)

model.load_state_dict(torch.load(
    "/home/ulixes/segmentation_cv/unet/models/unet_pet_segmentation_reduced/best_model.pth",
    map_location="cuda" if torch.cuda.is_available() else "cpu"
)["model_state_dict"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()


UNet(
  (encoder_stages): ModuleList(
    (0): ConvBlock(
      (block): Sequential(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.01, inplace=True)
        (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (5): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (1): ConvBlock(
      (block): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.01, inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, trac

In [6]:
from torch.utils.data import DataLoader

test_dataset = PetSegmentationDataset(
    images_dir="/home/ulixes/segmentation_cv/unet/data/processed/Test/resized",
    masks_dir="/home/ulixes/segmentation_cv/unet/data/processed/Test/processed_labels",
    include_augmented=False,
    target_size=(512, 512)
)

test_loader = DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)


In [7]:
import numpy as np

class SegmentationMetrics:
    def __init__(self, num_classes, ignore_index=255):
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.reset()
        
    def reset(self):
        # Initialize accumulators for each class
        self.intersections = np.zeros(self.num_classes)
        self.unions = np.zeros(self.num_classes)
        self.true_positives = np.zeros(self.num_classes)
        self.false_positives = np.zeros(self.num_classes)
        self.false_negatives = np.zeros(self.num_classes)
        self.total_pixels = 0
        self.correct_pixels = 0
        
    def update(self, pred, target):
        """
        Update accumulators with a new batch of predictions and targets
        
        Args:
            pred: prediction tensor/array
            target: ground truth tensor/array
        """
        mask = (target != self.ignore_index)
        self.total_pixels += mask.sum()
        self.correct_pixels += ((pred == target) & mask).sum()
        
        # Update class-wise metrics
        for cls in range(self.num_classes):
            pred_cls = (pred == cls) & mask
            target_cls = (target == cls) & mask
            
            intersection = (pred_cls & target_cls).sum()
            union = (pred_cls | target_cls).sum()
            
            # Update accumulators
            self.intersections[cls] += intersection
            self.unions[cls] += union
            
            # For dice coefficient
            self.true_positives[cls] += intersection
            self.false_positives[cls] += pred_cls.sum() - intersection
            self.false_negatives[cls] += target_cls.sum() - intersection
            
    def compute_dice(self, cls):
        """
        Compute Dice coefficient for a specific class using accumulated statistics
        
        Args:
            cls: class index
            
        Returns:
            dice: Dice coefficient for the specified class
        """
        numerator = 2 * self.true_positives[cls]
        denominator = 2 * self.true_positives[cls] + self.false_positives[cls] + self.false_negatives[cls]
        
        if denominator > 0:
            return (numerator / denominator).item()
        return float('nan')
    
    def compute_pixel_accuracy(self):
        """
        Compute overall pixel accuracy using accumulated statistics
        
        Returns:
            accuracy: Pixel accuracy across the entire dataset
        """
        if self.total_pixels > 0:
            return (self.correct_pixels / self.total_pixels).item()
        return float('nan')
    
    def compute_iou(self, cls):
        """
        Compute IoU for a specific class using accumulated statistics
        
        Args:
            cls: class index
            
        Returns:
            iou: IoU for the specified class
        """
        if self.unions[cls] > 0:
            return (self.intersections[cls] / self.unions[cls]).item()
        return float('nan')
    
    def compute_mean_iou(self):
        """
        Compute mean IoU across all classes
        
        Returns:
            mean_iou: Mean IoU value
        """
        valid_ious = []
        for cls in range(self.num_classes):
            iou = self.compute_iou(cls)
            if not np.isnan(iou):
                valid_ious.append(iou)
        
        if valid_ious:
            return sum(valid_ious) / len(valid_ious)
        return float('nan')

# Functions that maintain the same API as original but use the accumulation approach
def compute_dice(pred, target, cls, ignore_index=255):
    """
    This function should be used for a single prediction-target pair.
    For dataset-level metrics, use SegmentationMetrics class.
    """
    metrics = SegmentationMetrics(num_classes=cls+1, ignore_index=ignore_index)
    metrics.update(pred, target)
    return metrics.compute_dice(cls)

def compute_pixel_accuracy(pred, target, ignore_index=255):
    """
    This function should be used for a single prediction-target pair.
    For dataset-level metrics, use SegmentationMetrics class.
    """
    metrics = SegmentationMetrics(num_classes=max(pred.max(), target.max())+1, ignore_index=ignore_index)
    metrics.update(pred, target)
    return metrics.compute_pixel_accuracy()

def compute_iou(pred, target, cls, ignore_index=255):
    """
    This function should be used for a single prediction-target pair.
    For dataset-level metrics, use SegmentationMetrics class.
    """
    metrics = SegmentationMetrics(num_classes=cls+1, ignore_index=ignore_index)
    metrics.update(pred, target)
    return metrics.compute_iou(cls)

# Example usage:
# For dataset-level metrics:
# metrics = SegmentationMetrics(num_classes=3)
# for pred, target in dataset:
#     metrics.update(pred, target)
# 
# # Get metrics after processing all images
# mean_iou = metrics.compute_mean_iou()
# class_0_iou = metrics.compute_iou(0)
# pixel_accuracy = metrics.compute_pixel_accuracy()

In [8]:
print(f"Test samples found: {len(test_dataset)}")


Test samples found: 3694


In [9]:

metrics = SegmentationMetrics(num_classes=3, ignore_index=255)

with torch.no_grad():
    for batch in tqdm(test_loader):
        images = batch["image"].to(device)
        masks = batch["mask"].to(device)
        original_dims = batch["original_dims"]

        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)

        for i in range(preds.size(0)):
            orig_h, orig_w = original_dims[i]
            pred_resized = F.interpolate(
                preds[i][None, None].float(),
                size=(orig_h, orig_w),
                mode="nearest"
            ).squeeze().long()

            mask_resized = F.interpolate(
                masks[i][None, None].float(),
                size=(orig_h, orig_w),
                mode="nearest"
            ).squeeze().long()

            # print(f"Original mask shape: {masks[i].shape}")
            # print(f"Target dimensions: {orig_h}x{orig_w}")
            # print(f"Resized pred shape: {pred_resized.shape}")
            # print(f"Resized mask shape: {mask_resized.shape}")

            metrics.update(pred_resized.cpu().numpy(), mask_resized.cpu().numpy())


100%|██████████| 924/924 [00:27<00:00, 33.14it/s]


In [10]:
print(f"Pixel Accuracy:       {metrics.compute_pixel_accuracy():.4f}")
print(f"Dice (Background):    {metrics.compute_dice(0):.4f}")
print(f"Dice (Cat):           {metrics.compute_dice(1):.4f}")
print(f"Dice (Dog):           {metrics.compute_dice(2):.4f}")
print(f"Mean Foreground Dice: {np.nanmean([metrics.compute_dice(1), metrics.compute_dice(2)]):.4f}")
print(f"IoU (Background):     {metrics.compute_iou(0):.4f}")
print(f"IoU (Cat):            {metrics.compute_iou(1):.4f}")
print(f"IoU (Dog):            {metrics.compute_iou(2):.4f}")
print(f"Mean IoU:             {metrics.compute_mean_iou():.4f}")


Pixel Accuracy:       0.8757
Dice (Background):    0.9260
Dice (Cat):           0.7207
Dice (Dog):           0.7816
Mean Foreground Dice: 0.7511
IoU (Background):     0.8622
IoU (Cat):            0.5633
IoU (Dog):            0.6414
Mean IoU:             0.6890
