### U-Net

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class DoubleConv(nn.Module):
    """
    A block consisting of two 3x3 convolutional layers, 
    each followed by Batch Normalization and ReLU activation.
    This is the fundamental building block throughout the U-Net.
    """
    def __init__(self, in_channels, out_channels):
        """
        Initializes the DoubleConv block.
        
        Args:
            in_channels (int): Number of input feature channels.
            out_channels (int): Number of output feature channels after the block.
        """
        super().__init__() # Call the constructor of the parent class (nn.Module)
        
        # Define the sequence of operations for the double convolution block
        # nn.Sequential allows chaining multiple modules together
        self.conv = nn.Sequential(
            # First Convolutional Layer
            # nn.Conv2d: Applies a 2D convolution over an input signal composed of several input planes.
            # in_channels: Number of channels in the input image/feature map.
            # out_channels: Number of channels produced by the convolution.
            # kernel_size=3: The size of the convolutional kernel (3x3 pixels).
            # padding=0: Crucial for the original U-Net design. No padding is added, 
            #            which means the output feature map will be smaller than the input.
            #            This is why the 'copy and crop' mechanism is needed for skip connections.
            # bias=False: Typically set to False when Batch Normalization is used, 
            #             as Batch Norm introduces its own learnable bias.
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0, bias=False),
            # nn.BatchNorm2d: Applies Batch Normalization over a 4D input (batch_size, channels, height, width).
            #                 It normalizes the activations of the previous layer, 
            #                 stabilizing training and accelerating convergence.
            nn.BatchNorm2d(out_channels),
            # nn.ReLU: Rectified Linear Unit activation function. f(x) = max(0, x).
            #          Introduces non-linearity, allowing the network to learn complex patterns.
            # inplace=True: Modifies the input tensor directly, saving memory by avoiding creating a new tensor.
            nn.ReLU(inplace=True),
            
            # Second Convolutional Layer (similar to the first one, but input channels are now out_channels)
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        """
        Defines the forward pass for the DoubleConv block.
        
        Args:
            x (torch.Tensor): The input feature map to the block.
            
        Returns:
            torch.Tensor: The output feature map after passing through the block.
        """
        return self.conv(x)

In [4]:
class Down(nn.Module):
    """
    Downsampling block in the U-Net's contracting path.
    Consists of a MaxPool2d layer followed by a DoubleConv block.
    """
    def __init__(self, in_channels, out_channels):
        """
        Initializes the Down block.
        
        Args:
            in_channels (int): Number of input feature channels (from the previous DoubleConv).
            out_channels (int): Number of output feature channels for the DoubleConv in this block.
        """
        super().__init__()
        
        # MaxPool2d: Reduces the spatial dimensions (height, width) of the input.
        # kernel_size=2, stride=2: Takes the maximum value in a 2x2 window and moves 2 pixels at a time.
        #                     This effectively halves the height and width of the feature map.
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            # After max-pooling, apply the DoubleConv block to process the downsampled features.
            # The input channels to DoubleConv are the same as the input to MaxPool2d.
            # The output channels are then defined by out_channels, typically doubling with each downsampling step.
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        """
        Defines the forward pass for the Down block.
        
        Args:
            x (torch.Tensor): The input feature map to the block.
            
        Returns:
            torch.Tensor: The output feature map after downsampling and convolution.
        """
        return self.maxpool_conv(x)

