### DeepLab-MSc

In [2]:
import torch # The fundamental PyTorch library for tensor operations and neural networks.
import torch.nn as nn # Contains modules for building neural networks (e.g., Conv2d, ReLU, MaxPool2d).
import torch.nn.functional as F # Provides functions that don't have trainable parameters (e.g., interpolation).
import torchvision.models as models # Provides access to pre-trained models like VGG.

In [3]:
class MLPBranch(nn.Module):
    def __init__(self, in_channels, out_channels=128):
        # Call the constructor of the parent class (nn.Module).
        # This is essential for PyTorch modules to properly initialize themselves,
        # register parameters, and manage internal states.
        super(MLPBranch, self).__init__()

        # First layer: A 2D Convolutional layer with a 3x3 kernel.
        # This is the first part of the "two-layer MLP" as described.
        # It processes the input feature map (from a VGG pooling layer) locally.
        # in_channels: The number of input feature channels from the VGG pooling layer (e.g., 64 for pool1, 128 for pool2, etc.).
        # out_channels: The number of output feature channels, fixed to 128 as specified in the paper.
        # kernel_size=3: The filter size is 3x3, allowing it to capture local spatial patterns.
        # padding=1: 'same' padding is used to ensure the output feature map has the same spatial dimensions as the input.
        #           For a 3x3 kernel, padding=1 maintains spatial resolution.
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        # ReLU activation function for non-linearity after the first convolution.
        # This introduces non-linearity, allowing the network to learn more complex patterns.
        self.relu1 = nn.ReLU(inplace=True) # inplace=True saves memory by modifying input directly.

        # Second layer: A 2D Convolutional layer with a 1x1 kernel.
        # This is the second part of the "two-layer MLP".
        # in_channels: The output channels from the previous conv1 layer (which is out_channels, 128).
        # out_channels: The final output channels for this branch, also 128.
        # kernel_size=1: A 1x1 kernel operates on each spatial location independently,
        #               performing a linear combination across the feature channels.
        #               It's like a fully connected layer operating channel-wise across the feature map.
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1)

        # ReLU activation after the second convolution.
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        # Define the forward pass: how data flows through the layers.
        # x: Input tensor to the MLP branch (a feature map from a VGG pooling layer).

        # Apply the first convolution and ReLU.
        x = self.conv1(x)
        x = self.relu1(x)

        # Apply the second convolution and ReLU.
        x = self.conv2(x)
        x = self.relu2(x)

        # Return the processed feature map from this branch.
        return x

