In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SeparableConv2d(nn.Module):
    """
    Implements a depthwise separable convolution block as used in Xception.
    This class factorizes a standard convolution into a depthwise convolution
    followed by a pointwise convolution.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
        """
        Initializes the depthwise separable convolution module.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int): Size of the depthwise kernel. Default: 3.
            stride (int): Stride of the depthwise convolution. Default: 1.
            padding (int): Padding for the depthwise convolution. Default: 1.
            dilation (int): Dilation rate for the depthwise convolution. Default: 1.
            bias (bool): Whether to include bias terms. Default: False.
        """
        super(SeparableConv2d, self).__init__()

        # --- Depthwise Convolution ---
        # Applies a spatial filter to each input channel independently.
        self.depthwise = nn.Conv2d(
            in_channels, 
            in_channels,          # Output channels == Input channels
            kernel_size=kernel_size, 
            stride=stride, 
            padding=padding, 
            dilation=dilation,
            groups=in_channels,   # This is the key to making it depthwise
            bias=bias
        )

        # --- Pointwise Convolution ---
        # A 1x1 convolution to mix the channels from the depthwise step.
        self.pointwise = nn.Conv2d(
            in_channels, 
            out_channels, 
            kernel_size=1, 
            stride=1, 
            padding=0, 
            dilation=1,
            bias=bias
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Defines the forward pass of the separable convolution.

        Args:
            x (torch.Tensor): The input tensor of shape [B, C_in, H, W].

        Returns:
            torch.Tensor: The output tensor of shape [B, C_out, H', W'].
        """
        # First, apply the depthwise convolution to capture spatial patterns
        x = self.depthwise(x)
        # Second, apply the pointwise convolution to mix channels
        x = self.pointwise(x)
        return x

In [6]:
# (Assuming the SeparableConv2d class from Step 1 is already defined above this code)

