<a href="https://colab.research.google.com/github/Biswajitnahak2003/snn-glacier-segmentation/blob/main/week3/week3_baseline_unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# --- =================================================================== ---
# --- Week 3: U-Net Model Architecture Definition                         ---
# --- =================================================================== ---

# --- Step 1: Imports ---
import torch
import torch.nn as nn

# --- Step 2: A Proper, Full U-Net Architecture ---
class UNet(nn.Module):
    def __init__(self, in_channels=5, out_channels=1):
        """
        Initializes all the layers of the U-Net.
        """
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_c),
                nn.ReLU(),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_c),
                nn.ReLU()
            )

        # --- Encoder (Contracting Path) ---
        self.enc1 = conv_block(in_channels, 64)   # Output: 64 channels
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc2 = conv_block(64, 128)          # Output: 128 channels
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # --- Bottleneck ---
        self.bottleneck = conv_block(128, 256)   # Output: 256 channels

        # --- Decoder (Expanding Path) ---
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)         # Input: 128 (from upconv) + 128 (from skip2)

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)          # Input: 64 (from upconv) + 64 (from skip1)

        # --- Final Output Layer ---
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        """
        Defines the forward pass - how data flows through the network.
        """
        # Encoder
        skip1 = self.enc1(x)
        p1 = self.pool1(skip1)
        skip2 = self.enc2(p1)
        p2 = self.pool2(skip2)

        # Bottleneck
        b = self.bottleneck(p2)

        # Decoder with Skip Connections
        up2 = self.upconv2(b)
        merge2 = torch.cat([up2, skip2], dim=1) # Concatenate along channel dimension
        d2 = self.dec2(merge2)

        up1 = self.upconv1(d2)
        merge1 = torch.cat([up1, skip1], dim=1)
        d1 = self.dec1(merge1)

        # Final Output
        output = self.out_conv(d1)
        return self.sigmoid(output)

# --- Step 3: Verification ---
if __name__ == '__main__':
    test_input = torch.randn(4, 5, 256, 256)

    model = UNet()

    output = model(test_input)

    print("--- Model Architecture Verification ---")
    print(f"Input Shape:  {test_input.shape}")
    print(f"Output Shape: {output.shape}")

    assert test_input.shape[2:] == output.shape[2:], "Height and Width of output must match input"
    assert output.shape[1] == 1, "Output channels must be 1"

    print("\nVerification successful! The model architecture is correct.")

--- Model Architecture Verification ---
Input Shape:  torch.Size([4, 5, 256, 256])
Output Shape: torch.Size([4, 1, 256, 256])

Verification successful! The model architecture is correct.