In [4]:
class DeepLabMSc(nn.Module):
    def __init__(self, num_classes=21):
        # Call the constructor of the parent class (nn.Module).
        super(DeepLabMSc, self).__init__()

        # --- 3.1. VGG-16 Backbone Initialization and Modification ---
        # Load a pre-trained VGG-16 model from torchvision.
        # pretrained=True: Loads weights pre-trained on the ImageNet dataset.
        #                  This is crucial for transfer learning, as these weights are good
        #                  general feature extractors and help with convergence.
        vgg16 = models.vgg16(pretrained=True)
        
        # Access the 'features' module of VGG-16. This is the convolutional part of VGG.
        # VGG-16 is composed of a 'features' module (conv layers, pooling)
        # and a 'classifier' module (fully connected layers).
        self.features = vgg16.features

        # Identify key layers for feature extraction for multi-scale branches.
        # These are the indices of the MaxPool2d layers within the vgg16.features Sequential module.
        # We need their outputs for our MLP branches.
        self.pool1_idx = 4  # nn.MaxPool2d at index 4 (after conv1_2)
        self.pool2_idx = 9  # nn.MaxPool2d at index 9 (after conv2_2)
        self.pool3_idx = 16 # nn.MaxPool2d at index 16 (after conv3_3)
        self.pool4_idx = 23 # nn.MaxPool2d at index 23 (after conv4_3)

        # --- DeepLab-specific modifications for VGG-16 backbone ---
        # The paper states: "skip subsampling after the last two max-pooling layers"
        # This refers to VGG's original pool4 and pool5.
        # In PyTorch's VGG-16 features module:
        # features[23] is MaxPool2d (pool4)
        # features[30] is MaxPool2d (pool5)

        # 1. Modify stride of pool4 and pool5 to 1.
        # This prevents resolution reduction at these stages, maintaining feature map density.
        # Originally, these layers would halve spatial dimensions. Setting stride=1 means
        # they now only apply a max operation without downsampling.
        # This keeps the effective output stride after pool5 at 1/8 of input.
        self.features[self.pool4_idx].stride = 1 # Change pool4 stride from 2 to 1
        self.features[self.pool5_idx].stride = 1 # Change pool5 stride from 2 to 1

        # 2. Apply atrous (dilated) convolution to subsequent layers.
        # This compensates for the reduced effective stride and expands the receptive field,
        # allowing the network to see a larger context without further downsampling.
        # The paper specifies: "2x in the last three convolutional layers and 4x in the first fully connected layer"
        # 'Last three convolutional layers' refer to conv5_1, conv5_2, conv5_3 (in VGG features module).
        # 'First fully connected layer' refers to the one converted from vgg.classifier[0] (fc6).

        # For conv5_1 (features[24]), conv5_2 (features[26]), conv5_3 (features[28]):
        # Set dilation to 2. To maintain same output size, padding must be adjusted.
        # For kernel_size=3, dilation=2, new padding = dilation * (kernel_size - 1) // 2 = 2 * (3 - 1) // 2 = 2.
        for i in [24, 26, 28]: # Indices for conv5_1, conv5_2, conv5_3
            if isinstance(self.features[i], nn.Conv2d): # Ensure it's a Conv2d layer
                self.features[i].dilation = (2, 2) # Set dilation to 2 in both dimensions
                self.features[i].padding = (2, 2)  # Adjust padding accordingly

        # --- 3.2. Converting Fully Connected Layers to Convolutional Layers ---
        # VGG-16's original classifier has three nn.Linear layers: fc6, fc7, fc8.
        # For dense prediction (segmentation), these must be converted to nn.Conv2d layers.
        # The input to fc6 comes from pool5 (which is 512 channels, 7x7 spatial size if original 224x224 input).
        # With our modifications (pool5 stride=1), the input to conv6 (converted fc6)
        # will have spatial dimensions that are 1/8th of the original input.

        # Store the original classifier weights to initialize the new convolutional layers.
        # This preserves the learned knowledge from ImageNet classification.
        # clone() is important to ensure we don't modify the original tensor in-place before copying.
        # detach() removes the tensor from the current computation graph, so no gradients are computed for this copy.
        classifier = vgg16.classifier.clone().detach()

        # fc6 (original classifier[0]) -> new conv6
        # Input channels: 512 (from conv5_3 output)
        # Output channels: 4096 (original fc6 output size)
        # Kernel size: 7x7 (original spatial size for 224x224 input / 32)
        # Dilation: The paper specifies "4x in the first fully connected layer".
        #           So, dilation=4. Padding = dilation * (kernel_size - 1) // 2 = 4 * (7 - 1) // 2 = 12.
        self.conv6 = nn.Conv2d(512, 4096, kernel_size=7, padding=12, dilation=4)
        # Copy weights from fc6. We need to reshape the linear layer weights (4096, 512*7*7)
        # to convolutional weights (out_channels, in_channels, kernel_h, kernel_w).
        # We divide by (7*7) to get the average weight for each channel per kernel output location.
        # The VGG fc6 weights are often considered as 7x7 filters replicated for all input locations.
        self.conv6.weight.data = classifier[0].weight.data.view(4096, 512, 7, 7)
        self.conv6.bias.data = classifier[0].bias.data

        # fc7 (original classifier[3]) -> new conv7
        # Input channels: 4096 (from conv6 output)
        # Output channels: 4096 (original fc7 output size)
        # Kernel size: 1x1 (standard for fully connected layer conversion when spatial dim is 1x1, or for channel mixing)
        # Dilation: Not explicitly specified as dilated, so usually 1x1.
        self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1)
        # Copy weights. For a 1x1 conv, reshape from (4096, 4096) to (4096, 4096, 1, 1).
        self.conv7.weight.data = classifier[3].weight.data.view(4096, 4096, 1, 1)
        self.conv7.bias.data = classifier[3].bias.data

        # fc8 (original classifier[6]) -> final_conv
        # This is the classification head predicting scores for each class.
        # Input channels: 4096 (from conv7 output)
        # Output channels: num_classes (e.g., 21 for PASCAL VOC)
        # Kernel size: 1x1
        # Dilation: Not explicitly specified as dilated.
        self.final_conv_main = nn.Conv2d(4096, num_classes, kernel_size=1)
        # Copy weights. Reshape from (num_classes, 4096) to (num_classes, 4096, 1, 1).
        self.final_conv_main.weight.data = classifier[6].weight.data.view(num_classes, 4096, 1, 1)
        self.final_conv_main.bias.data = classifier[6].bias.data

        # ReLU activations for the converted FC layers.
        # VGG's classifier has ReLU after fc6 and fc7.
        self.relu6 = nn.ReLU(inplace=True)
        self.relu7 = nn.ReLU(inplace=True)

        # --- 3.3. Multi-Scale Branch Initialization ---
        # Instantiate MLPBranch for each pooling layer.
        # The in_channels correspond to the output channels of the respective VGG pooling layer.
        self.branch1 = MLPBranch(in_channels=64)  # pool1 outputs 64 channels
        self.branch2 = MLPBranch(in_channels=128) # pool2 outputs 128 channels
        self.branch3 = MLPBranch(in_channels=256) # pool3 outputs 256 channels
        self.branch4 = MLPBranch(in_channels=512) # pool4 (modified stride=1) outputs 512 channels

        # --- 3.4. Final Classification Head for Concatenated Features ---
        # The paper states: "The aggregate feature map fed into the softmax layer is thus enhanced by 5 * 128 = 640 channels."
        # This implies that the main network's final feature map (before its own classification head)
        # is also processed to 128 channels, or there's a different setup.
        # However, the general interpretation is: main_network_features + 4 * branch_features.
        # If main network's features (before final_conv_main) are 4096 channels, and each branch outputs 128,
        # then total input channels to the *final* classification head for concatenated features would be:
        # 4096 (from conv7) + 4 * 128 (from 4 branches) = 4096 + 512 = 4608 channels.
        # The paper's text "5 * 128 = 640 channels" might imply they changed conv7 output to 128 channels
        # or it refers to a different specific variant.
        # For DeepLab-MSc implementation, the common approach is to use the conv5_3 (main) output for concat,
        # or have another small layer that reduces conv7 output to 128 channels.
        # Let's stick to the interpretation that the final classification head takes all concatenated features.
        # We assume the main network's final features (e.g. from conv7) are 4096 channels, as per VGG.
        # Total concatenated channels = 4096 (from main path) + 4 * 128 (from branches) = 4608 channels.
        self.final_classifier_msc = nn.Conv2d(4096 + (4 * 128), num_classes, kernel_size=1)

    def forward(self, x):
        # Store input dimensions for final upsampling.
        input_H, input_W = x.shape[2], x.shape[3]

        # Containers to store outputs of pooling layers and branch outputs.
        pool_outputs = {}
        branch_features = []

        # --- 3.5. Forward Pass through VGG Features (with DeepLab modifications) ---
        # Iterate through the VGG features module to get intermediate outputs.
        for i, layer in enumerate(self.features):
            x = layer(x) # Pass input through the current layer

            # Store outputs at specific pooling layer indices.
            if i == self.pool1_idx:
                pool_outputs['pool1'] = x # (N, 64, H/2, W/2)
            elif i == self.pool2_idx:
                pool_outputs['pool2'] = x # (N, 128, H/4, W/4)
            elif i == self.pool3_idx:
                pool_outputs['pool3'] = x # (N, 256, H/8, W/8)
            elif i == self.pool4_idx:
                pool_outputs['pool4'] = x # (N, 512, H/8, W/8) (due to stride=1 modification)
        
        # x at this point is the output of the modified VGG features (after pool5, also H/8, W/8).
        # This is the feature map that would normally go into the classifier.

        # --- 3.6. Forward Pass through Converted FC Layers (Main Path) ---
        # These are the former fully-connected layers, now acting as convolutional layers.
        x = self.conv6(x)
        x = self.relu6(x)
        x = self.conv7(x)
        x = self.relu7(x)
        # x is now the deep, high-level feature map (N, 4096, H/8, W/8).

        # --- 3.7. Forward Pass through Multi-Scale Branches ---
        # Process each stored pooling output through its respective MLPBranch.
        # Also, upsample their outputs to the target resolution (H/8, W/8) for concatenation.
        # Note: pool3 and pool4 are already at H/8, W/8, so no upsampling is needed for them.
        
        # Branch for pool1 (H/2, W/2) -> Needs 4x upsampling
        branch_features.append(
            F.interpolate(self.branch1(pool_outputs['pool1']), size=(input_H // 8, input_W // 8),
                          mode='bilinear', align_corners=False)
        )
        # Explanation: F.interpolate is used for upsampling.
        # size=(H//8, W//8): Specifies the target spatial dimensions (1/8th of input resolution).
        # mode='bilinear': Uses bilinear interpolation, a common method for smooth resizing.
        #                  It's a weighted average of the 4 nearest pixels.
        # align_corners=False: This is crucial for consistency in deep learning frameworks.
        #                      It avoids off-by-one pixel issues when upsampling/downsampling multiple times.
        #                      True means the corner pixels of the input and output tensors are aligned.
        #                      False means the pixel centers are aligned, which is generally preferred for feature maps.

        # Branch for pool2 (H/4, W/4) -> Needs 2x upsampling
        branch_features.append(
            F.interpolate(self.branch2(pool_outputs['pool2']), size=(input_H // 8, input_W // 8),
                          mode='bilinear', align_corners=False)
        )

        # Branch for pool3 (H/8, W/8) -> Already at target resolution, but still pass through MLPBranch
        branch_features.append(self.branch3(pool_outputs['pool3']))

        # Branch for pool4 (H/8, W/8, due to stride=1) -> Already at target resolution, but still pass through MLPBranch
        branch_features.append(self.branch4(pool_outputs['pool4']))

        # --- 3.8. Concatenate All Features ---
        # Concatenate the deep main path features (x) with all processed multi-scale branch features.
        # dim=1: Specifies concatenation along the channel dimension (axis=1 in PyTorch tensors).
        # This creates a combined feature map with (4096 + 4*128) channels,
        # providing both high-level semantic context and fine-grained spatial details.
        x_concat = torch.cat([x] + branch_features, dim=1)

        # --- 3.9. Final Classification from Concatenated Features ---
        # Pass the concatenated feature map through the final classification head.
        # This is a 1x1 convolution that projects the combined features to per-pixel class logits.
        logits = self.final_classifier_msc(x_concat) # (N, num_classes, H/8, W/8)

        # --- 3.10. Final Upsampling to Original Input Resolution ---
        # The logits are currently at 1/8th resolution. For pixel-level segmentation output
        # (which serves as unary potential for CRF), they need to be upsampled to the original input size.
        # This is the final output of the DCNN part of DeepLab-MSc-CRF, ready for CRF.
        output = F.interpolate(logits, size=(input_H, input_W), mode='bilinear', align_corners=False)

        return output

In [None]:
if __name__ == '__main__':
    # Define the number of semantic classes (e.g., 21 for PASCAL VOC).
    num_classes = 21

    # Instantiate the DeepLabMSc model.
    print(f"Initializing DeepLabMSc model with {num_classes} classes...")
    model = DeepLabMSc(num_classes=num_classes)
    # Move the model to GPU if available for faster computation.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    # Set the model to evaluation mode. This disables dropout and batch normalization updates,
    # important for consistent inference.
    model.eval()

    # Create a dummy input image tensor.
    # Batch size = 1, Channels = 3 (RGB), Height = 224, Width = 224 (typical VGG input size).
    # This input will be processed by the model.
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    print(f"Dummy input shape: {dummy_input.shape}")

    # Perform a forward pass (inference).
    print("Performing forward pass...")
    with torch.no_grad(): # Disable gradient calculations for inference, saves memory and speeds up.
        output_logits = model(dummy_input)

    # Print the shape of the output tensor.
    # It should be (Batch_Size, Num_Classes, Original_Height, Original_Width).
    print(f"Output logits shape: {output_logits.shape}")

    # Proof of Concept: Output Resolution
    expected_output_shape = (1, num_classes, 224, 224)
    assert output_logits.shape == expected_output_shape, \
        f"Output shape mismatch! Expected {expected_output_shape}, got {output_logits.shape}"
    print("Output shape matches expected full resolution.")

    # --- Illustrative Example of Feature Map Sizes at Each Stage ---
    print("\n--- Illustrative Feature Map Sizes (Conceptual Flow) ---")
    # This is a conceptual walkthrough to show the impact of DeepLab modifications on resolution.
    # Actual sizes will vary with input image size.
    
    # Original VGG stages for a 224x224 input:
    # After conv1_x (no pool): 224x224
    # After pool1: 112x112 (1/2) - used for branch1
    # After pool2: 56x56 (1/4) - used for branch2
    # After pool3: 28x28 (1/8) - used for branch3
    # After conv4_x (no pool4 due to stride=1): 28x28 (1/8)
    # After pool4 (stride=1): 28x28 (1/8) - used for branch4
    # After conv5_x (no pool5 due to stride=1, with dilation=2): 28x28 (1/8)
    # After pool5 (stride=1): 28x28 (1/8)
    # After conv6 (with dilation=4): 28x28 (1/8)
    # After conv7 (1x1 conv): 28x28 (1/8)
    # Concatenated features: 28x28 (1/8)
    # Final logits: 28x28 (1/8) before final upsampling

    print(f"Input: {dummy_input.shape[2]}x{dummy_input.shape[3]}")
    print(f"pool1 output (raw): {dummy_input.shape[2]//2}x{dummy_input.shape[3]//2}")
    print(f"pool2 output (raw): {dummy_input.shape[2]//4}x{dummy_input.shape[3]//4}")
    print(f"pool3 output (raw): {dummy_input.shape[2]//8}x{dummy_input.shape[3]//8}")
    print(f"pool4 output (modified VGG): {dummy_input.shape[2]//8}x{dummy_input.shape[3]//8} (stride 1)")
    print(f"Main path features (after conv7): {dummy_input.shape[2]//8}x{dummy_input.shape[3]//8}")
    
    print(f"Branch1 output (after MLPBranch, before upsample): {dummy_input.shape[2]//2}x{dummy_input.shape[3]//2}")
    print(f"Branch1 output (after upsample for concat): {dummy_input.shape[2]//8}x{dummy_input.shape[3]//8}")
    # Similar for other branches.

    print(f"Final concatenated features: {output_logits.shape[2]//(224//dummy_input.shape[2])}x{output_logits.shape[3]//(224//dummy_input.shape[3])}") # Correct for arbitrary input size for 1/8 output.
    print(f"Final output logits (after final upsample): {output_logits.shape[2]}x{output_logits.shape[3]}")