In [15]:
import albumentations as A

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [34]:
class CNA(nn.Module):
    def __init__(self, in_nc, out_nc, stride=1):
        super().__init__()
        
        self.conv = nn.Conv2d(in_nc, out_nc, 3, stride=stride, padding=1, bias=False)
        self.norm = nn.BatchNorm2d(out_nc)
        self.act = nn.GELU()
        
    def forward(self, x):
        out = self.conv(x)
        out = self.norm(out)
        out = self.act(out)
        
        return out

In [49]:
class UnetBlock(nn.Module):
    def __init__(self, in_nc, inner_nc, out_nc, inner_block=None):
        super().__init__()
        
        self.conv1 = CNA(in_nc, inner_nc, stride=2)
        self.conv2 = CNA(inner_nc, inner_nc)
        self.inner_block = inner_block
        self.conv3 = CNA(inner_nc, inner_nc)
        self.conv_cat = nn.Conv2d(inner_nc+in_nc, out_nc, 3, padding=1)
    
    def forward(self, x):
        _,_,h,w = x.shape
        
        inner = self.conv1(x)
        inner = self.conv2(inner)
        if self.inner_block is not None:
            inner = self.inner_block(inner)
        inner = self.conv3(inner)
        
        inner = F.upsample(inner, size=(h,w), mode='bilinear')
        inner = torch.cat((x, inner), axis=1)
        out = self.conv_cat(inner)
        
        return out
        

In [50]:
class Unet(nn.Module):
    def __init__(self, nc):
        super().__init__()
        
        self.cna1 = CNA(1, nc)
        self.cna2 = CNA(nc, nc)
        
        unet_block = UnetBlock(8*nc, 8*nc, 8*nc)
        unet_block = UnetBlock(4*nc, 8*nc, 4*nc, unet_block)
        unet_block = UnetBlock(2*nc, 4*nc, 2*nc, unet_block)
        self.unet_block = UnetBlock(nc, 2*nc, nc, unet_block)
        
        self.cna3 = CNA(nc, nc)
        
        self.conv_last = nn.Conv2d(nc, 1, 3, padding=1)
        
    def forward(self, x):
        out = self.cna1(x)
        out = self.cna2(out)
        out = self.unet_block(out)
        out = self.cna3(out)
        out = self.conv_last(out)
        return out


In [51]:
unet_model = Unet(32)

In [52]:
tensor = torch.rand((1, 1, 64, 64))

In [54]:
res = unet_model(tensor)
res

tensor([[[[ 0.0187,  0.0090, -0.0205,  ..., -0.0490, -0.0365, -0.0842],
          [ 0.5774,  0.4081,  0.2086,  ..., -0.2683,  0.0204, -0.1817],
          [ 0.0774, -0.0725,  0.1987,  ...,  0.0563,  0.0691,  0.0029],
          ...,
          [ 0.1125, -0.4866, -0.1995,  ..., -0.3262, -0.3429, -0.5147],
          [ 0.1388,  0.1593,  0.1902,  ..., -0.2675, -0.3342,  0.0208],
          [ 0.3765,  0.4812,  0.3691,  ..., -0.0015,  0.1953, -0.0622]]]],
       grad_fn=<ConvolutionBackward0>)