In [1]:
import torch
import torch.nn as nn
from torchsummary import summary
import gc
torch.cuda.empty_cache()
gc.collect()

0

# U-Net

 - Originally introduyced as an image segmentation tool for detecting tumours
 - U-Net is computationally less expensive and minimises information loss, compared to predecessors

## Key Operations

### Convolutions

 - Retains the influence of all input pixels but keeps them only loosely connected to reduce computation cost
 - Passes a filter matrix K over the image
 - Consider an input matrix with dimensions (Height x Width x Depth/Channels) A x B x C
 - Consider a filter matrix with dimensions (Height, Width, Depth (same as image), Number of filters) f x f x C x G
 - We have padding p and stride s
 - The output matrix will have dimensions H x W x G
   - H = $\lfloor\frac{A+2p-f}{s}+1\rfloor$

   - G = $\lfloor\frac{B+2p-f}{s}+1\rfloor$

### Transposed Convolutions (up-conv)

 - Transposed convolutions upscale images (compared to standard convolutions which reduce resolution)
 - This is achieved by using a filter bigger than the image

### Pooling (max pool)

 - Pooling is used for the same purpose as convolution (reducing parameters)
 - Also provides regularisation
 - Average or max
 - We create subsets of the input based on filter size f and stride s

### Skip Connections (copy and crop)

 - These copy the image matrix from earlier layers and uses it as a part of the later layers
 - Enables the preservation of image from a richer matrix and prevents information loss

## Defining the Architecture

In [2]:
# conv(3x3) -> batch norm -> relu -> conv(3x3) -> batch norm -> relu
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        # x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        # x = self.bn2(x)
        x = self.relu(x)
        return x

In [3]:
# conv block -> maxpool(2,2)
class EncBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = ConvBlock(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))
        
    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p

In [4]:
# convtransp -> conv
class DecBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = ConvBlock(out_c + out_c, out_c)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x

In [5]:
# in -> encblock -> encblock -> encblock -> encblock -> convblock <- decblock <- decblock <- decblock <- decblock <- out
class UNET(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.e1 = EncBlock(3, 64)
        self.e2 = EncBlock(64, 128)
        self.e3 = EncBlock(128, 256)
        self.e4 = EncBlock(256, 512)
        # Bottleneck
        self.b = ConvBlock(512, 1024)
        # Decoder
        self.d1 = DecBlock(1024, 512)
        self.d2 = DecBlock(512, 256)
        self.d3 = DecBlock(256, 128)
        self.d4 = DecBlock(128, 64)
        # Classifier
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        # Encoder
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)
        # Bottlenek
        b = self.b(p4)
        # Decoder
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)
        # Classifier
        outputs = self.outputs(d4)
        return outputs

In [6]:
with torch.no_grad():
    dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    inputs = torch.randn((2, 3, 512, 512))
    inputs = inputs.to(dev)
    model = UNET()
    model = model.to(dev)
    y = model(inputs)
    print(y.shape)
    summary(model, input_size=(3,512,512))

torch.Size([2, 1, 512, 512])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 512, 512]           1,792
              ReLU-2         [-1, 64, 512, 512]               0
            Conv2d-3         [-1, 64, 512, 512]          36,928
              ReLU-4         [-1, 64, 512, 512]               0
         ConvBlock-5         [-1, 64, 512, 512]               0
         MaxPool2d-6         [-1, 64, 256, 256]               0
          EncBlock-7  [[-1, 64, 512, 512], [-1, 64, 256, 256]]               0
            Conv2d-8        [-1, 128, 256, 256]          73,856
              ReLU-9        [-1, 128, 256, 256]               0
           Conv2d-10        [-1, 128, 256, 256]         147,584
             ReLU-11        [-1, 128, 256, 256]               0
        ConvBlock-12        [-1, 128, 256, 256]               0
        MaxPool2d-13        [-1, 128, 128, 128]            