In [1]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
import numpy as np

from torchvision import datasets
import torchvision.transforms.v2 as transforms
#import torchvision.transforms as transforms

from torchvision.datasets import Cityscapes
from argparse import ArgumentParser

import torch.optim.lr_scheduler as lr_scheduler

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import gc
import torch.nn.functional as F

import os
import random

import pdb
import collections
from matplotlib.colors import ListedColormap

import os
import albumentations as A
#import wandb
from torch.optim.lr_scheduler import StepLR

from torch.utils.data.sampler import SubsetRandomSampler

from math import ceil

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Preapre the data
Load the dataset and apply the transformation to the data. Split the dataset into the training and the test datasets with the desired split ratio.

In [3]:
# Create the transform variable
size = 512
transform = transforms.Compose([
transforms.Resize((size, size*2), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 
])

target_transforms = transforms.Compose([
    transforms.Resize((size, size*2)),
    transforms.ToTensor(),
])
 
# Load the dataset and apply transforms download=False, transform=transform, target_transform=target_transforms
dataset_train = datasets.Cityscapes('./data', split='train', mode='fine', target_type='semantic', transform=transform, target_transform=target_transforms)

# Split training set into training and validation sets
split = 0.8
boundary = round(split*round(len(dataset_train)))
train_dataset = torch.utils.data.Subset(dataset_train, range(boundary))
val_dataset = torch.utils.data.Subset(dataset_train, range(boundary, len(dataset_train)))

# Create data loaders
trainloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=False)
validationloader = DataLoader(val_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=False)

### Mapping function and selection of 18 classes. 
Please run this function in order to select 18 out of 30 classes. 

In [4]:
from collections import namedtuple
import torch

Label = namedtuple( 'Label' , [

    'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
                    # We use them to uniquely name a class

    'id'          , # An integer ID that is associated with this label.
                    # The IDs are used to represent the label in ground truth images
                    # An ID of -1 means that this label does not have an ID and thus
                    # is ignored when creating ground truth images (e.g. license plate).
                    # Do not modify these IDs, since exactly these IDs are expected by the
                    # evaluation server.

    'trainId'     , # Feel free to modify these IDs as suitable for your method. Then create
                    # ground truth images with train IDs, using the tools provided in the
                    # 'preparation' folder. However, make sure to validate or submit results
                    # to our evaluation server using the regular IDs above!
                    # For trainIds, multiple labels might have the same ID. Then, these labels
                    # are mapped to the same class in the ground truth images. For the inverse
                    # mapping, we use the label that is defined first in the list below.
                    # For example, mapping all void-type classes to the same ID in training,
                    # might make sense for some approaches.
                    # Max value is 255!

    'category'    , # The name of the category that this label belongs to

    'categoryId'  , # The ID of this category. Used to create ground truth images
                    # on category level.

    'hasInstances', # Whether this label distinguishes between single instances or not

    'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
                    # during evaluations or not

    'color'       , # The color of this label
    ] )

LABELS = [
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
]

def map_id_to_train_id(label_id):
    """map the id to the train id for cityscapes masks
    input: Tensor of shape (batch_size, height, width) with values from 0 to 33
    output: Tensor of shape (batch_size, height, width) with values from 0 to 18
    """
    # create a tensor with the same shape as the input tensor and fill it with the value 255
    train_id_tensor = torch.full_like(label_id, 255)
    for label in LABELS:
        # replace the value in the tensor with the train id if the value in the input tensor is equal to the id of the label
        train_id_tensor[label_id == label.id] = label.trainId
    return train_id_tensor

### Early stopping and training loop
This section is applicable for all the models, please run it. 

