In [20]:
import torch
import torch.nn as nn
import math
import torchvision.transforms.functional as TF

In [3]:
def calculate_conv_out_shape(height, width, kernel, stride, padding):
  new_width = math.floor(((width - kernel + 2 * padding)/stride) + 1)
  new_height = math.floor(((height - kernel + 2 * padding)/stride) + 1)

  return new_width, new_height

In [7]:
calculate_conv_out_shape(572, 572, 3, 1, 1)

(572, 572)

In [4]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.batchNorm1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.batchNorm2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(self.batchNorm1(x))
        x = self.conv2(x)
        x = self.relu2(self.batchNorm2(x))

        return x

In [25]:
class Unet(nn.Module):
  def __init__(self, in_channels, out_channels, model_channels=[64, 128, 256, 512]):
    super(Unet, self).__init__()
    self.input_channel = in_channels
    self.output_channel = out_channels
    self.encoder = nn.ModuleList()
    self.decoder = nn.ModuleList()
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # halves the 2d tensors
    self.final_conv = nn.Conv2d(model_channels[0], out_channels, kernel_size=1) # restoring the channels at the final convolution

    # encoder part
    for channel in model_channels:
      self.encoder.append(DoubleConv(self.input_channel, channel))
      self.input_channel = channel
      print(self.input_channel)

    self.bottleneck = DoubleConv(model_channels[-1], model_channels[-1]*2) # Channels: 512 -> 1024

    # decoder part
    for channel in reversed(model_channels):
      self.decoder.append(nn.ConvTranspose2d(channel * 2, channel, kernel_size=2, stride=2)) # (B, 256, 64, 64) -> (B, 128, 128, 128)
      self.decoder.append(DoubleConv(channel * 2, channel))


  def forward(self, x):
    skip_connections = []

    for encoder_layer in self.encoder:
      x = encoder_layer(x)
      print('After one enc layer: ', x.shape)
      skip_connections.append(x)
      x = self.pool(x)
      print('After one maxpool layer: ', x.shape)

    x = self.bottleneck(x)
    print('After bottleneck: ', x.shape)
    print()
    skip_connections = skip_connections[::-1] # reversing the skip connection list

    print(len(self.decoder))
    print(len(skip_connections))

    for idx in range(0, len(self.decoder), 2):
      print('Before:', x.shape)
      x = self.decoder[idx](x) # transpose conv
      skip_connection = skip_connections[idx//2] # concatenation

      print('After x shape: ', x.shape)
      print('skip_connection shape: ', skip_connection.shape)

      if x.shape != skip_connection.shape:
        x = TF.resize(x, (skip_connection.shape[2], skip_connection.shape[3]))

      x = torch.cat((x, skip_connection),dim=1) # (B, 512, 64, 64) + (B, 512, 64, 64) = (B, 1024, 64, 64)
      x = self.decoder[idx+1](x) # double conv

    return self.final_conv(x) # final convolution to desired segmentation mask


In [None]:
def test():
    net = DoubleConv(3, 1024, 1, 1)
    print(net)
    y = net(torch.randn(4, 3, 256, 256)).to("cuda")
    print(y.size())
    print(y)

test()

In [26]:
def test():
    x = torch.randn((3, 1, 161, 161))
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)
    assert preds.shape == x.shape

if __name__ == "__main__":
    test()

64
128
256
512
After one enc layer:  torch.Size([3, 64, 161, 161])
After one maxpool layer:  torch.Size([3, 64, 80, 80])
After one enc layer:  torch.Size([3, 128, 80, 80])
After one maxpool layer:  torch.Size([3, 128, 40, 40])
After one enc layer:  torch.Size([3, 256, 40, 40])
After one maxpool layer:  torch.Size([3, 256, 20, 20])
After one enc layer:  torch.Size([3, 512, 20, 20])
After one maxpool layer:  torch.Size([3, 512, 10, 10])
After bottleneck:  torch.Size([3, 1024, 10, 10])

8
4
Before: torch.Size([3, 1024, 10, 10])
After x shape:  torch.Size([3, 512, 20, 20])
skip_connection shape:  torch.Size([3, 512, 20, 20])
Before: torch.Size([3, 512, 20, 20])
After x shape:  torch.Size([3, 256, 40, 40])
skip_connection shape:  torch.Size([3, 256, 40, 40])
Before: torch.Size([3, 256, 40, 40])
After x shape:  torch.Size([3, 128, 80, 80])
skip_connection shape:  torch.Size([3, 128, 80, 80])
Before: torch.Size([3, 128, 80, 80])
After x shape:  torch.Size([3, 64, 160, 160])
skip_connection sh



In [22]:
x = torch.randn((3, 64, 160, 160))
skip = torch.randn((3, 64, 161, 161))

x = TF.resize(x, (skip.shape[2], skip.shape[3]))
x.shape



torch.Size([3, 64, 161, 161])

In [24]:
x = torch.cat((x, skip), dim=1)
x.shape

torch.Size([3, 128, 161, 161])