In [1]:
# Codeblock 1
import torch
import torch.nn as nn
from torchinfo import summary

**Note**: several prints are commented out after being run to make the output more concise.

In [4]:
# Codeblock 2
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv_0 = nn.Conv2d(in_channels=in_channels, 
                                out_channels=out_channels, 
                                kernel_size=3, bias=False)    #(1)
        self.bn_0 = nn.BatchNorm2d(num_features=out_channels)
        
        self.conv_1 = nn.Conv2d(in_channels=out_channels,
                                out_channels=out_channels, 
                                kernel_size=3, bias=False)    #(2)
        self.bn_1 = nn.BatchNorm2d(num_features=out_channels)
        
        self.relu = nn.ReLU(inplace=True)

# Codeblock 3
    def forward(self, x):
        #print(f'original\t\t: {x.size()}')
        
        x = self.conv_0(x)
        x = self.bn_0(x)
        x = self.relu(x)
        #print(f'after first conv\t: {x.size()}')
        
        x = self.conv_1(x)
        x = self.bn_1(x)
        x = self.relu(x)
        #print(f'after second conv\t: {x.size()}')
        
        return x

In [3]:
# Codeblock 4
double_conv = DoubleConv(in_channels=1, out_channels=64)    #(1)
x = torch.randn((1, 1, 572, 572))    #(2)
x = double_conv(x).size()

original		: torch.Size([1, 1, 572, 572])
after first conv	: torch.Size([1, 64, 570, 570])
after second conv	: torch.Size([1, 64, 568, 568])


In [7]:
# Codeblock 5
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.double_conv = DoubleConv(in_channels=in_channels, 
                                      out_channels=out_channels)    #(1)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)    #(2)
    
    def forward(self, x):
        #print(f'original\t\t: {x.size()}')
        
        convolved = self.double_conv(x)
        #print(f'after double conv\t: {convolved.size()}')
        
        maxpooled = self.maxpool(convolved)
        #print(f'after pooling\t\t: {maxpooled.size()}')
        
        return convolved, maxpooled    #(3)

In [6]:
# Codeblock 6
down_sample = DownSample(in_channels=1, out_channels=64)
x = torch.randn((1, 1, 572, 572))
x = down_sample(x)

original		: torch.Size([1, 1, 572, 572])
after double conv	: torch.Size([1, 64, 568, 568])
after pooling		: torch.Size([1, 64, 284, 284])


In [8]:
# Codeblock 7
def crop_image(original, expected):    #(1)
    
    original_dim = original.size()[-1]    #(2)
    expected_dim = expected.size()[-1]    #(3)
    
    difference = original_dim - expected_dim    #(4)
    padding = difference // 2    #(5)
    
    cropped = original[:, :, padding:original_dim-padding, padding:original_dim-padding]    #(6)
    
    return cropped

In [11]:
# Codeblock 8
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv_transpose = nn.ConvTranspose2d(in_channels=in_channels,
                                                 out_channels=out_channels, 
                                                 kernel_size=2, stride=2)    #(1)
        self.double_conv = DoubleConv(in_channels=in_channels,
                                      out_channels=out_channels)
        
# Codeblock 9
    def forward(self, x, connection):    #(1)
        #print(f'x original\t\t\t: {x.size()}')
        #print(f'connection original\t\t: {connection.size()}')
        
        x = self.conv_transpose(x)    #(2)
        #print(f'x after conv transpose\t\t: {x.size()}')
        
        cropped_connection = crop_image(connection, x)    #(3)
        #print(f'connection after cropped\t: {x.size()}')
        
        x = torch.cat([x, cropped_connection], dim=1)    #(4)
        #print(f'after concatenation\t\t: {x.size()}')
        
        x = self.double_conv(x)    #(5)
        #print(f'after double conv\t\t: {x.size()}')
        
        return x

In [10]:
# Codeblock 10
up_sample = UpSample(1024, 512)    #(1)

x = torch.randn((1, 1024, 28, 28))    #(2)
connection = torch.randn((1, 512, 64, 64))    #(3)

x = up_sample(x, connection)

x original			: torch.Size([1, 1024, 28, 28])
connection original		: torch.Size([1, 512, 64, 64])
x after conv transpose		: torch.Size([1, 512, 56, 56])
connection after cropped	: torch.Size([1, 512, 56, 56])
after concatenation		: torch.Size([1, 1024, 56, 56])
after double conv		: torch.Size([1, 512, 52, 52])


