In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torchsummary import summary

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [8]:
class Double_Conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0)
        self.batchNorm = nn.BatchNorm2d(out_channels)
    def forward(self, x):
        x = F.relu(self.batchNorm(self.conv1(x)))
        x = F.relu(self.batchNorm(self.conv2(x)))
        return x
    
class encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = Double_Conv(in_channels, out_channels)
        self.pool = nn.MaxPool2d((2,2))
    def forward(self, x):
        x = self.conv(x)
        p = self.pool(x)
        return x, p
    
class decoder(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
        self.conv = Double_Conv(out_channels+out_channels, out_channels)
    def forward(self, x, skip):
        x = self.deconv(x)
        skip = F.interpolate(skip, size=x.shape[2:], mode='nearest')
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x

class UNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.e1 = encoder(1, 64)
        self.e2 = encoder(64, 128)
        self.e3 = encoder(128, 256)
        self.e4 = encoder(256, 512)
        
        self.bottleneck = Double_Conv(512, 1024)

        self.d1 = decoder(1024, 512)
        self.d2 = decoder(512, 256)
        self.d3 = decoder(256, 128)
        self.d4 = decoder(128, 64)

        self.conv = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, x):
        
        s1, x = self.e1(x)
        s2, x = self.e2(x)
        s3, x = self.e3(x)
        s4, x = self.e4(x)
        x = self.bottleneck(x)
        x = self.d1(x, s4)
        x = self.d2(x, s3)
        x = self.d3(x, s2)
        x = self.d4(x, s1)

        return self.conv(x)
    
model = UNet().to(device)

In [9]:
summary(model, (1, 572, 572))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 570, 570]             640
       BatchNorm2d-2         [-1, 64, 570, 570]             128
            Conv2d-3         [-1, 64, 568, 568]          36,928
       BatchNorm2d-4         [-1, 64, 568, 568]             128
       Double_Conv-5         [-1, 64, 568, 568]               0
         MaxPool2d-6         [-1, 64, 284, 284]               0
           encoder-7  [[-1, 64, 568, 568], [-1, 64, 284, 284]]               0
            Conv2d-8        [-1, 128, 282, 282]          73,856
       BatchNorm2d-9        [-1, 128, 282, 282]             256
           Conv2d-10        [-1, 128, 280, 280]         147,584
      BatchNorm2d-11        [-1, 128, 280, 280]             256
      Double_Conv-12        [-1, 128, 280, 280]               0
        MaxPool2d-13        [-1, 128, 140, 140]               0
          encoder-14  [[

In [10]:
# GPU memory Cache Clear
import gc
gc.collect()
torch.cuda.empty_cache()