In [8]:
class EarlyStopper:
    def __init__(self, epoch_start=10, diff_patience=10, diff_min_delta=0.2, diff_lim = 0.35, val_patience=10, val_min_delta=0.1, val_lim = 0.3):
        self.epoch_start = epoch_start
        self.diff_patience = diff_patience
        self.diff_min_delta = diff_min_delta
        self.diff_lim = diff_lim
        self.val_patience = val_patience
        self.val_min_delta = val_min_delta
        self.val_lim = val_lim
        self.counter_val = 0
        self.counter_diff = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, epoch, running_loss, validation_loss):
        # Activate early stopping after #num epochs, when the validation and training errors are ~ same. 
        if epoch > self.epoch_start:
            # Check if the current validation error is better, if true save it as new min_validation
            if validation_loss < self.min_validation_loss:
                self.min_validation_loss = validation_loss
                self.counter_val = 0
                # Check if validation loss is below certain threshold and stop if true
                if self.min_validation_loss < self.val_lim:
                    print("The validation loss is below the threshold.\n")
                    return True
            # Check if the validation does not improve over epochs, and stop if true
            elif validation_loss > (self.min_validation_loss + self.val_min_delta):
                self.counter_val += 1
                if self.counter_val >= self.val_patience:
                    print("The validation loss is not decreasing for multiple epochs.\n")
                    return True
                
            # Check if the training error is greater than validation error, if true reset    
            if running_loss > self.min_validation_loss:
                self.counter_diff = 0
            # Check if the training error differce from validation by certain limit, and stop if true 
            elif running_loss < (self.min_validation_loss - self.diff_lim):
                print("The validation loss is too far away from validation error.\n")
                return True
            # Check if the difference betweem the training error and validation error does not decrease, and stop if true
            elif running_loss < (self.min_validation_loss - self.diff_min_delta):
                self.counter_diff += 1
                if self.counter_diff >= self.diff_patience:
                    print("The validation loss is far away from validation error for multiple epochs.\n")
                    return True 
        return False

def train_model_segmentation(model, train_loader, num_epochs=5, lr=0.01, step_size=20):
    criterion = nn.CrossEntropyLoss(ignore_index=255)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Set scheduler to change the learning rate over number of epochs
    scheduler = StepLR(optimizer, step_size=step_size, gamma=0.1)
    # Early stopping
    early_stopper = EarlyStopper(epoch_start=10, diff_patience=20, diff_min_delta=1, diff_lim = 0.8, val_patience=10, val_min_delta=1, val_lim = 0.3)
        
    for epoch in range(num_epochs):
        running_loss = 0.0
        val_loss = 0.0
        for i, data in enumerate(trainloader):
            inputs, labels = data[0].to(device), (data[1]*255).to(device).long()
            labels = map_id_to_train_id(labels)#.to(device) 
            labels=labels.squeeze(1)
            optimizer.zero_grad()
            optimizer.zero_grad()

            outputs = model(inputs)
            labels = labels.squeeze(1)
            loss = criterion(outputs, labels)
            v=epoch + 1
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            print(f'Epoch {epoch + 1}, Iteration [{i}/{len(train_loader)}], Loss: {running_loss/(i+1)}')

        with torch.no_grad():
            model.eval()
            
            for i, data in enumerate(validationloader):
                inputs, labels = data[0].to(device), (data[1]*255).to(device).long()
                labels = map_id_to_train_id(labels)#.to(device)
                labels=labels.squeeze(1)
                outputs = model(inputs)
                
                loss = criterion(outputs, labels)
                v=epoch + 1
                val_loss += loss.item()

                print(f'TEST: Epoch {epoch + 1}, Iteration [{i}/{len(validationloader)}], Loss: {val_loss/(i+1)}')

        #visualize_segmentation(model, validationloader, device)
        print(f'Finished Train epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(trainloader):.4f}')
        print(f'Finished TEST epoch [{epoch + 1}/{num_epochs}], Loss: {val_loss / len(validationloader):.4f}')
        # Check if stopping is required
        if early_stopper.early_stop(epoch, running_loss, val_loss):             
            break
        # Change the learning rate if certain number of epochs is achieved
        if (epoch < 65):
          scheduler.step()
        if (v%step_size == 0):
          print(scheduler.get_last_lr(), " New Learning Rate")

## Models
All the following cells are optional to run. Each section contains different models: 
- Baseline U-net model with dropout and early stopping. 
- U-net model with the residual blocks and the attention layers. U-net with the ResNet as the backbone. 
- U-net with the efficient blocks. U-net with EfficientNet as the backbone. 

### Baseline U-net
This is a baseline model. The early stopping and the dropout are already added. These can be removed by modifying the following cell and the training function. 