In [16]:
# Codeblock 11
class UNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=2):    #(1)
        super().__init__()
        
        # Encoder    #(2)
        self.downsample_0 = DownSample(in_channels=in_channels, out_channels=64)
        self.downsample_1 = DownSample(in_channels=64, out_channels=128)
        self.downsample_2 = DownSample(in_channels=128, out_channels=256)
        self.downsample_3 = DownSample(in_channels=256, out_channels=512)
        
        # Bottleneck    #(3)
        self.bottleneck   = DoubleConv(in_channels=512, out_channels=1024)
        
        # Decoder    #(4)
        self.upsample_0   = UpSample(in_channels=1024, out_channels=512)
        self.upsample_1   = UpSample(in_channels=512, out_channels=256)
        self.upsample_2   = UpSample(in_channels=256, out_channels=128)
        self.upsample_3   = UpSample(in_channels=128, out_channels=64)
        
        # Output    #(5)
        self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)
        
# Codeblock 12
    def forward(self, x):
        #print(f'original\t\t: {x.size()}')
        
        convolved_0, maxpooled_0 = self.downsample_0(x)    #(1)
        #print(f'maxpooled_0\t\t: {maxpooled_0.size()}')
        
        convolved_1, maxpooled_1 = self.downsample_1(maxpooled_0)    #(2)
        #print(f'maxpooled_1\t\t: {maxpooled_1.size()}')
        
        convolved_2, maxpooled_2 = self.downsample_2(maxpooled_1)    #(3)
        #print(f'maxpooled_2\t\t: {maxpooled_2.size()}')
        
        convolved_3, maxpooled_3 = self.downsample_3(maxpooled_2)    #(4)
        #print(f'maxpooled_3\t\t: {maxpooled_3.size()}')
        
        x = self.bottleneck(maxpooled_3)
        #print(f'after bottleneck\t: {x.size()}')
        
        upsampled_0 = self.upsample_0(x, convolved_3)    #(5)
        #print(f'upsampled_0\t\t: {upsampled_0.size()}')
        
        upsampled_1 = self.upsample_1(upsampled_0, convolved_2)    #(6)
        #print(f'upsampled_1\t\t: {upsampled_1.size()}')

        upsampled_2 = self.upsample_2(upsampled_1, convolved_1)
        #print(f'upsampled_2\t\t: {upsampled_2.size()}')
        
        upsampled_3 = self.upsample_3(upsampled_2, convolved_0)
        #print(f'upsampled_3\t\t: {upsampled_3.size()}')
        
        x = self.output(upsampled_3)
        #print(f'final output\t\t: {x.size()}')
        
        return x

In [13]:
# Codeblock 13
unet = UNet()
x = torch.randn((1, 1, 572, 572))
x = unet(x)

original		: torch.Size([1, 1, 572, 572])
maxpooled_0		: torch.Size([1, 64, 284, 284])
maxpooled_1		: torch.Size([1, 128, 140, 140])
maxpooled_2		: torch.Size([1, 256, 68, 68])
maxpooled_3		: torch.Size([1, 512, 32, 32])
after bottleneck	: torch.Size([1, 1024, 28, 28])
upsampled_0		: torch.Size([1, 512, 52, 52])
upsampled_1		: torch.Size([1, 256, 100, 100])
upsampled_2		: torch.Size([1, 128, 196, 196])
upsampled_3		: torch.Size([1, 64, 388, 388])
final output		: torch.Size([1, 2, 388, 388])


In [17]:
# Codeblock 14
summary(unet, input_size=(1,1,572,572))

original		: torch.Size([1, 1, 572, 572])
maxpooled_0		: torch.Size([1, 64, 284, 284])
maxpooled_1		: torch.Size([1, 128, 140, 140])
maxpooled_2		: torch.Size([1, 256, 68, 68])
maxpooled_3		: torch.Size([1, 512, 32, 32])
after bottleneck	: torch.Size([1, 1024, 28, 28])
upsampled_0		: torch.Size([1, 512, 52, 52])
upsampled_1		: torch.Size([1, 256, 100, 100])
upsampled_2		: torch.Size([1, 128, 196, 196])
upsampled_3		: torch.Size([1, 64, 388, 388])
final output		: torch.Size([1, 2, 388, 388])


Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     [1, 2, 388, 388]          --
├─DownSample: 1-1                        [1, 64, 568, 568]         --
│    └─DoubleConv: 2-1                   [1, 64, 568, 568]         --
│    │    └─Conv2d: 3-1                  [1, 64, 570, 570]         576
│    │    └─BatchNorm2d: 3-2             [1, 64, 570, 570]         128
│    │    └─ReLU: 3-3                    [1, 64, 570, 570]         --
│    │    └─Conv2d: 3-4                  [1, 64, 568, 568]         36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 568, 568]         128
│    │    └─ReLU: 3-6                    [1, 64, 568, 568]         --
│    └─MaxPool2d: 2-2                    [1, 64, 284, 284]         --
├─DownSample: 1-2                        [1, 128, 280, 280]        --
│    └─DoubleConv: 2-3                   [1, 128, 280, 280]        --
│    │    └─Conv2d: 3-7                  [1, 128, 282, 282]        73,728
│   