# U-Net

U-Net is a Deep Learning architecture used for semantic segmentation tasks in image analysis.

First intorduced in a paper titled in a paper titled "U-Net: Convolutional Networks for Biomedical Image Segmentation".

## Architecture

![Unet Diagram](./../../resources/images/unet1.png)

<b>Encoder(Contraction Path)</b>

The encoder is a series of convolutional and pooling layers that progressively downsample the input image to extract features at multiple scales.

In the Encoder, the size of the image is gradually reduced while the depth gradually increases. The subsequent feature maps of the convolutons learn the general features of the image.

<b>Decoder(Expansion Path)</b>

The decoder consists of a series of convolutional and upsampling layers that upsample the feature maps to the original input image size while also incorporating the high-resolution features from the encoder. This allows the decoder to produce segmentation masks that have the same size as the original input image.


In the Decoder, the size of the image gradually increases while the depth gradually decreases. By upsampling, the general fatures of the image learned by the convolutional layers are "redistributed" back in an image of the same size.

In [1]:
import torch
import torch.nn as nn
from torchvision import models
from torch.nn.functional import relu
from torch.utils.data import Subset

In [2]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        # Encoder
        # In the encoder, convolutional layers with the Conv2d function are used to extract features from the input image. 
        # Each block in the encoder consists of two convolutional layers followed by a max-pooling layer, 
        # with the exception of the last block which does not include a max-pooling layer.
        # -------
        # input: 572x572x3
        self.e11 = nn.Conv2d(self.n_channels, 64, kernel_size=3, padding='same') # output: 570x570x64
        self.e12 = nn.Conv2d(64, 64, kernel_size=3, padding='same') # output: 568x568x64
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 284x284x64

        # input: 284x284x64
        self.e21 = nn.Conv2d(64, 128, kernel_size=3, padding='same') # output: 282x282x128
        self.e22 = nn.Conv2d(128, 128, kernel_size=3, padding='same') # output: 280x280x128
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 140x140x128

        # input: 140x140x128
        self.e31 = nn.Conv2d(128, 256, kernel_size=3, padding='same') # output: 138x138x256
        self.e32 = nn.Conv2d(256, 256, kernel_size=3, padding='same') # output: 136x136x256
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 68x68x256

        # input: 68x68x256
        self.e41 = nn.Conv2d(256, 512, kernel_size=3, padding='same') # output: 66x66x512
        self.e42 = nn.Conv2d(512, 512, kernel_size=3, padding='same') # output: 64x64x512
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 32x32x512

        # input: 32x32x512
        self.e51 = nn.Conv2d(512, 1024, kernel_size=3, padding='same') # output: 30x30x1024
        self.e52 = nn.Conv2d(1024, 1024, kernel_size=3, padding='same') # output: 28x28x1024

        # Decoder
        self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) # output: 56x56x512
        self.d11 = nn.Conv2d(1024, 512, kernel_size=3, padding='same') # output: 54x54x512
        self.d12 = nn.Conv2d(512, 512, kernel_size=3, padding='same') # output: 52x52x512

        self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) # output: 104x104x256
        self.d21 = nn.Conv2d(512, 256, kernel_size=3, padding='same') # output: 102x102x256
        self.d22 = nn.Conv2d(256, 256, kernel_size=3, padding='same') # output: 100x100x256

        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) # output: 200x200x128
        self.d31 = nn.Conv2d(256, 128, kernel_size=3, padding='same') # output: 198x198x128
        self.d32 = nn.Conv2d(128, 128, kernel_size=3, padding='same') # output: 196x196x128

        self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) # output: 392x392x64
        self.d41 = nn.Conv2d(128, 64, kernel_size=3, padding='same') # output: 390x390x64
        self.d42 = nn.Conv2d(64, 64, kernel_size=3, padding='same') # output: 388x388x64

        # Output layer
        self.outconv = nn.Conv2d(64, n_classes, kernel_size=1) # output: 388x388xn_classes

    def forward(self, x):
        # Encoder
        xe11 = relu(self.e11(x))
        #print(xe11.shape)
        xe12 = relu(self.e12(xe11))
        #print(xe12.shape)
        xp1 = self.pool1(xe12)
        #print(xp1.shape)

        xe21 = relu(self.e21(xp1))
        #print(xe21.shape)
        xe22 = relu(self.e22(xe21))
        #print(xe22.shape)
        xp2 = self.pool2(xe22)
        #print(xp2.shape)

        xe31 = relu(self.e31(xp2))
        #print(xe31.shape)
        xe32 = relu(self.e32(xe31))
        #print(xe32.shape)
        xp3 = self.pool3(xe32)
        #print(xp3.shape)

        xe41 = relu(self.e41(xp3))
        #print(xe41.shape)
        xe42 = relu(self.e42(xe41))
        #print(xe42.shape)
        xp4 = self.pool4(xe42)
        #print(xp4.shape)

        xe51 = relu(self.e51(xp4))
        #print(xe51.shape)
        xe52 = relu(self.e52(xe51))
        #print(xe52.shape)

        # Decoder
        xu1 = self.upconv1(xe52)
        #print(xu1.shape)
        xu11 = torch.cat([xu1, xe42], dim=1)
        xd11 = relu(self.d11(xu11))
        xd12 = relu(self.d12(xd11))

        xu2 = self.upconv2(xd12)
        xu22 = torch.cat([xu2, xe32], dim=1)
        xd21 = relu(self.d21(xu22))
        xd22 = relu(self.d22(xd21))

        xu3 = self.upconv3(xd22)
        xu33 = torch.cat([xu3, xe22], dim=1)
        xd31 = relu(self.d31(xu33))
        xd32 = relu(self.d32(xd31))

        xu4 = self.upconv4(xd32)
        xu44 = torch.cat([xu4, xe12], dim=1)
        xd41 = relu(self.d41(xu44))
        xd42 = relu(self.d42(xd41))

        # Output layer
        out = self.outconv(xd42)

        return out

    