In [5]:
class Up(nn.Module):
    """
    Upsampling block in the U-Net's expanding path.
    Consists of a ConvTranspose2d (up-convolution), concatenation with a cropped skip connection,
    and a DoubleConv block.
    """
    def __init__(self, in_channels, out_channels):
        """
        Initializes the Up block.
        
        Args:
            in_channels (int): Number of input feature channels for the ConvTranspose2d. 
                               This will be the combined channels from the lower (bottleneck) layer.
            out_channels (int): Number of output feature channels for the DoubleConv in this block.
                                This is typically half of the input channels.
        """
        super().__init__()
        
        # nn.ConvTranspose2d: Performs transposed convolution (often called "deconvolution").
        #                      It effectively upsamples the feature map.
        # in_channels: Channels from the previous (lower) layer in the expanding path.
        # out_channels: Half of the in_channels, as specified in the U-Net diagram 
        #               ("halves the number of feature channels").
        # kernel_size=2, stride=2: Upsamples the feature map by a factor of 2 in both height and width.
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        
        # DoubleConv: Applied after concatenation of upsampled features and skip connection.
        # Its input channels are the sum of upsampled channels and skip connection channels.
        # up_in_channels (in_channels // 2) + skip_in_channels (in_channels // 2) = in_channels
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x, skip_connection):
        """
        Defines the forward pass for the Up block.
        
        Args:
            x (torch.Tensor): The input feature map from the previous (lower) upsampling block.
            skip_connection (torch.Tensor): The corresponding high-resolution feature map 
                                            copied from the contracting path.
                                            
        Returns:
            torch.Tensor: The output feature map after upsampling, concatenation, and convolution.
        """
        # 1. Upsample the input from the lower layer
        x = self.up(x) # x's spatial dimensions are now larger.

        # 2. Handle the "copy and crop" mechanism for the skip connection
        # Due to 'padding=0' in Conv2d layers, the encoder's feature maps are larger than the decoder's.
        # We need to center-crop the skip_connection to match the spatial dimensions of the upsampled 'x'.
        # Calculate the difference in dimensions (skip_connection is larger)
        diff_y = skip_connection.size()[2] - x.size()[2]
        diff_x = skip_connection.size()[3] - x.size()[3]

        # Use F.pad to pad the upsampled 'x' tensor to match the dimensions of the skip_connection.
        # This is a common practice when the upsampled tensor might be slightly smaller than expected
        # due to odd dimensions or slight inconsistencies in ConvTranspose2d output size.
        # The padding amounts are applied to (left, right, top, bottom)
        # Note: The original paper describes cropping the skip_connection. 
        # In PyTorch implementations, sometimes padding the upsampled tensor is done for convenience
        # if the upsampled output is slightly off by 1 pixel (e.g., 56x56 vs 57x57).
        # Let's stick to the paper's literal "copy and crop" (cropping skip_connection)
        # for a more direct interpretation.
        
        # Center-crop the skip_connection to match the size of x
        # Calculate the starting and ending indices for cropping
        crop_start_y = diff_y // 2
        crop_end_y = skip_connection.size()[2] - (diff_y - diff_y // 2)
        crop_start_x = diff_x // 2
        crop_end_x = skip_connection.size()[3] - (diff_x - diff_x // 2)
        
        cropped_skip_connection = skip_connection[:, :, crop_start_y:crop_end_y, crop_start_x:crop_end_x]

        # 3. Concatenate the upsampled features with the (cropped) skip connection
        # dim=1 means concatenating along the channel dimension.
        # This combines the precise spatial information from the encoder with the semantic context from the decoder.
        x = torch.cat([cropped_skip_connection, x], dim=1) # The order matters conceptually but not mathematically here.
        
        # 4. Apply the DoubleConv block to the concatenated features
        return self.conv(x)

In [6]:
class OutConv(nn.Module):
    """
    Final output convolutional layer.
    Maps the number of feature channels to the number of output classes.
    """
    def __init__(self, in_channels, out_channels):
        """
        Initializes the OutConv layer.
        
        Args:
            in_channels (int): Number of input feature channels (from the last DoubleConv in the expanding path).
            out_channels (int): Number of output classes for segmentation (e.g., 2 for foreground/background).
        """
        super().__init__()
        # nn.Conv2d(in_channels, out_channels, kernel_size=1):
        # A 1x1 convolution is used to linearly combine the feature channels 
        # into the desired number of class scores for each pixel.
        # It's essentially a fully connected layer operating on each pixel individually.
        # padding=0 is default for kernel_size=1, so no change in spatial dimensions.
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        """
        Defines the forward pass for the OutConv layer.
        
        Args:
            x (torch.Tensor): The input feature map from the final Up block.
            
        Returns:
            torch.Tensor: The output logits/scores for each class per pixel.
                          These will typically be passed through a softmax or sigmoid 
                          activation outside the network for probabilities.
        """
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    """
    The full U-Net architecture for biomedical image segmentation.
    Consists of a contracting path (encoder) and an expanding path (decoder)
    with skip connections.
    """
    def __init__(self, in_channels=1, num_classes=2):
        """
        Initializes the U-Net model.
        
        Args:
            in_channels (int): Number of input image channels (e.g., 1 for grayscale, 3 for RGB).
            num_classes (int): Number of output segmentation classes (e.g., 2 for background/foreground).
        """
        super().__init__()

        # --- Contracting Path (Encoder) ---
        # The paper's architecture has 5 levels of feature maps:
        # Level 1 (input_channels -> 64)
        self.inc = DoubleConv(in_channels, 64) 
        
        # Level 2 (64 -> 128, with downsampling)
        self.down1 = Down(64, 128)
        
        # Level 3 (128 -> 256, with downsampling)
        self.down2 = Down(128, 256)
        
        # Level 4 (256 -> 512, with downsampling)
        self.down3 = Down(256, 512)
        
        # Level 5 (The deepest layer, or "bottleneck" - 512 -> 1024, with downsampling)
        # This is where the most abstract contextual features are captured.
        self.down4 = Down(512, 1024)

        # --- Expanding Path (Decoder) ---
        # Level 5 (1024 -> 512, with upsampling and concatenation)
        # Note: The in_channels for Up is the total channels *before* the initial up-conv.
        #       So it's 1024 (from down4), which will be up-conved to 512, then concatenated with skip_connection (512 channels).
        #       Thus, the DoubleConv inside Up1 takes 1024 channels as input and outputs 512.
        self.up1 = Up(1024, 512) 
        
        # Level 4 (512 -> 256, with upsampling and concatenation)
        self.up2 = Up(512, 256)
        
        # Level 3 (256 -> 128, with upsampling and concatenation)
        self.up3 = Up(256, 128)
        
        # Level 2 (128 -> 64, with upsampling and concatenation)
        self.up4 = Up(128, 64)
        
        # --- Output Layer ---
        # Maps the final 64 channels to the desired number of output classes.
        self.outc = OutConv(64, num_classes)

    def forward(self, x):
        """
        Defines the forward pass for the entire U-Net.
        
        Args:
            x (torch.Tensor): The input image (e.g., [batch_size, 1, 572, 572]).
            
        Returns:
            torch.Tensor: The predicted segmentation logits (e.g., [batch_size, num_classes, H_out, W_out]).
        """
        # --- Contracting Path ---
        # Each 'down' block returns the output of its DoubleConv, 
        # which will be used as a skip connection.
        x1 = self.inc(x) # Output: 64 channels. Spatial: 572x572 -> 568x568
        x2 = self.down1(x1) # Output: 128 channels. Spatial: 568x568 (maxpool) -> 284x284 (DoubleConv) -> 280x280
        x3 = self.down2(x2) # Output: 256 channels. Spatial: 280x280 (maxpool) -> 140x140 (DoubleConv) -> 136x136
        x4 = self.down3(x3) # Output: 512 channels. Spatial: 136x136 (maxpool) -> 68x68 (DoubleConv) -> 64x64
        x5 = self.down4(x4) # Output: 1024 channels. Spatial: 64x64 (maxpool) -> 32x32 (DoubleConv) -> 28x28 (This is the bottleneck)

        # --- Expanding Path ---
        # Each 'up' block takes the upsampled feature map from the previous level (x_i)
        # AND the corresponding skip connection from the contracting path (x_skip).
        # The skip connections provide fine-grained spatial information.
        x = self.up1(x5, x4) # Input: x5 (28x28, 1024), x4 (64x64, 512). Output of up1: (56x56, 512)
        x = self.up2(x, x3) # Input: x (56x56, 512), x3 (136x136, 256). Output of up2: (112x112, 256)
        x = self.up3(x, x2) # Input: x (112x112, 256), x2 (280x280, 128). Output of up3: (224x224, 128)
        x = self.up4(x, x1) # Input: x (224x224, 128), x1 (568x568, 64). Output of up4: (448x448, 64)

        # --- Output Layer ---
        # Final 1x1 convolution to map the feature channels to the number of classes.
        # This gives the raw class scores (logits) per pixel.
        logits = self.outc(x) # Output: (batch_size, num_classes, H_out, W_out), H_out, W_out are 388x388 for 572x572 input
        
        return logits

In [None]:
if __name__ == '__main__':
    # 1. Instantiate the U-Net model
    # For grayscale input (1 channel) and 2 output classes (e.g., foreground/background)
    model = UNet(in_channels=1, num_classes=2)
    print(f"Model architecture:\n{model}")

    # 2. Create a dummy input tensor
    # The paper's example input size is 572x572 for a 5-level U-Net with padding=0.
    # This size ensures a clean integer output size (388x388 in this case).
    batch_size = 1
    input_channels = 1
    input_height = 572
    input_width = 572
    dummy_input = torch.randn(batch_size, input_channels, input_height, input_width)
    print(f"\nDummy input shape: {dummy_input.shape}")

    # 3. Perform a forward pass
    output_logits = model(dummy_input)
    print(f"Output logits shape: {output_logits.shape}")

    # Expected output shape: (Batch, num_classes, 388, 388)
    # The actual spatial size (388x388) is smaller than the input (572x572) 
    # due to 'padding=0' in all Conv2d layers.
    # This confirms the U-Net's characteristic output shrinking.

    # 4. (Optional) Move model to GPU if available
    if torch.cuda.is_available():
        model.cuda()
        dummy_input = dummy_input.cuda()
        output_logits_gpu = model(dummy_input)
        print(f"Output logits shape on GPU: {output_logits_gpu.shape}")
        print("Model and data moved to GPU.")
    else:
        print("CUDA not available. Running on CPU.")

    # 5. Calculate total number of parameters (important for understanding model complexity)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nTotal trainable parameters: {total_params:,}") # ~31 million parameters for this configuration