
### What is a U-Net?

U-Net is a type of deep learning model specifically designed for image segmentation tasks. Image segmentation involves dividing an image into different parts or segments, often to identify and isolate objects or regions of interest within the image. For example, in medical imaging, it can be used to segment different tissues, organs, or anomalies.

### Structure of a U-Net

The U-Net architecture is named for its U-shaped structure, which consists of two main parts: the contracting path (encoder) and the expansive path (decoder).

1. **Contracting Path (Encoder)**:
   - The encoder is similar to a traditional convolutional neural network (CNN).
   - It consists of several convolutional layers followed by max-pooling layers.
   - The purpose of the encoder is to capture the context and features of the input image by progressively reducing its spatial dimensions while increasing the number of feature channels.
   - This path captures "what" is in the image.

2. **Expansive Path (Decoder)**:
   - The decoder upsamples the feature maps to the original image size.
   - It consists of up-convolutional (transposed convolution) layers followed by convolutional layers.
   - The purpose of the decoder is to construct a detailed segmentation map using the features captured by the encoder.
   - Skip connections are used to transfer information from the encoder layers directly to the decoder layers, which helps in preserving spatial information and improving segmentation accuracy.
   - This path helps in determining "where" the objects are located in the image.

### Key Features

- **Skip Connections**: These connections between corresponding layers in the encoder and decoder help retain high-resolution features that might otherwise be lost during downsampling. They allow the model to use both global and local context for better segmentation.
  
- **Symmetry**: The U-Net is symmetric, meaning the layers in the decoder mirror those in the encoder. This symmetry aids in the reconstruction of the segmented image.
  
- **Localization and Context**: U-Net's design allows it to understand both the precise localization of objects and the broader context of the entire image, making it effective for segmentation tasks.

### Applications

U-Net was originally developed for biomedical image segmentation but has since been widely adopted for various other applications, including:

- **Medical Imaging**: Segmenting different tissues, organs, or tumors in medical scans like MRI or CT images.
- **Satellite Imagery**: Segmenting different land types, water bodies, and built-up areas.
- **Autonomous Driving**: Identifying different objects on the road such as cars, pedestrians, and lanes.
- **Agriculture**: Segmenting crops and plants from background soil or other objects.


In [1]:
# U-Net Model from Scratch in PyTorch
# ===================================

# Step 1: Import Libraries
import torch
import torch.nn as nn
import numpy as np

# Step 2: Define the U-Net Model
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        def conv_block(in_channels, out_channels):
            block = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )
            return block

        def upconv_block(in_channels, out_channels):
            block = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
                nn.ReLU(inplace=True)
            )
            return block

        self.enc1 = conv_block(3, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = conv_block(512, 1024)

        self.upconv4 = upconv_block(1024, 512)
        self.dec4 = conv_block(1024, 512)
        self.upconv3 = upconv_block(512, 256)
        self.dec3 = conv_block(512, 256)
        self.upconv2 = upconv_block(256, 128)
        self.dec2 = conv_block(256, 128)
        self.upconv1 = upconv_block(128, 64)
        self.dec1 = conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        print(f"Input: {x.shape}")

        enc1 = self.enc1(x)
        print(f"After enc1: {enc1.shape}")
        enc2 = self.enc2(self.pool(enc1))
        print(f"After enc2: {enc2.shape}")
        enc3 = self.enc3(self.pool(enc2))
        print(f"After enc3: {enc3.shape}")
        enc4 = self.enc4(self.pool(enc3))
        print(f"After enc4: {enc4.shape}")

        bottleneck = self.bottleneck(self.pool(enc4))
        print(f"After bottleneck: {bottleneck.shape}")

        dec4 = self.dec4(torch.cat((self.upconv4(bottleneck), enc4), dim=1))
        print(f"After dec4: {dec4.shape}")
        dec3 = self.dec3(torch.cat((self.upconv3(dec4), enc3), dim=1))
        print(f"After dec3: {dec3.shape}")
        dec2 = self.dec2(torch.cat((self.upconv2(dec3), enc2), dim=1))
        print(f"After dec2: {dec2.shape}")
        dec1 = self.dec1(torch.cat((self.upconv1(dec2), enc1), dim=1))
        print(f"After dec1: {dec1.shape}")

        output = torch.sigmoid(self.final_conv(dec1))
        print(f"Output: {output.shape}")

        return output

# Step 3: Create Synthetic Data
def create_synthetic_data(input_shape, num_samples):
    X = np.random.rand(num_samples, *input_shape).astype(np.float32)
    X = np.transpose(X, (0, 3, 1, 2))  # Convert to (N, C, H, W)
    return torch.tensor(X)

input_shape = (128, 128, 3)  # Example input shape
num_samples = 1  # Only one sample to show the tensor sizes

X = create_synthetic_data(input_shape, num_samples)

# Step 4: Instantiate the Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)

# Step 5: Run a Forward Pass
test_input = X.to(device)
predicted_output = model(test_input)


Input: torch.Size([1, 3, 128, 128])
After enc1: torch.Size([1, 64, 128, 128])
After enc2: torch.Size([1, 128, 64, 64])
After enc3: torch.Size([1, 256, 32, 32])
After enc4: torch.Size([1, 512, 16, 16])
After bottleneck: torch.Size([1, 1024, 8, 8])
After dec4: torch.Size([1, 512, 16, 16])
After dec3: torch.Size([1, 256, 32, 32])
After dec2: torch.Size([1, 128, 64, 64])
After dec1: torch.Size([1, 64, 128, 128])
Output: torch.Size([1, 1, 128, 128])
