In [1]:
import torch
import torch.nn as nn

In [8]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_down = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        down = self.conv_down(x)
        p = self.pool(down)

        return down, p
    
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        self.conv_up = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], 1)

        return self.conv_up(x)
    




In [11]:
class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.downconv_1 = DownSample(in_channels, 64)
        self.downconv_2 = DownSample(64, 128)
        self.downconv_3 = DownSample(128, 256)
        self.downconv_4 = DownSample(256, 512)

        self.bottle_neck = DoubleConv(512, 1024)

        self.upconv_1 = UpSample(1024, 512)
        self.upconv_2 = UpSample(512, 256)
        self.upconv_3 = UpSample(256, 128)
        self.upconv_4 = UpSample(128, 64)

        self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)


    def forward(self, x):
        downer_1, p1 = self.downconv_1(x)
        downer_2, p2 = self.downconv_2(p1)
        downer_3, p3 = self.downconv_3(p2)
        downer_4, p4 = self.downconv_4(p3)

        last = self.bottle_neck(p4)

        up_1 = self.upconv_1(last, downer_4)
        up_2 = self.upconv_2(up_1, downer_3)
        up_3 = self.upconv_3(up_2, downer_2)
        up_4 = self.upconv_4(up_3, downer_1)

        bigoutput = self.output(up_4)

        return bigoutput
    

if __name__ == "__main__":
    double_conv = DoubleConv(256, 256)
    print(double_conv)

    input_image = torch.rand((1, 3, 512, 512))
    model = UNet(3, 10)
    print(model)





DoubleConv(
  (conv): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
)
UNet(
  (downconv_1): DownSample(
    (conv_down): DoubleConv(
      (conv): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU(inplace=True)
      )
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (downconv_2): DownSample(
    (conv_down): DoubleConv(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU(inplace=True)
      )
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (downconv_3): Down