In [2]:
import albumentations as A

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class CNA(nn.Module):
    def __init__(self, in_nc, out_nc, stride=1):
        super().__init__()
        
        self.conv = nn.Conv2d(in_nc, out_nc, 3, stride=stride, padding=1, dias=False)
        self.norm = nn.BatchNorm2d(out_nc)
        self.act = nn.GELU()
        

In [None]:
class UnetBlock:
    def __init__(self, in_nc, out_nc, stride=1, inner_block=None):
        super().__init__()
        
        self.conv1 = CNA(in_nc, int_nc, stride=2)
        self.conv2 = CNA(in_nc, in_nc)
        self.inner_block = inner_block
        self.conv3 = CNA(in_nc, in_nc)
        self.conv_cat = nn.Conv2d(in_nc, out_nc, 3, padding=1)
    
    def forward(self, x):
        _,_,h,w = x.shape
        
        inner = self.conv1(x)
        inner = self.conv2(inner)
        inner = self.inner_block(inner)
        inner = self.conv3(inner)
        
        inner = F.upsample(inner, size=(h,w), mode='bilinear')
        inner = torch.cat((x, inner), axis=1)
        out = self.conv_cat(inner)
        
        return out
        

In [7]:
class Unet(nn.Module):
    def __init__(self, nc):
        super().__init__()
        
        self.act = nn.GELU()
        
        self.conv1 = nn.Conv2D(1, nc, 7, stride=1, padding=3)
        self.conv2 = nn.Conv2D(nc, nc, 3, stride=1, padding=3)
        
        self.conv3 = nn.Conv2D(nc, 2*nc, 7, stride=1, padding=1)
        self.conv4 = nn.Conv2D(2*nc, nc, 7, stride=1, padding=1)
        
        self.conv5 = nn.Conv2D(3*nc, nc, 7, stride=1, padding=3)
        self.conv6 = nn.Conv2D(nc, 1, 7, stride=1, padding=3)
        
        
    def forward(self, x):
        fea = self.conv1(x)
        fea = self.act(fea)
        fea = self.conv2(fea)
        fea = self.act(fea)
        
        _,_,h,w = fea.shape
        
        fea_deep = self.conv3(fea)
        fea_deep = self.act(fea_deep)
        fea_deep = self.conv4(fea_deep)
        fea_deep = self.act(fea_deep)
        fea_deep = F.upsample(fea_deep, size = (h, w), mode = 'bilinear')
        
        fea = torch.cat((fea, fea_deep), axis=1)
        fea = self.act(fea)
        del fea_deep
        fea = self.conv5(fea)
        fea = self.act(fea)
        fea = self.conv6(fea)
        return fea


In [None]:
unet_model = Unet(32)

In [None]:
tensor = torch.rand((1, 1))