In [None]:
# create Segmentation model
class SegmentationCNN(nn.Module):
    def __init__(self, in_channels=3, classes=19, power=4): 
        super(SegmentationCNN, self).__init__()
  
        # Factor of number of weights
        factor = 2**power

        # Encoder (contracting path)
        self.conv1 = self.convolve_block(in_channels, (2**1)*factor, kernel_size=3, stride=1, padding=1)
        self.conv2 = self.convolve_block((2**1)*factor, (2**2)*factor, kernel_size=3, stride=1, padding=1)
        self.conv3 = self.convolve_block((2**2)*factor, (2**3)*factor, kernel_size=3, stride=1, padding=1)
        self.conv4 = self.convolve_block((2**3)*factor, (2**4)*factor, kernel_size=3, stride=1, padding=1)
        self.conv5 = self.convolve_block((2**4)*factor, (2**5)*factor, kernel_size=3, stride=1, padding=1)
        
        #Max pool
        self.max = nn.MaxPool2d(2, stride=2)
        
        # Bottleneck
        self.bottleneck = self.bottleneck_block(external_channels=(2**5)*factor,internal_channels=(2**6)*factor,kernel_size=3,stride=1,padding=1)
    
        # Decoder (expanding path)
        self.upconv4 = self.expand_block((2**6)*factor, (2**5)*factor, kernel_size=3, stride=1, padding=1)
        self.upconv3 = self.expand_block((2**5)*factor, (2**4)*factor, kernel_size=3, stride=1, padding=1) 
        self.upconv2 = self.expand_block((2**4)*factor, (2**3)*factor, kernel_size=3, stride=1, padding=1)  # Adjusting input channels
        self.upconv1 = self.expand_block((2**3)*factor, (2**2)*factor, kernel_size=3, stride=1, padding=1)  # Adjusting input channels
        
        # Output layer
        self.conv_out = self.convolve_block((2**2)*factor, (2**1)*factor, kernel_size=3, stride=1, padding=1)
        self.output = nn.Conv2d((2**1)*factor, classes, kernel_size=1)
        
    # Convolution block    
    def convolve_block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5),
            nn.Conv2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5),
        )
    
    # Bottleneck block
    def bottleneck_block(self, external_channels, internal_channels, kernel_size=3, stride=1, padding=1):
        return nn.Sequential(
            nn.Conv2d(external_channels, internal_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(internal_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5),
            nn.Conv2d(internal_channels, internal_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(internal_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5),
            nn.ConvTranspose2d(internal_channels, external_channels, kernel_size=2, stride=2)
        )
    
    # Upconvolve block
    def expand_block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5),
            nn.Conv2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5),
            nn.ConvTranspose2d(out_channels, out_channels=out_channels//2, kernel_size=2, stride=2),
        )

    def forward(self, x):
        # Encoder (contracting path)
        conv1 = self.conv1(x)
        x = self.max(conv1)
        conv2 = self.conv2(x)
        x = self.max(conv2)
        conv3 = self.conv3(x)
        x = self.max(conv3)
        conv4 = self.conv4(x)
        x = self.max(conv4)
        conv5 = self.conv5(x)
        x = self.max(conv5)
        
        # Bottleneck
        bottleneck = self.bottleneck(x)
        
        # Decoder (expanding path)
        upconv4 = self.upconv4(torch.cat([conv5, bottleneck], dim=1))
        upconv3 = self.upconv3(torch.cat([conv4, upconv4], dim=1))
        upconv2 = self.upconv2(torch.cat([conv3, upconv3], dim=1))
        upconv1 = self.upconv1(torch.cat([conv2, upconv2], dim=1))
        
        # Output layer
        output = self.conv_out(torch.cat([conv1, upconv1], dim=1))
        output = self.output(output)
        return output

# Define the model
model = SegmentationCNN(in_channels=3, classes=19, power=2)

# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

# Set the weights initialization
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        # torch.nn.init.constant_(m.weight, 1) # Sets tensor m.weights to value of 1
        torch.nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))  # Uses xavier weights initialization (Var[s] = Var[x], since Var[w] = 1/n)
        if m.bias is not None:
            m.bias.data.zero_()
    if isinstance(m, nn.Linear):
        torch.nn.init.constant_(m.weight, 1) 
        if m.bias is not None:
            m.bias.data.zero_()

# Apply the weights and biases before training
model.apply(init_weights)

# Set parameters
nb_epochs = 3
learning_rate = 0.01 # 0.01 for 20
step_size = 20

# Train the model
train_model_segmentation(model, trainloader, nb_epochs, learning_rate, step_size)


### U-net with the ResNet as a backbone + Attention layers
The main difference from the conventional U-net is that the encoder uses a ResidualBlock in order to preserve the gradient. 

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        nn.BatchNorm2d(out_channels))
        self.downsample = downsample
        self.relu = nn.ReLU()
        self.out_channels = out_channels
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