class Block(nn.Module):
    """
    Defines a single Xception block.
    This block consists of a series of depthwise separable convolutions
    and a residual connection.
    """
    def __init__(self, in_channels, out_channels, reps, stride=1, start_with_relu=True, grow_first=True):
        """
        Initializes the Xception Block.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            reps (int): Number of separable convolution repetitions within the block.
            stride (int): Stride for the first convolution in the block, used for downsampling.
            start_with_relu (bool): Whether to apply a ReLU activation before the first conv.
            grow_first (bool): If True, the first separable conv changes the number of channels,
                               otherwise the last one does.
        """
        super(Block, self).__init__()

        # --- Residual Connection Path ---
        # This is for the 'shortcut' or 'identity' path.
        # It's needed if the main path changes dimensions (stride > 1) or channels.
        if out_channels != in_channels or stride != 1:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.skip = None

        # --- Main Convolutional Path ---
        self.relu = nn.ReLU(inplace=True)
        
        # Build the sequence of separable convolutions for the main path
        rep_layers = []
        
        # Determine which layer will handle the channel growth
        current_channels = in_channels
        if grow_first:
            rep_layers.append(self.relu)
            rep_layers.append(SeparableConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False))
            rep_layers.append(nn.BatchNorm2d(out_channels))
            current_channels = out_channels

        # Add the rest of the repetitions
        for i in range(reps - 1):
            rep_layers.append(self.relu)
            rep_layers.append(SeparableConv2d(current_channels, current_channels, kernel_size=3, stride=1, padding=1, bias=False))
            rep_layers.append(nn.BatchNorm2d(current_channels))
        
        # Apply striding and potential channel growth in the last layer if not done first
        if not grow_first:
            rep_layers.append(self.relu)
            rep_layers.append(SeparableConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False))
            rep_layers.append(nn.BatchNorm2d(out_channels))
        
        # The first layer in the sequence may have a stride for downsampling
        if stride != 1:
            # The logic to apply stride to the first convolution of the sequence
            # Note: This is a simplified logic. In the original Xception, the striding layer is more explicitly defined.
            # Here we ensure the very first conv in the block gets the stride if needed.
            # A more direct implementation as per the paper:
            # The first SeparableConv2d in the block would take the stride argument. We will refine this.
            pass # Placeholder for more complex stride logic if needed. Our current logic is a simplification.

        # The Xception paper has a specific structure where the first layer of the block
        # may have stride != 1. Let's refine the logic to be more faithful.
        
        # --- Refined Main Path Logic ---
        self.reps = reps
        self.start_with_relu = start_with_relu
        
        # The first layer of the block handles potential downsampling (stride)
        if grow_first:
            # First layer expands channels, then subsequent layers are same-channel
            first_conv_out_channels = out_channels
            self.conv1 = SeparableConv2d(in_channels, first_conv_out_channels, 3, stride=stride, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(first_conv_out_channels)
            
            # Subsequent layers
            self.convs = nn.ModuleList()
            self.bns = nn.ModuleList()
            for i in range(reps-1):
                self.convs.append(SeparableConv2d(first_conv_out_channels, first_conv_out_channels, 3, stride=1, padding=1, bias=False))
                self.bns.append(nn.BatchNorm2d(first_conv_out_channels))
        else:
            # All but the last layer are same-channel, last layer expands
            self.convs = nn.ModuleList()
            self.bns = nn.ModuleList()
            for i in range(reps-1):
                self.convs.append(SeparableConv2d(in_channels, in_channels, 3, stride=1 if i > 0 else stride, padding=1, bias=False))
                self.bns.append(nn.BatchNorm2d(in_channels))
            
            # Last layer
            self.conv_last = SeparableConv2d(in_channels, out_channels, 3, stride=1, padding=1, bias=False)
            self.bn_last = nn.BatchNorm2d(out_channels)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Defines the forward pass for the Xception block.
        """
        residual = x

        # --- Main Path ---
        # First layer with potential ReLU start
        if self.start_with_relu:
            x = self.relu(x)
        
        # Apply the sequence of convolutions
        # (This is a simplified forward pass based on the refined logic)
        # Proper forward pass would iterate through the ModuleList of convs and bns.
        # Let's write the correct, more modular forward pass.

        # Correct Forward Pass:
        if self.skip is not None:
            residual = self.skip(residual)
        
        # Let's revert to a simpler, more readable implementation for teaching purposes.
        # The complex `grow_first` logic can be abstracted away by how we call the block.
        # We will create a list of layers and pass the input through them sequentially.
        
        # --- FINAL, CLEAN IMPLEMENTATION for Step 2 ---
        # We will remove the complex `grow_first` logic from the constructor
        # and instead define the block as a simpler sequence, which is easier to understand.
        
        # (Re-writing the class for clarity)

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dilation=1):
        super(Block, self).__init__()

        # --- Residual Connection Path ---
        # This handles the case where the input dimensions (channels or spatial size)
        # need to be changed to match the output dimensions for the final addition.
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.skip = None

        # --- Main Convolutional Path ---
        # The main path in an Xception block typically consists of:
        # ReLU -> SeparableConv -> BN -> ReLU -> SeparableConv -> BN ...
        # The first convolution in the block is responsible for any downsampling (stride > 1).
        self.conv_path = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            SeparableConv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # --- Residual Path ---
        if self.skip is not None:
            skip_x = self.skip(x)
        else:
            skip_x = x

        # --- Main Path ---
        main_x = self.conv_path(x)
        
        # --- Add Residual ---
        output = main_x + skip_x
        return output

In [7]:
# (Assuming SeparableConv2d and Block classes are defined above)

class Xception(nn.Module):
    """
    Implements the modified Xception backbone for DeepLabv3+.
    The architecture is modified to allow for atrous convolution to control
    the output stride, which is crucial for semantic segmentation.
    """
    def __init__(self, in_channels=3, output_stride=16):
        """
        Initializes the Xception backbone.

        Args:
            in_channels (int): Number of input channels, typically 3 for RGB images.
            output_stride (int): The ratio of input image spatial resolution to the final
                                 output feature map spatial resolution. Must be 8 or 16.
        """
        super(Xception, self).__init__()

        # --- Set up stride and dilation rates based on output_stride ---
        if output_stride == 16:
            entry_block3_stride = 2
            middle_block_dilation = 1
            exit_block_dilations = (1, 2)
        elif output_stride == 8:
            entry_block3_stride = 1
            middle_block_dilation = 2
            exit_block_dilations = (2, 4)
        else:
            raise ValueError("Output stride must be 8 or 16.")

        # ================== ENTRY FLOW ==================
        # First two convolutions are standard, not separable
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        
        # The entry flow consists of 3 blocks
        self.block1 = Block(in_channels=64, out_channels=128, stride=2)
        self.block2 = Block(in_channels=128, out_channels=256, stride=2)
        # The stride of this block depends on the desired output_stride
        self.block3 = Block(in_channels=256, out_channels=728, stride=entry_block3_stride)

        # ================== MIDDLE FLOW ==================
        # Consists of 16 identical blocks. Dilation rate is controlled by output_stride.
        middle_blocks = []
        for i in range(16):
            middle_blocks.append(
                Block(in_channels=728, out_channels=728, stride=1, dilation=middle_block_dilation)
            )
        self.middle_flow = nn.Sequential(*middle_blocks)

        # =================== EXIT FLOW ===================
        self.exit_block1 = Block(
            in_channels=728, 
            out_channels=1024, 
            stride=1, 
            dilation=exit_block_dilations[0]
        )
        self.exit_block2 = nn.Sequential(
            nn.ReLU(),
            SeparableConv2d(1024, 1536, 3, stride=1, padding=exit_block_dilations[1], dilation=exit_block_dilations[1], bias=False),
            nn.BatchNorm2d(1536),
            nn.ReLU(),
            SeparableConv2d(1536, 2048, 3, stride=1, padding=exit_block_dilations[1], dilation=exit_block_dilations[1], bias=False),
            nn.BatchNorm2d(2048),
            nn.ReLU()
        )


    def forward(self, x: torch.Tensor):
        """
        Defines the forward pass of the Xception backbone.

        Returns a tuple of feature maps:
        - low_level_features: Features from the entry flow for the decoder.
        - x: The final high-level features from the exit flow.
        """
        # --- Entry Flow ---
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        # No relu here, as the Block starts with a ReLU
        
        x = self.block1(x)
        # --- The point to extract low-level features ---
        # These features have a larger spatial size and contain finer details.
        # They will be used in the DeepLabv3+ decoder.
        low_level_features = x
        x = self.block2(x)
        x = self.block3(x)

        # --- Middle Flow ---
        x = self.middle_flow(x)

        # --- Exit Flow ---
        x = self.exit_block1(x)
        x = self.exit_block2(x)

        return x, low_level_features

In [8]:
# (Assuming all previous classes are defined)

class _ASPPConv(nn.Module):
    """A single convolution branch in the ASPP module."""
    def __init__(self, in_channels, out_channels, dilation):
        super(_ASPPConv, self).__init__()
        self.module = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, 
                      padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.module(x)

class _ASPPPooling(nn.Module):
    """The image-level pooling branch in the ASPP module."""
    def __init__(self, in_channels, out_channels):
        super(_ASPPPooling, self).__init__()
        self.module = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)), # Pool to a 1x1 feature map
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Get the size of the input feature map
        size = x.shape[2:]
        # Apply the pooling and convolutions
        x = self.module(x)
        # Upsample back to the original size
        x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)
        return x

class ASPP(nn.Module):
    """
    Implements Atrous Spatial Pyramid Pooling (ASPP) as described in DeepLabv3.
    """
    def __init__(self, in_channels, output_stride):
        super(ASPP, self).__init__()

        # Determine dilation rates based on the output stride from the backbone
        if output_stride == 16:
            dilations = [1, 6, 12, 18]
        elif output_stride == 8:
            dilations = [1, 12, 24, 36]
        else:
            raise ValueError("Output stride must be 8 or 16.")

        # Number of channels for each of the parallel branches
        aspp_out_channels = 256

        # --- Branch 1: 1x1 Convolution ---
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(in_channels, aspp_out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(aspp_out_channels),
            nn.ReLU(inplace=True)
        )
        
        # --- Branch 2, 3, 4: Atrous Convolutions ---
        self.conv_d6 = _ASPPConv(in_channels, aspp_out_channels, dilations[1])
        self.conv_d12 = _ASPPConv(in_channels, aspp_out_channels, dilations[2])
        self.conv_d18 = _ASPPConv(in_channels, aspp_out_channels, dilations[3])

        # --- Branch 5: Image Pooling ---
        self.image_pool = _ASPPPooling(in_channels, aspp_out_channels)

        # --- Fusion Layer ---
        # Concatenates the outputs of all branches and fuses them
        self.project = nn.Sequential(
            nn.Conv2d(
                in_channels=aspp_out_channels * 5, # 5 branches
                out_channels=aspp_out_channels, 
                kernel_size=1, 
                bias=False
            ),
            nn.BatchNorm2d(aspp_out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the ASPP module.
        """
        # Pass the input through each parallel branch
        res1 = self.conv1x1(x)
        res2 = self.conv_d6(x)
        res3 = self.conv_d12(x)
        res4 = self.conv_d18(x)
        res5 = self.image_pool(x)

        # Concatenate the results along the channel dimension (dim=1)
        # [B, C1, H, W], [B, C2, H, W], ... -> [B, C1+C2+..., H, W]
        x_concat = torch.cat([res1, res2, res3, res4, res5], dim=1)

        # Fuse the concatenated features with the final projection layer
        x_fused = self.project(x_concat)
        
        return x_fused

In [None]:
# (Assuming all previous classes: SeparableConv2d, Block, Xception, ASPP are defined)

class Decoder(nn.Module):
    """
    The decoder module for DeepLabv3+. It upsamples features from the ASPP module
    and fuses them with low-level features from the backbone.
    """
    def __init__(self, num_classes, low_level_in_channels):
        super(Decoder, self).__init__()

        # --- Low-level Feature Processing ---
        # A 1x1 convolution to reduce the number of channels from the backbone's
        # low-level features. This is a crucial step for efficiency.
        self.conv1x1_low_level = nn.Sequential(
            nn.Conv2d(low_level_in_channels, 48, kernel_size=1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )

        # --- High-level and Fused Feature Processing ---
        # Convolutions to refine the features after concatenation.
        self.conv3x3_fused = nn.Sequential(
            # Using SeparableConv2d is more in line with the Xception theme
            SeparableConv2d(256 + 48, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            SeparableConv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        # --- Final Classifier ---
        # A final convolution to produce the logits for each class.
        self.classifier = nn.Conv2d(256, num_classes, kernel_size=1)


    def forward(self, x_high_level: torch.Tensor, x_low_level: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the decoder.

        Args:
            x_high_level (torch.Tensor): High-level features from the ASPP module.
            x_low_level (torch.Tensor): Low-level features from the backbone's entry flow.
        """
        # 1. Upsample the high-level features (from ASPP)
        # We need to upsample them to match the spatial dimensions of the low-level features.
        # Typically, low-level features are at 1/4 size, high-level are at 1/16 or 1/8.
        # So we upsample by a factor of 4.
        x_high_upsampled = F.interpolate(x_high_level, size=x_low_level.shape[2:], 
                                         mode='bilinear', align_corners=False)

        # 2. Process the low-level features
        x_low_processed = self.conv1x1_low_level(x_low_level)

        # 3. Concatenate the two feature maps
        x_fused = torch.cat([x_high_upsampled, x_low_processed], dim=1)

        # 4. Refine the fused features
        x_refined = self.conv3x3_fused(x_fused)
        
        # 5. Get class predictions
        x_logits = self.classifier(x_refined)

        return x_logits


class DeepLabv3_plus(nn.Module):
    """
    The complete DeepLabv3+ model with a modified Xception backbone.
    """
    def __init__(self, num_classes, output_stride=16, in_channels=3):
        super(DeepLabv3_plus, self).__init__()
        
        # --- Encoder ---
        self.backbone = Xception(in_channels=in_channels, output_stride=output_stride)
        # The input channels to ASPP is the output channels of the backbone's exit flow
        aspp_in_channels = 2048
        self.aspp = ASPP(in_channels=aspp_in_channels, output_stride=output_stride)

        # --- Decoder ---
        # The low-level input channels to the decoder is the output of block1 in the backbone
        low_level_in_channels = 128
        self.decoder = Decoder(num_classes=num_classes, low_level_in_channels=low_level_in_channels)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the entire DeepLabv3+ model.
        """
        # Get the original input size for final upsampling
        input_size = x.shape[2:]

        # --- Encoder Pass ---
        # 1. Pass through the backbone to get high and low level features
        x_high, x_low = self.backbone(x)
        
        # 2. Pass high-level features through ASPP
        x_high = self.aspp(x_high)

        # --- Decoder Pass ---
        # 3. Pass both feature maps to the decoder
        x = self.decoder(x_high, x_low)

        # --- Final Upsampling ---
        # 4. Upsample the final logits to the original image size
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)

        return x