In [1]:
import torch
import torch.nn as nn

class SeparableConv2d(nn.Module):
    """
    A depthwise separable convolution module. This module is a fundamental building block
    for Xception. It consists of a depthwise convolution followed by a pointwise convolution.
    
    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int): Size of the convolution kernel.
        stride (int, optional): Stride of the convolution. Defaults to 1.
        padding (int, optional): Padding added to all four sides of the input. Defaults to 0.
        bias (bool, optional): If True, adds a learnable bias to the output. Defaults to False.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
        super().__init__()

        # --- Depthwise Convolution ---
        # This convolution acts on each input channel independently.
        # 'groups=in_channels' is the key to making this a depthwise convolution.
        # Each input channel gets its own filter, so there's no cross-channel information mixing.
        # The number of output channels must be equal to in_channels because each channel is
        # processed separately.
        self.depthwise = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels, # Must be same as in_channels
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=in_channels, # Key parameter for depthwise conv
            bias=bias
        )
        
        # --- Pointwise Convolution ---
        # This is a standard 1x1 convolution.
        # Its purpose is to take the output of the depthwise layer and mix the channel
        # information together to create the final desired number of output channels.
        # It learns linear combinations of the input channels.
        self.pointwise = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1, # This makes it a pointwise convolution
            stride=1,
            padding=0,
            bias=bias
        )

    def forward(self, x):
        """
        Forward pass for the separable convolution.
        
        Args:
            x (torch.Tensor): Input tensor.
            
        Returns:
            torch.Tensor: Output tensor.
        """
        # Pass input through the depthwise convolution first
        x = self.depthwise(x)
        # Then pass the result through the pointwise convolution
        x = self.pointwise(x)
        return x

In [None]:
class Block(nn.Module):
    """
    The Xception Block. This is the main repeating unit of the Xception architecture.
    It contains a series of separable convolutions and a residual connection.
    
    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        reps (int): Number of separable convolution layers in this block.
        stride (int, optional): Stride for the first convolution. Used for downsampling. Defaults to 1.
        start_with_relu (bool, optional): Whether to apply ReLU before the first convolution. Defaults to True.
        grow_first (bool, optional): Whether to expand channels in the first or last layer. Defaults to True.
    """
    def __init__(self, in_channels, out_channels, reps, stride=1, start_with_relu=True, grow_first=True):
        super().__init__()

        # --- Residual Connection Path ---
        # The residual connection adds the input 'x' to the output of the convolutional path.
        # If the dimensions don't match (due to striding or change in channel count),
        # we need a 'projection' convolution in the residual path to make them match.
        if out_channels != in_channels or stride != 1:
            # This 1x1 convolution will adjust the channel count and, if stride > 1,
            # it will downsample the spatial dimensions to match the output of the main path.
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            # If dimensions match, the shortcut is just an identity connection (add the input directly).
            self.shortcut = None

        # --- Main Convolutional Path ---
        layers = []
        
        # Determine the number of channels for the intermediate layers
        # 'grow_first' controls whether the channel expansion happens at the first
        # separable conv or the last one in the block.
        intermediate_channels = out_channels if grow_first else in_channels

        # The first layer in the block might have a ReLU before it, as per the paper's design.
        if start_with_relu:
            layers.append(nn.ReLU())
        
        # First separable convolution layer
        layers.append(SeparableConv2d(in_channels, intermediate_channels, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(intermediate_channels))
        layers.append(nn.ReLU())
        
        # Add the middle separable convolution layers
        for _ in range(reps - 2):
            layers.append(SeparableConv2d(intermediate_channels, intermediate_channels, kernel_size=3, stride=1, padding=1, bias=False))
            layers.append(nn.BatchNorm2d(intermediate_channels))
            layers.append(nn.ReLU())

        # The last separable convolution layer, which may have a different stride for downsampling.
        # It also ensures the final output has 'out_channels'.
        layers.append(SeparableConv2d(intermediate_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(out_channels))
        
        # Create the sequential module for the main path
        self.main_path = nn.Sequential(*layers)

    def forward(self, x):
        """
        Forward pass for the Xception block.
        
        Args:
            x (torch.Tensor): Input tensor.
            
        Returns:
            torch.Tensor: Output tensor.
        """
        # Get the output from the main convolutional path
        main_output = self.main_path(x)
        
        # Get the output from the residual path (either identity or projection)
        if self.shortcut is not None:
            residual_output = self.shortcut(x)
        else:
            residual_output = x
            
        # Add the main path output and the residual path output
        # This is the core idea of residual connections.
        return main_output + residual_output

In [None]:
class Xception(nn.Module):
    """
    The Xception architecture model.
    Based on the paper "Xception: Deep Learning with Depthwise Separable Convolutions"
    by F. Chollet.
    
    Args:
        num_classes (int, optional): Number of classes for the final classifier. Defaults to 1000.
    """
    def __init__(self, num_classes=1000):
        super().__init__()
        
        # ==================== Entry Flow ====================
        
        # The entry flow starts with two standard convolution layers.
        # This is a common practice to quickly downsample the image and create a rich set of
        # low-level features before moving to the more complex and efficient separable convolutions.
        self.entry_flow_conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.entry_flow_bn1 = nn.BatchNorm2d(32)
        self.entry_flow_relu1 = nn.ReLU()
        
        self.entry_flow_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.entry_flow_bn2 = nn.BatchNorm2d(64)
        self.entry_flow_relu2 = nn.ReLU()

        # The rest of the entry flow consists of Xception blocks.
        # Block 1: Downsamples spatial dimensions (stride=2) and increases channels.
        self.entry_flow_block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False, grow_first=True)
        # Block 2: Downsamples again and increases channels.
        self.entry_flow_block2 = Block(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True)
        # Block 3: Final downsampling and channel increase in the entry flow.
        self.entry_flow_block3 = Block(256, 728, reps=2, stride=2, start_with_relu=True, grow_first=True)

        # ==================== Middle Flow ====================
        
        # The middle flow is the main body of the network. It consists of 8 identical
        # Xception blocks. These blocks do not downsample or change the number of channels.
        # Their purpose is to repeatedly transform and refine the features at a fixed
        # spatial resolution (19x19 in the paper) and channel depth (728).
        middle_flow_blocks = []
        for _ in range(8):
            middle_flow_blocks.append(Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True))
        self.middle_flow = nn.Sequential(*middle_flow_blocks)

        # ==================== Exit Flow ====================

        # The exit flow performs a final transformation and downsampling before the classification layer.
        # Block 1: Increases channels and performs spatial downsampling.
        self.exit_flow_block1 = Block(728, 1024, reps=2, stride=2, start_with_relu=True, grow_first=False)
        
        # The final part of the exit flow consists of two separable convolutions without a residual connection.
        # It expands the channel space further to 2048, which serves as the final feature representation
        # before global pooling.
        self.exit_flow_sepconv1 = SeparableConv2d(1024, 1536, kernel_size=3, stride=1, padding=1, bias=False)
        self.exit_flow_bn1 = nn.BatchNorm2d(1536)
        self.exit_flow_relu1 = nn.ReLU()
        
        self.exit_flow_sepconv2 = SeparableConv2d(1536, 2048, kernel_size=3, stride=1, padding=1, bias=False)
        self.exit_flow_bn2 = nn.BatchNorm2d(2048)
        self.exit_flow_relu2 = nn.ReLU()
        
        # ==================== Classifier ====================
        
        # Global Average Pooling takes the HxWx2048 feature map and computes the average
        # value for each of the 2048 channels, resulting in a 1x1x2048 feature vector.
        # This is a powerful technique to reduce parameters and prevent overfitting compared
        # to flattening and using multiple large fully-connected layers.
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # The final fully-connected (linear) layer maps the 2048 features to the number of classes.
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        """
        Forward pass for the full Xception model.
        
        Args:
            x (torch.Tensor): Input tensor, typically of shape (batch_size, 3, 299, 299).
            
        Returns:
            torch.Tensor: Output logits, of shape (batch_size, num_classes).
        """
        # Entry Flow
        x = self.entry_flow_conv1(x)
        x = self.entry_flow_bn1(x)
        x = self.entry_flow_relu1(x)
        
        x = self.entry_flow_conv2(x)
        x = self.entry_flow_bn2(x)
        x = self.entry_flow_relu2(x)
        
        x = self.entry_flow_block1(x)
        x = self.entry_flow_block2(x)
        x = self.entry_flow_block3(x)
        
        # Middle Flow
        x = self.middle_flow(x)
        
        # Exit Flow
        x = self.exit_flow_block1(x)
        
        x = self.exit_flow_sepconv1(x)
        x = self.exit_flow_bn1(x)
        x = self.exit_flow_relu1(x)
        
        x = self.exit_flow_sepconv2(x)
        x = self.exit_flow_bn2(x)
        x = self.exit_flow_relu2(x)
        
        # Classifier
        x = self.avgpool(x)
        x = torch.flatten(x, 1) # Flatten the tensor for the FC layer
        x = self.fc(x)
        
        return x