class URes(nn.Module):
    def __init__(self, block = ResidualBlock, blocks = [2, 2, 2, 2], in_channels = 3, classes = 19, power = 5):
        super(URes, self).__init__()
        factor = 2**power
        self.inplanes = factor
        # First layer for edges
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, self.inplanes, kernel_size = 3, stride = 1, padding = 1),
                        nn.BatchNorm2d(self.inplanes),
                        nn.ReLU())
        # Maxpooling 
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1)
        self.max = nn.MaxPool2d(2, stride=2)

        # Encoder
        self.layer0 = self._make_layer(block, self.inplanes*2, blocks[0], stride = 1)
        self.x_att0 = nn.Conv2d(in_channels=self.inplanes, out_channels=128, kernel_size=1, stride=2, padding=0)
        self.layer1 = self._make_layer(block, self.inplanes*2, blocks[1], stride = 1)
        self.x_att1 = nn.Conv2d(in_channels=self.inplanes, out_channels=128, kernel_size=1, stride=2, padding=0)
        self.layer2 = self._make_layer(block, self.inplanes*2, blocks[2], stride = 1)
        self.x_att2 = nn.Conv2d(in_channels=self.inplanes, out_channels=128, kernel_size=1, stride=2, padding=0)
        self.layer3 = self._make_layer(block, self.inplanes*2, blocks[3], stride = 1)
        self.x_att3 = nn.Conv2d(in_channels=self.inplanes, out_channels=128, kernel_size=1, stride=2, padding=0)

        # Bottleneck
        self.bottleneck = self.bottleneck_block(external_channels=self.inplanes, internal_channels=self.inplanes*2, kernel_size=3,stride=1,padding=1)
        self.g_att3 = nn.Conv2d(in_channels=self.inplanes, out_channels=128, kernel_size=1, stride=1, padding=0)
        self.upB = nn.ConvTranspose2d(self.inplanes, self.inplanes//2, kernel_size=2, stride=2)
        # Decoder (expanding path)
        self.upconv3 = self.expand_block(self.inplanes, self.inplanes//2, kernel_size=3, stride=1, padding=1) 
        self.g_att2 = nn.Conv2d(in_channels=self.inplanes, out_channels=128, kernel_size=1, stride=1, padding=0)
        self.up3 = nn.ConvTranspose2d(self.inplanes, self.inplanes//2, kernel_size=2, stride=2)

        self.upconv2 = self.expand_block(self.inplanes, self.inplanes//2, kernel_size=3, stride=1, padding=1)  # Adjusting input channels
        self.g_att1 = nn.Conv2d(in_channels=self.inplanes, out_channels=128, kernel_size=1, stride=1, padding=0)
        self.up2 = nn.ConvTranspose2d(self.inplanes, self.inplanes//2, kernel_size=2, stride=2)

        self.upconv1 = self.expand_block(self.inplanes, self.inplanes//2, kernel_size=3, stride=1, padding=1)  # Adjusting input channels
        self.g_att0 = nn.Conv2d(in_channels=self.inplanes, out_channels=128, kernel_size=1, stride=1, padding=0)
        self.up1 = nn.ConvTranspose2d(self.inplanes, self.inplanes//2, kernel_size=2, stride=2)

         # Output layer
        self.conv_out = self.convolve_block(self.inplanes, self.inplanes//2, kernel_size=3, stride=1, padding=1)
        self.output = nn.Conv2d(self.inplanes, classes, kernel_size=1)

        # Attention functions
        self.activateRelu = nn.ReLU(inplace=True)
        self.activateSig = nn.Sigmoid()
        self.conv_psi = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
        self.upconv = nn.ConvTranspose2d(1, 1, kernel_size=2, stride=2)
        
    def _make_layer(self, block, planes, num_blocks, stride=1):
        downsample = None
        if self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for i in range(1, num_blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def convolve_block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        self.inplanes = out_channels
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def bottleneck_block(self, external_channels, internal_channels, kernel_size=3, stride=1, padding=1):
        self.inplanes = internal_channels
        return nn.Sequential(
            nn.Conv2d(external_channels, internal_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(internal_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(internal_channels, internal_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(internal_channels),
            nn.ReLU(inplace=True),
        )
    
    def expand_block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        self.inplanes = out_channels
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        # Encoder
        conv0 = self.conv1(x)       # (2**1)*power, size/(2**0)
        x = self.maxpool(conv0)     # size/(2**0)
        conv1 = self.layer0(x)      # (2**2)*power, size/(2**0)
        x_att0 = self.x_att0(conv1)
        x = self.max(conv1)         # size/(2**1)
        conv2 = self.layer1(x)      # (2**3)*power, size/(2**1)
        x_att1 = self.x_att1(conv2)
        x = self.max(conv2)         # size/(2**2)
        conv3 = self.layer2(x)      # (2**4)*power, size/(2**2)
        x_att2 = self.x_att2(conv3)
        x = self.max(conv3)         # size/(2**3)
        conv4 = self.layer3(x)      # (2**5)*power, size/(2**3)
        x_att3 = self.x_att3(conv4)
        x = self.max(conv4)         # size/(2**4)

        # Bottleneck
        bottleneck = self.bottleneck(x) # (2**5)*power, size/(2**3)
        g_att3 = self.g_att3(bottleneck)
        bottleneck = self.upB(bottleneck)
        y = conv4 * self.upconv(self.activateSig(self.conv_psi(self.activateRelu(x_att3 + g_att3))))

        upconv3 = self.upconv3(torch.cat([y, bottleneck], dim=1)) # (2**4)*power, size/(2**2)
        g_att2 = self.g_att2(upconv3)
        upconv3 = self.up3(upconv3)
        y = conv3 * self.upconv(self.activateSig(self.conv_psi(self.activateRelu(x_att2 + g_att2))))

        upconv2 = self.upconv2(torch.cat([y, upconv3], dim=1)) # (2**3)*power, size/(2**1)
        g_att1 = self.g_att1(upconv2)
        upconv2 = self.up2(upconv2)
        y = conv2 * self.upconv(self.activateSig(self.conv_psi(self.activateRelu(x_att1 + g_att1))))

        upconv1 = self.upconv1(torch.cat([y, upconv2], dim=1)) # (2**2)*power, size/(2**0)
        g_att0 = self.g_att0(upconv1)
        upconv1 = self.up1(upconv1)
        y = conv1 * self.upconv(self.activateSig(self.conv_psi(self.activateRelu(x_att0 + g_att0))))

        # Output
        output = self.conv_out(torch.cat([y, upconv1], dim=1)) # (2**1)*power, size/(2**0)
        output = self.output(output)                               # (2**0)*power, size/(2**0)

        return output

model = URes(block = ResidualBlock, blocks = [2, 2, 2, 2], in_channels = 3, classes = 19, power = 3)

# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

# Set the weights initialization
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))  # Uses xavier weights initialization (Var[s] = Var[x], since Var[w] = 1/n)
        if m.bias is not None:
            m.bias.data.zero_()
    if isinstance(m, nn.Linear):
        torch.nn.init.constant_(m.weight, 1) 
        if m.bias is not None:
            m.bias.data.zero_()

# Apply the weights and biases before training
model.apply(init_weights)

# Set parameters
nb_epochs = 3
learning_rate = 0.01 # 0.01 for 20
step_size = 20

# Train the model
train_model_segmentation(model, trainloader, nb_epochs, learning_rate, step_size)

### U-net with the EfficientNet as a backbone.
The main difference from the conventional U-net is that the encoder uses an MBConv block.

In [None]:
# Model
class SqueezeExcitation(nn.Module):
    def __init__(self, input_channels, reduced_dim):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(input_channels, reduced_dim, 1),
            nn.SiLU(),  # SiLU activation
            nn.Conv2d(reduced_dim, input_channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.se(x)

class MBConv(nn.Module):
    def __init__(self, in_channels, out_channels, expansion_factor, stride):
        super().__init__()
        mid_channels = in_channels * expansion_factor
        
        self.use_residual = in_channels == out_channels and stride == 1
        self.expand_conv = nn.Conv2d(in_channels, mid_channels, 1, bias=False) if expansion_factor != 1 else nn.Identity()
        self.bn0 = nn.BatchNorm2d(mid_channels)
        self.depthwise_conv = nn.Conv2d(mid_channels, mid_channels, 3, stride, 1, groups=mid_channels, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.se_layer = SqueezeExcitation(mid_channels, reduced_dim=int(mid_channels / expansion_factor))
        self.project_conv = nn.Conv2d(mid_channels, out_channels, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.activation = nn.SiLU()

    def forward(self, x):
        identity = x
        x = self.expand_conv(x)
        x = self.bn0(x)
        x = self.activation(x)
        x = self.depthwise_conv(x)
        x = self.bn1(x)
        x = self.activation(x)
        x = self.se_layer(x)
        x = self.project_conv(x)
        x = self.bn2(x)
        if self.use_residual:
            x += identity
        return x

class UNetWithMBConv(nn.Module):
    def __init__(self, in_channels=3, num_classes=19, initial_power=5):
        super().__init__()
        self.factor = 2**initial_power
        # Initial block (customized for the appropriate number of input channels, e.g., 3 for RGB images)
        self.encoder0 = MBConv(in_channels=in_channels, out_channels=(2**0)*self.factor, expansion_factor=1, stride=1)

        # Encoder: Increasing channels and reducing dimensions
        self.encoder1 = MBConv(in_channels=(2**0)*self.factor, out_channels=(2**1)*self.factor, expansion_factor=6, stride=1)
        self.encoder2 = MBConv(in_channels=(2**1)*self.factor, out_channels=(2**2)*self.factor, expansion_factor=6, stride=1)
        self.encoder3 = MBConv(in_channels=(2**2)*self.factor, out_channels=(2**3)*self.factor, expansion_factor=6, stride=1)

        # Bottleneck
        self.bottleneck = self.bottleneck_block(external_channels=(2**3)*self.factor,internal_channels=(2**4)*self.factor,kernel_size=3,stride=1,padding=1)

        # Maxpolling
        self.max = nn.MaxPool2d(2, stride=2)

        # Decoder and upsample
        self.decoder3 = self.double_conv((2**4)*self.factor, (2**3)*self.factor)
        self.upconv2 = nn.ConvTranspose2d((2**3)*self.factor, (2**2)*self.factor, 2, stride=2)
        self.decoder2 = self.double_conv((2**3)*self.factor, (2**2)*self.factor)
        self.upconv1 = nn.ConvTranspose2d((2**2)*self.factor, (2**1)*self.factor, 2, stride=2)
        self.decoder1 = self.double_conv((2**2)*self.factor, (2**1)*self.factor)
        self.upconv0 = nn.ConvTranspose2d((2**1)*self.factor, (2**0)*self.factor, 2, stride=2)
        self.decoder0 = self.double_conv((2**1)*self.factor, (2**0)*self.factor)

        # Final classifier
        self.final_conv = nn.Conv2d((2**0)*self.factor, num_classes, 1)

    # Helper function for double convolution
    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def bottleneck_block(self, external_channels, internal_channels, kernel_size=3, stride=1, padding=1):
        return nn.Sequential(
            nn.Conv2d(external_channels, internal_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(internal_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(internal_channels, internal_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(internal_channels),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(internal_channels, external_channels, kernel_size=2, stride=2)
        )

    def forward(self, x):
        # Encoder path
        enc0 = self.encoder0(x)     # size//(2**0)
        x = self.max(enc0)          # size//(2**1)
        enc1 = self.encoder1(x)     # size//(2**1)
        x = self.max(enc1)          # size//(2**2)         
        enc2 = self.encoder2(x)     # size//(2**2)
        x = self.max(enc2)          # size//(2**3)
        enc3 = self.encoder3(x)     # size//(2**3)
        x = self.max(enc3)          # size//(2**4)

        # Bottelneck
        bottelneck = self.bottleneck(x) # size//(2**3)

        # Decoder path
        dec3 = self.decoder3(torch.cat([bottelneck, enc3], dim=1)) # size//(2**3)
        dec2 = self.upconv2(dec3)           # size//(2**2) 
        dec2 = self.decoder2(torch.cat([dec2, enc2], dim=1))  # size//(2**2)
        dec1 = self.upconv1(dec2)           # size//(2**1)
        dec1 = self.decoder1(torch.cat([dec1, enc1], dim=1))  # size//(2**1)
        dec0 = self.upconv0(dec1)           # size//(2**0)
        dec0 = self.decoder0(torch.cat([dec0, enc0], dim=1))  # size//(2**0)

        # Output layer
        out = self.final_conv(dec0)
        return out
    
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))  # Uses xavier weights initialization (Var[s] = Var[x], since Var[w] = 1/n)
        if m.bias is not None:
            m.bias.data.zero_()
    if isinstance(m, nn.Linear):
        torch.nn.init.constant_(m.weight, 1) 
        if m.bias is not None:
            m.bias.data.zero_()

model = UNetWithMBConv(in_channels=3, num_classes=19, initial_power=3) 
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

# Apply the weights and biases before training
model.apply(init_weights)

# Plot the distribution of the weights
nb_epochs = 3
learning_rate = 0.01 # 0.01 for 20
step_size = 20

# Train the model
train_model_segmentation(model, trainloader, nb_epochs, learning_rate, step_size)