In [2]:
'''
U-Net Architecture
+-------------+-----------------+-------------------------------+------------------+
| Layer       | Type            | Configuration                 | Output Size      |
+-------------+-----------------+-------------------------------+------------------+
| Input       | Image           | 572 x 572 x 1 (Grayscale)     | 572 x 572 x 1    |
|             |                 |                               |                  |
| ========== ENCODER (CONTRACTING PATH) ==========                                 |
|             |                 |                               |                  |
| Conv1-1     | Convolution     | 64 filters (3x3), No padding  | 570 x 570 x 64   |
| Conv1-2     | Convolution     | 64 filters (3x3), No padding  | 568 x 568 x 64   |
| Pool1       | Max Pooling     | 2x2 window, Stride 2          | 284 x 284 x 64   |
|             |                 |                               |                  |
| Conv2-1     | Convolution     | 128 filters (3x3), No padding | 282 x 282 x 128  |
| Conv2-2     | Convolution     | 128 filters (3x3), No padding | 280 x 280 x 128  |
| Pool2       | Max Pooling     | 2x2 window, Stride 2          | 140 x 140 x 128  |
|             |                 |                               |                  |
| Conv3-1     | Convolution     | 256 filters (3x3), No padding | 138 x 138 x 256  |
| Conv3-2     | Convolution     | 256 filters (3x3), No padding | 136 x 136 x 256  |
| Pool3       | Max Pooling     | 2x2 window, Stride 2          | 68 x 68 x 256    |
|             |                 |                               |                  |
| Conv4-1     | Convolution     | 512 filters (3x3), No padding | 66 x 66 x 512    |
| Conv4-2     | Convolution     | 512 filters (3x3), No padding | 64 x 64 x 512    |
| Pool4       | Max Pooling     | 2x2 window, Stride 2          | 32 x 32 x 512    |
|             |                 |                               |                  |
| ========== BOTTLENECK ==========                                                 |
|             |                 |                               |                  |
| Conv5-1     | Convolution     | 1024 filters (3x3), No padding| 30 x 30 x 1024   |
| Conv5-2     | Convolution     | 1024 filters (3x3), No padding| 28 x 28 x 1024   |
|             |                 |                               |                  |
| ========== DECODER (EXPANDING PATH) ==========                                   |
|             |                 |                               |                  |
| UpConv4     | Up-Convolution  | 512 filters (2x2), Stride 2   | 56 x 56 x 512    |
| Crop4       | Crop & Concat   | From Conv4-2 (64x64 → 56x56)  | 56 x 56 x 1024   |
| Conv6-1     | Convolution     | 512 filters (3x3), No padding | 54 x 54 x 512    |
| Conv6-2     | Convolution     | 512 filters (3x3), No padding | 52 x 52 x 512    |
|             |                 |                               |                  |
| UpConv3     | Up-Convolution  | 256 filters (2x2), Stride 2   | 104 x 104 x 256  |
| Crop3       | Crop & Concat   | From Conv3-2 (136x136→104x104)| 104 x 104 x 512  |
| Conv7-1     | Convolution     | 256 filters (3x3), No padding | 102 x 102 x 256  |
| Conv7-2     | Convolution     | 256 filters (3x3), No padding | 100 x 100 x 256  |
|             |                 |                               |                  |
| UpConv2     | Up-Convolution  | 128 filters (2x2), Stride 2   | 200 x 200 x 128  |
| Crop2       | Crop & Concat   | From Conv2-2 (280x280→200x200)| 200 x 200 x 256  |
| Conv8-1     | Convolution     | 128 filters (3x3), No padding | 198 x 198 x 128  |
| Conv8-2     | Convolution     | 128 filters (3x3), No padding | 196 x 196 x 128  |
|             |                 |                               |                  |
| UpConv1     | Up-Convolution  | 64 filters (2x2), Stride 2    | 392 x 392 x 64   |
| Crop1       | Crop & Concat   | From Conv1-2 (568x568→392x392)| 392 x 392 x 128  |
| Conv9-1     | Convolution     | 64 filters (3x3), No padding  | 390 x 390 x 64   |
| Conv9-2     | Convolution     | 64 filters (3x3), No padding  | 388 x 388 x 64   |
|             |                 |                               |                  |
| Output      | Convolution     | 2 filters (1x1) + Softmax     | 388 x 388 x 2    |
+-------------+-----------------+-------------------------------+------------------+

Key Characteristics of U-Net:

- **Symmetric U-Shape**: Encoder (contracting) path followed by decoder (expanding) path
- **Skip Connections**: Feature maps from encoder are concatenated with decoder maps
  (helps recover spatial information lost during downsampling)
- **No Padding**: Original paper uses unpadded convolutions (valid padding)
  causing output to be smaller than input (388x388 vs 572x572)
- **No Fully Connected Layers**: Fully convolutional architecture enables 
  arbitrary input sizes
- **ReLU Activation**: Applied after every convolution layer
- **Dropout**: Optional, typically 0.5 after convolutional layers in deeper networks
- **Total Parameters**: ~31 million (original U-Net)

Architecture Philosophy:

1. **Contracting Path** (Left side of U):
   - Captures context through repeated convolution and max pooling
   - Progressively reduces spatial dimensions while increasing feature channels
   - Standard convolutional network architecture

2. **Bottleneck** (Bottom of U):
   - Smallest spatial dimensions, highest number of feature channels
   - Captures the most abstract features

3. **Expanding Path** (Right side of U):
   - Recovers spatial resolution through up-convolutions (transposed convolutions)
   - Skip connections from encoder provide fine-grained spatial information
   - Enables precise localization for segmentation

4. **Skip Connections**:
   - Concatenate encoder features with corresponding decoder features
   - Encoder features are cropped to match decoder size (due to no padding)
   - Combines high-resolution encoder features with upsampled decoder features
   - Critical for accurate pixel-wise segmentation

Advantages:
- Works well with very few training images (data augmentation helps)
- Achieves precise localization with skip connections
- End-to-end training for image-to-image tasks
- Fast inference (fully convolutional, no dense layers)

Common Use Cases:
- Medical image segmentation (original purpose)
- Satellite image analysis
- Any pixel-wise prediction task requiring spatial precision
'''
print()