## Training

In [3]:
from data_loading import BasicDataset, CarvanaDataset
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import wandb
import logging
from pathlib import Path
from dice_score import dice_loss
from evaluate import evaluate
import torch.nn.functional as F
import os

In [4]:
dir_img = Path('/Users/mp/viscode-github/datasets/kaggle/carvana-image-masking-challenge/train')
dir_mask = Path('/Users/mp/viscode-github/datasets/kaggle/carvana-image-masking-challenge/train_masks')
dir_checkpoint = Path('./checkpoints/')

In [5]:
def train(
        model,
        dataset: BasicDataset,
        device,
        epochs: int = 5,
        batch_size: int = 32,
        learning_rate: float = 1e-5,
        val_percent: float = 0.1,
        #save_checkpoint: bool = True,
        img_scale: float = 0.5,
        #amp: bool = False,
        weight_decay: float = 1e-8,
        momentum: float = 0.999,
        #gradient_clipping: float = 1.0,
):
    
    dataset = dataset
    print('split')
    train_fraction = 0.1
    # 2. Split into train / validation partitions
    indices = torch.randperm(len(dataset)) # Select 700 random samples
    small_dataset = Subset(dataset, indices)
    n_train = int(len(small_dataset) * (1 - val_percent))
    n_val = len(small_dataset) - n_train
    
    train_set, val_set = random_split(small_dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    print(f'len train_set: {len(train_set)}')
    print(f'len val_set: {len(val_set)}')

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count())
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling
    optimizer = optim.RMSprop(model.parameters(),
                              lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)  # goal: maximize Dice score

    criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
    global_step = 0

    print('begin training')
    # 5. Begin training
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images, true_masks = batch['image'], batch['mask']

                assert images.shape[1] == model.n_channels, \
                    f'Network has been defined with {model.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last, non_blocking=True)
                true_masks = true_masks.to(device=device, dtype=torch.long, non_blocking=True)

                masks_pred = model(images)

                loss = criterion(masks_pred, true_masks)
                loss += dice_loss(
                        F.softmax(masks_pred, dim=1).float(),
                        F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
                        multiclass=True
                )

                optimizer.zero_grad(set_to_none=True)

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                division_step = (n_train // (5 * batch_size))
                if division_step > 0:
                    if global_step % division_step == 0:
                        print('Evaluating...')
                        val_score = evaluate(model, val_loader, device)
                        scheduler.step(val_score)
                        print('Validation Dice score: {}'.format(val_score))

In [6]:
device = torch.device('mps' if torch.mps.is_available() else 'cpu')
model = UNet(n_channels=3, n_classes=2)
model = model.to(memory_format=torch.channels_last)
model.to(device=device)

UNet(
  (e11): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (e12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (e21): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (e22): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (e31): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (e32): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (e41): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (e42): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (e51): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1

In [7]:
img_scale = 0.25
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)

100%|██████████| 5088/5088 [00:12<00:00, 399.06it/s]


In [8]:
train(model, dataset, device)

split
len train_set: 4579
len val_set: 509
begin training


Epoch 1/5:  20%|█▉        | 896/4579 [05:09<15:18,  4.01img/s, loss (batch)=1.13]

Evaluating...


Epoch 1/5:  20%|██        | 928/4579 [05:47<37:02,  1.64img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 1/5:  39%|███▉      | 1792/4579 [08:28<08:33,  5.42img/s, loss (batch)=1.13]

Evaluating...


Epoch 1/5:  40%|███▉      | 1824/4579 [09:04<24:27,  1.88img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 1/5:  59%|█████▊    | 2688/4579 [11:59<06:27,  4.88img/s, loss (batch)=1.14]

Evaluating...


Epoch 1/5:  59%|█████▉    | 2720/4579 [12:34<16:47,  1.85img/s, loss (batch)=1.14]

Validation Dice score: 3.2454001291926104e-11


Epoch 1/5:  78%|███████▊  | 3584/4579 [18:26<04:37,  3.59img/s, loss (batch)=1.14]

Evaluating...


Epoch 1/5:  79%|███████▉  | 3616/4579 [19:01<09:18,  1.72img/s, loss (batch)=1.14]

Validation Dice score: 3.2454001291926104e-11


Epoch 1/5:  98%|█████████▊| 4480/4579 [21:20<00:20,  4.82img/s, loss (batch)=1.13]

Evaluating...


Epoch 1/5:  99%|█████████▊| 4512/4579 [21:59<00:36,  1.82img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 1/5: 100%|██████████| 4579/4579 [22:19<00:00,  3.42img/s, loss (batch)=1.13]
Epoch 2/5:  17%|█▋        | 768/4579 [03:39<14:58,  4.24img/s, loss (batch)=1.13]

Evaluating...


Epoch 2/5:  17%|█▋        | 800/4579 [04:12<33:43,  1.87img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 2/5:  36%|███▋      | 1664/4579 [06:08<05:39,  8.58img/s, loss (batch)=1.13]

Evaluating...


Epoch 2/5:  37%|███▋      | 1696/4579 [06:39<19:03,  2.52img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 2/5:  56%|█████▌    | 2560/4579 [08:16<03:18, 10.16img/s, loss (batch)=1.13]

Evaluating...


Epoch 2/5:  57%|█████▋    | 2592/4579 [08:47<12:43,  2.60img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 2/5:  75%|███████▌  | 3456/4579 [12:19<04:40,  4.01img/s, loss (batch)=1.14]

Evaluating...


Epoch 2/5:  76%|███████▌  | 3488/4579 [12:55<10:18,  1.76img/s, loss (batch)=1.14]

Validation Dice score: 3.2454001291926104e-11


Epoch 2/5:  95%|█████████▌| 4352/4579 [16:17<00:51,  4.41img/s, loss (batch)=1.14]

Evaluating...


Epoch 2/5:  96%|█████████▌| 4384/4579 [16:57<01:54,  1.71img/s, loss (batch)=1.14]

Validation Dice score: 3.2454001291926104e-11


Epoch 2/5: 100%|██████████| 4579/4579 [17:27<00:00,  4.37img/s, loss (batch)=1.12]
Epoch 3/5:  14%|█▍        | 640/4579 [02:47<11:17,  5.82img/s, loss (batch)=1.13]

Evaluating...


Epoch 3/5:  15%|█▍        | 672/4579 [03:21<32:04,  2.03img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 3/5:  34%|███▎      | 1536/4579 [05:12<05:02, 10.06img/s, loss (batch)=1.13]

Evaluating...


Epoch 3/5:  34%|███▍      | 1568/4579 [05:55<25:09,  1.99img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 3/5:  53%|█████▎    | 2432/4579 [09:37<06:36,  5.42img/s, loss (batch)=1.14]

Evaluating...


Epoch 3/5:  54%|█████▍    | 2464/4579 [10:16<18:59,  1.86img/s, loss (batch)=1.14]

Validation Dice score: 3.2454001291926104e-11


Epoch 3/5:  73%|███████▎  | 3328/4579 [13:57<03:28,  6.01img/s, loss (batch)=1.13]

Evaluating...


Epoch 3/5:  73%|███████▎  | 3360/4579 [14:31<10:18,  1.97img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 3/5:  92%|█████████▏| 4224/4579 [16:51<01:08,  5.19img/s, loss (batch)=1.13]

Evaluating...


Epoch 3/5:  93%|█████████▎| 4256/4579 [17:32<03:04,  1.75img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 3/5: 100%|██████████| 4579/4579 [18:19<00:00,  4.16img/s, loss (batch)=1.13]
Epoch 4/5:  11%|█         | 512/4579 [02:13<17:12,  3.94img/s, loss (batch)=1.13] 

Evaluating...


Epoch 4/5:  12%|█▏        | 544/4579 [03:01<48:27,  1.39img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 4/5:  31%|███       | 1408/4579 [05:35<08:49,  5.99img/s, loss (batch)=1.13]

Evaluating...


Epoch 4/5:  31%|███▏      | 1440/4579 [06:11<26:12,  2.00img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 4/5:  50%|█████     | 2304/4579 [07:43<03:36, 10.52img/s, loss (batch)=1.13]

Evaluating...


Epoch 4/5:  51%|█████     | 2336/4579 [08:13<13:50,  2.70img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 4/5:  70%|██████▉   | 3200/4579 [10:50<04:14,  5.41img/s, loss (batch)=1.13]

Evaluating...


Epoch 4/5:  71%|███████   | 3232/4579 [11:28<11:56,  1.88img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 4/5:  89%|████████▉ | 4096/4579 [13:20<00:46, 10.38img/s, loss (batch)=1.13]

Evaluating...


Epoch 4/5:  90%|█████████ | 4128/4579 [13:50<02:48,  2.68img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 4/5: 100%|██████████| 4579/4579 [14:43<00:00,  5.18img/s, loss (batch)=1.14]
Epoch 5/5:   8%|▊         | 384/4579 [01:27<12:46,  5.47img/s, loss (batch)=1.13]

Evaluating...


Epoch 5/5:   9%|▉         | 416/4579 [02:04<36:47,  1.89img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 5/5:  28%|██▊       | 1280/4579 [04:55<09:53,  5.56img/s, loss (batch)=1.13]

Evaluating...


Epoch 5/5:  29%|██▊       | 1312/4579 [05:33<28:19,  1.92img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 5/5:  48%|████▊     | 2176/4579 [08:17<07:44,  5.17img/s, loss (batch)=1.14]

Evaluating...


Epoch 5/5:  48%|████▊     | 2208/4579 [08:52<20:35,  1.92img/s, loss (batch)=1.14]

Validation Dice score: 3.2454001291926104e-11


Epoch 5/5:  67%|██████▋   | 3072/4579 [11:33<04:28,  5.61img/s, loss (batch)=1.13]

Evaluating...


Epoch 5/5:  68%|██████▊   | 3104/4579 [12:09<12:26,  1.98img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 5/5:  87%|████████▋ | 3968/4579 [14:54<02:02,  4.97img/s, loss (batch)=1.13]

Evaluating...


Epoch 5/5:  87%|████████▋ | 4000/4579 [15:32<05:18,  1.82img/s, loss (batch)=1.13]

Validation Dice score: 3.2454001291926104e-11


Epoch 5/5: 100%|██████████| 4579/4579 [16:39<00:00,  4.58img/s, loss (batch)=1.13]


In [1]:
import gc
import torch
torch.mps.empty_cache()
gc.collect()

20

In [64]:
train_model(
            model=model,
            device=device,
        )

100%|██████████| 5088/5088 [00:12<00:00, 422.28it/s]


device train: cpu


Epoch 1/1:   0%|          | 0/4580 [00:07<?, ?img/s]


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 119 for tensor number 1 in the list.