In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

In [6]:
class ResNetBackbone(nn.Module):
    def __init__(self, name='resnet101', pretrained=True, output_stride=8, freeze_bn=True):
        super(ResNetBackbone, self).__init__()
        # Load the specified ResNet model from torchvision
        if name == 'resnet101':
            self.resnet = models.resnet101(pretrained=pretrained)
        elif name == 'resnet50':
            self.resnet = models.resnet50(pretrained=pretrained)
        else:
            raise ValueError("Unsupported ResNet backbone name.")

        self.output_stride = output_stride
        self.freeze_bn = freeze_bn

        # DeepLabv3 modification:
        # Change strides and dilations for specific layers to control output_stride

        # --- Layer 4 Modification (for output_stride = 16) ---
        # The 'Bottleneck' block in torchvision's ResNet has the stride applied to conv2.
        # Its 'downsample' branch also gets the same stride.
        # To change output_stride from 32 to 16, we need to modify the first block of layer4.
        # Original: layer4[0].conv2.stride = (2,2), layer4[0].downsample[0].stride = (2,2)
        # New: layer4[0].conv2.stride = (1,1), layer4[0].downsample[0].stride = (1,1)
        # Then, apply dilation to all convs in layer4 to compensate for the removed stride.

        if self.output_stride == 16:
            # Set stride of the first Bottleneck in layer4 to 1
            # This makes the feature map size from layer3 (1/16) to layer4 (1/16)
            self.resnet.layer4[0].conv2.stride = (1, 1)
            self.resnet.layer4[0].downsample[0].stride = (1, 1)

            # Apply dilation to all convolutional layers within layer4
            # (conv2 and conv3 in each Bottleneck block).
            # The original blocks typically have a padding of 1 for 3x3 convs (rate 1).
            # For dilation r, padding should be r to maintain output size.
            for i, block in enumerate(self.resnet.layer4):
                # conv2 is the 3x3 convolution
                block.conv2.dilation = (2, 2)
                block.conv2.padding = (2, 2) # Padding = dilation rate for 3x3 kernel

                # For the third convolution (conv3, 1x1), no dilation needed usually.
                # However, if there are subsequent 3x3 convs within the block (not typical in Bottleneck)
                # or if a different block structure, apply dilation where appropriate.
                # In ResNet Bottleneck, conv3 is 1x1, so dilation doesn't apply.
                # No change to block.conv1 (1x1) as it's typically just channel transform.

        # --- Layer 3 & 4 Modification (for output_stride = 8) ---
        # To change output_stride from 32 to 8, we need to modify layer3 as well.
        # This will be more memory intensive.
        elif self.output_stride == 8:
            # Modify layer3 first
            self.resnet.layer3[0].conv2.stride = (1, 1)
            self.resnet.layer3[0].downsample[0].stride = (1, 1)
            for i, block in enumerate(self.resnet.layer3):
                block.conv2.dilation = (2, 2)
                block.conv2.padding = (2, 2)

            # Then modify layer4 (relative to original network, it's now 'twice' dilated)
            self.resnet.layer4[0].conv2.stride = (1, 1)
            self.resnet.layer4[0].downsample[0].stride = (1, 1)
            for i, block in enumerate(self.resnet.layer4):
                block.conv2.dilation = (4, 4) # Rate of 4 (2*2 from layer3+layer4)
                block.conv2.padding = (4, 4)

        else:
            raise ValueError("Unsupported output_stride. Must be 8 or 16.")

        # Freeze Batch Normalization layers if specified
        if self.freeze_bn:
            self._freeze_bn()

    def _freeze_bn(self):
        # Freeze Batch Normalization layers
        # This sets all BN layers to evaluation mode and stops updating their running stats.
        # Crucially, it also freezes their learnable gamma and beta parameters.
        # However, for DeepLab, we want to freeze running_mean/var but allow gamma/beta to update.
        # The correct way to do this in PyTorch is to set track_running_stats=False
        # and keep BN layers in training mode if their parameters should be updated.

        for m in self.resnet.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval() # Set to eval mode to use running_mean/var, not batch stats
                # Optionally, if gamma/beta *should* be updated, you would need to
                # ensure m.weight.requires_grad = True and m.bias.requires_grad = True.
                # By default, in m.eval(), these are still trainable, but their updates
                # won't rely on *batch* statistics for normalization.
                # The typical DeepLab implementation usually means BN is frozen completely
                # for the backbone to avoid issues with small batch sizes.
                # So m.eval() is the standard behavior in practice for this setting.
                # Alternatively:
                # m.track_running_stats = False # Use fixed running stats from pre-training
                # m.requires_grad_(False) # Freeze all parameters (gamma, beta too)
                # But the paper implies updating gamma/beta. This is a subtle point.
                # For simplicity and common practice, m.eval() is often used for backbone BN.
                # If we want to strictly follow "update gamma/beta but freeze stats",
                # it's m.track_running_stats = False and m.momentum = 0.0 (for no moving average update)
                # then ensure m.weight.requires_grad and m.bias.requires_grad are True.
                # For this implementation, we will use m.eval() for simplicity which implies
                # using the pre-computed means/vars and *not updating* gamma/beta for backbone BNs.
                # This aligns with common DeepLab implementations for stability.

    def forward(self, x):
        # Forward pass through ResNet layers
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x) # Output stride 4
        x = self.resnet.layer2(x) # Output stride 8 (for original) or 8 (for modified)
        low_level_features = x # Potentially useful for DeepLabV3+ decoder

        # Depending on output_stride, layer3 and layer4 are modified
        x = self.resnet.layer3(x) # Output stride 16 (for original) or 8 (for modified)
        x = self.resnet.layer4(x) # Output stride 32 (for original) or 16/8 (for modified)

        return x, low_level_features # Return both high-level and low-level features