In [12]:
###################
## UNet building ##
###################

import torch
from torch import nn

#-------
## Define DoubleConv
#-------

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), # bias=False for BatchNorm2d()
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)            
        )

    def forward(self, X):
        return self.cnn(X)

#-------
## Building
#-------

import torchvision.transforms.functional as TF

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]): # In this example, we just do binary segmentation, so can set out_channels=1
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature # set in_channels for shape matching of the next layer

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) # ConvTranpose: https://www.youtube.com/watch?v=qb4nRoEAASA
            self.ups.append(DoubleConv(feature*2, feature))                                  # This ConvTranspose is for upsampling back 64->128->256->512

        # Bottleneck part (transition from down to up)
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)

        # Final conv layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, X):
        skip_connections = [] # Store values from down block to use later on

        for down in self.downs:
            X = down(X)
            skip_connections.append(X)
            X = self.pool(X)

        X = self.bottleneck(X)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2): # So idx takes values: 0, 2, 4, 6 (accessing only the ConvTranspose layers for upsampling)
            X = self.ups[idx](X)
            skip_connection = skip_connections[idx//2]

            if X.shape != skip_connection.shape:
                X = TF.resize(X, skip_connection.shape[2:]) # ensure X and skip_connection have the same shape for torch.cat()
            
            concat_skip = torch.cat((skip_connection, X), dim=1) # Combining high-resolution encoder features with upsampled decoder features.
            X = self.ups[idx+1](concat_skip)

        return self.final_conv(X)