In [None]:
class _ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        super(_ASPPConv, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        # Explanation of parameters:
        # in_channels: Number of input feature channels from the previous layer (e.g., ResNet's layer4 output channels).
        # out_channels: Number of output feature channels for this specific ASPP branch.
        # kernel_size=3: Standard 3x3 kernel size for feature extraction.
        # padding=dilation: Crucial for maintaining the output spatial dimensions when stride=1.
        #                   For a 3x3 kernel with dilation 'd', padding 'd' ensures output_size = input_size.
        # dilation=dilation: This is the atrous rate. It controls how sparsely the input is sampled.
        # bias=False: Typically set to False when BatchNorm2d is used immediately after, as BatchNorm
        #             introduces its own learnable bias (beta), making the Conv2d's bias redundant.
        # nn.BatchNorm2d(out_channels): Applies Batch Normalization.
        # nn.ReLU(): Applies Rectified Linear Unit activation.

In [None]:
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels_per_branch=256, atrous_rates=[6, 12, 18]):
        super(ASPP, self).__init__()
        
        # Ensure atrous_rates are provided and valid
        if not isinstance(atrous_rates, (list, tuple)):
            raise TypeError("atrous_rates must be a list or tuple of integers.")
        
        # 1. 1x1 Convolution branch
        # This branch acts as a baseline, capturing local point-wise features and reducing channels.
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels_per_branch, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels_per_branch),
            nn.ReLU()
        )

        # 2. Atrous Convolution branches with different dilation rates
        # These capture multi-scale contextual information.
        self.atrous_convs = nn.ModuleList() # Use ModuleList to hold multiple modules
        for rate in atrous_rates:
            self.atrous_convs.append(_ASPPConv(in_channels, out_channels_per_branch, dilation=rate))

        # 3. Image-level features branch (Global Average Pooling)
        # This captures global context.
        self.image_pooling = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), # Global Average Pooling: pools spatial dimensions to 1x1
            nn.Conv2d(in_channels, out_channels_per_branch, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels_per_branch),
            nn.ReLU()
        )

        # 4. Final 1x1 Convolution to fuse all branch outputs
        # It reduces the concatenated channels back to a unified 'out_channels_per_branch' size.
        # Total input channels for this final conv will be:
        # (1x1 conv branch) + (num_atrous_convs * out_channels_per_branch) + (image_pooling branch)
        # e.g., 256 + (3 * 256) + 256 = 256 * 5 = 1280 channels
        self.final_conv = nn.Sequential(
            nn.Conv2d(out_channels_per_branch * (len(atrous_rates) + 2), out_channels_per_branch, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels_per_branch),
            nn.ReLU(),
            nn.Dropout(0.5) # Dropout is often used for regularization
        )

    def forward(self, x):
        # Store original input dimensions for upsampling the image_pooling branch later
        input_size = x.size()[2:] # Get H, W from (N, C, H, W)

        # 1. Process through 1x1 convolution branch
        # Input: (N, in_channels, H, W) -> Output: (N, out_channels_per_branch, H, W)
        x1 = self.conv1x1(x)

        # 2. Process through Atrous Convolution branches
        # Each branch: Input: (N, in_channels, H, W) -> Output: (N, out_channels_per_branch, H, W)
        atrous_outputs = []
        for atrous_conv_layer in self.atrous_convs:
            atrous_outputs.append(atrous_conv_layer(x))

        # 3. Process through Image-level features branch
        # Input: (N, in_channels, H, W)
        # a. Global Average Pooling: Output (N, in_channels, 1, 1)
        x_pool = self.image_pooling(x)
        
        # b. Bilinear Upsampling: Resize pooled feature from 1x1 to original input_size (H, W)
        # This broadcasts the global context across all spatial locations.
        # Output: (N, out_channels_per_branch, H, W)
        x_pool_upsampled = F.interpolate(x_pool, size=input_size, mode='bilinear', align_corners=True)
        # align_corners=True is standard practice for bilinear upsampling in segmentation
        # to ensure pixel alignment.

        # 4. Concatenate all branch outputs
        # Stack all tensors along the channel dimension (dim=1).
        # List of tensors to concatenate: [x1] + atrous_outputs + [x_pool_upsampled]
        # All tensors should have the same spatial dimensions (H, W) and out_channels_per_branch.
        # The total channels will be out_channels_per_branch * (1 + len(atrous_rates) + 1)
        # e.g., 256 * (1 + 3 + 1) = 256 * 5 = 1280
        all_outputs = [x1] + atrous_outputs + [x_pool_upsampled]
        x_concat = torch.cat(all_outputs, dim=1)

        # 5. Final 1x1 Convolution to fuse concatenated features
        # Input: (N, 1280, H, W) -> Output: (N, out_channels_per_branch, H, W)
        output = self.final_conv(x_concat)

        return output

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torchvision import models # Already imported in Step 1
# Assuming ResNetBackbone and ASPP classes from Step 1 and 2 are defined above this.

class DeepLabV3(nn.Module):
    def __init__(self, num_classes, backbone_name='resnet101', pretrained_backbone=True, output_stride=16, freeze_bn_backbone=True):
        super(DeepLabV3, self).__init__()

        self.num_classes = num_classes
        self.output_stride = output_stride

        # 1. Initialize the Backbone (Modified ResNet)
        # The ResNetBackbone will handle loading the pre-trained model and applying atrous modifications.
        # It typically returns two feature maps: high_level_features (from layer4) and low_level_features (from layer2).
        self.backbone = ResNetBackbone(name=backbone_name,
                                       pretrained=pretrained_backbone,
                                       output_stride=output_stride,
                                       freeze_bn=freeze_bn_backbone)

        # Determine input channels for ASPP based on the chosen backbone and output_stride.
        # For ResNet-50/101/152, layer4 typically outputs 2048 channels.
        # If output_stride=8, layer3 also has atrous convs, so input to layer4 is denser.
        # However, the *number of channels* from layer4 remains 2048 regardless of stride/dilation.
        aspp_in_channels = 2048 # ResNet's layer4 output channels

        # Define atrous rates for ASPP based on output_stride
        # These rates are carefully chosen to cover different scales and avoid gridding.
        # A common set of rates for output_stride=16 are [6, 12, 18].
        # For output_stride=8, the rates are doubled: [12, 24, 36].
        if output_stride == 16:
            aspp_atrous_rates = [6, 12, 18]
        elif output_stride == 8:
            aspp_atrous_rates = [12, 24, 36]
        else:
            raise ValueError("ASPP atrous rates are only defined for output_stride 8 or 16.")

        # 2. Initialize the Atrous Spatial Pyramid Pooling (ASPP) module
        # It takes the high-level features from the backbone.
        # The output channels of ASPP are typically 256 (out_channels_per_branch in ASPP.__init__).
        self.aspp = ASPP(in_channels=aspp_in_channels,
                         out_channels_per_branch=256, # Common choice for ASPP output channels
                         atrous_rates=aspp_atrous_rates)

        # 3. Final Classification Head
        # This is a 1x1 convolution that maps the ASPP output channels (e.g., 256)
        # to the number of desired semantic classes (num_classes).
        self.classifier = nn.Conv2d(256, num_classes, kernel_size=1)

        # Initialize new layers (ASPP and classifier) with common practices.
        # Pre-trained layers in backbone are already initialized.
        self._init_weights()

    def _init_weights(self):
        # Initialize convolutional layers with He initialization (Kaiming normal)
        # and Batch Normalization layers with default (mean=0, var=1, gamma=1, beta=0).
        # This only applies to *newly added* layers (ASPP and classifier),
        # as backbone is pre-trained.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # He initialization is suitable for ReLU activations
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                # BatchNorm default initialization (weights=1, biases=0) is usually fine
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Store original input dimensions for final upsampling
        input_size = x.size()[2:] # Get H, W from (N, C, H, W)

        # 1. Forward pass through the backbone
        # backbone_output is the high-level feature map from layer4 (e.g., N, 2048, H/16, W/16)
        # low_level_features is from layer2 (e.g., N, 512, H/4, W/4) -- Not directly used in pure DeepLabV3,
        # but kept as output of backbone for potential DeepLabV3+ extension.
        backbone_output, low_level_features = self.backbone(x)

        # 2. Forward pass through the ASPP module
        # aspp_output: (N, 256, H/16, W/16) or (N, 256, H/8, W/8)
        aspp_output = self.aspp(backbone_output)

        # 3. Forward pass through the final classification head
        # logits: (N, num_classes, H/16, W/16) or (N, num_classes, H/8, W/8)
        logits = self.classifier(aspp_output)

        # 4. Upsample logits to original input image resolution
        # This is crucial for per-pixel classification.
        # Output: (N, num_classes, H, W)
        output = F.interpolate(logits, size=input_size, mode='bilinear', align_corners=True)

        return output

In [14]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image # For loading images
import os # For path manipulation
import numpy as np # For potential mask operations (e.g., handling ignore_index)

In [None]:
class SemanticSegmentationDataset(Dataset):
    def __init__(self, root_dir, image_dir, mask_dir, transform=None, ignore_index=255):
        """
        Args:
            root_dir (str): Base directory for the dataset (e.g., 'path/to/voc2012').
            image_dir (str): Subdirectory containing images (e.g., 'JPEGImages').
            mask_dir (str): Subdirectory containing segmentation masks (e.g., 'SegmentationClassAug').
            transform (callable, optional): Optional transform to be applied on a sample.
            ignore_index (int): Label value to ignore in loss calculation, typically 255.
        """
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, image_dir)
        self.mask_dir = os.path.join(root_dir, mask_dir)
        self.transform = transform
        self.ignore_index = ignore_index

        # List all image files. Assuming mask filenames correspond to image filenames.
        # Example: image_dir/2007_000033.jpg, mask_dir/2007_000033.png
        self.image_filenames = sorted([f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))])
        # For simplicity, we assume mask filenames are the same as image filenames but with .png extension
        self.mask_filenames = sorted([f.replace('.jpg', '.png') if f.endswith('.jpg') else f for f in self.image_filenames])

        # Basic check to ensure masks exist for all images (optional but good for robustness)
        if len(self.image_filenames) != len(self.mask_filenames):
            # This is a simplification; in real datasets, you might need a list of image IDs
            # and map them to their respective mask files.
            print("Warning: Number of images and masks do not match. Ensure proper filename logic.")
        
        # Verify all mask files exist
        # for mask_fn in self.mask_filenames:
        #     if not os.path.exists(os.path.join(self.mask_dir, mask_fn)):
        #         raise FileNotFoundError(f"Mask file not found: {os.path.join(self.mask_dir, mask_fn)}")


    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.image_filenames)

    def __getitem__(self, idx):
        """
        Loads and returns one sample (image, mask) from the dataset.
        Args:
            idx (int): Index of the sample to load.
        Returns:
            tuple: (image, mask)
        """
        # Construct full paths to the image and mask files
        img_name = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_name = os.path.join(self.mask_dir, self.mask_filenames[idx])

        # Load image and mask using PIL
        # Ensure image is RGB for consistency
        image = Image.open(img_name).convert('RGB')
        # Segmentation masks are typically single channel (grayscale)
        # Note: 'L' mode is for grayscale. If masks are RGB with specific colors,
        # you might need to convert RGB colors to class IDs. For Pascal VOC,
        # masks are single channel with class IDs.
        mask = Image.open(mask_name).convert('L') 
        
        # Apply transformations if provided
        if self.transform:
            # IMPORTANT: Transforms must be applied identically to both image and mask.
            # This requires a custom transform that operates on both.
            image, mask = self.transform(image, mask)

        # Convert mask to long tensor (required for CrossEntropyLoss)
        # Handle ignore_index if specified (e.g., map to actual class ID 255 for loss calculation)
        # Ensure mask is of type torch.long for nn.CrossEntropyLoss
        mask = torch.from_numpy(np.array(mask, dtype=np.long))
        mask[mask == self.ignore_index] = self.ignore_index # Keep ignore_index as is, or map to a specific value

        return image, mask