In [1]:
import torch
import torch.nn as nn

In [2]:
class norm_act(nn.Module):
    def __init__(self, channels, act=nn.ReLU()):
        super(norm_act, self).__init__()
        self.act = act
        self.norm = nn.InstanceNorm3d(channels, affine=True)
    
    def forward(self, x):
        x = self.norm(x)
        x = self.act(x)
        return x


In [3]:
class conv_block(nn.Module):
    """
    Convolution Block
    """
    def __init__(self, in_channels, out_channels, kernel_size=3,
                 stride=1, padding=1):
        super(conv_block, self).__init__()
        self.act = norm_act(in_channels)
        self.conv = nn.Conv3d(in_channels=in_channels, out_channels=out_channels,
                              kernel_size=kernel_size, stride=stride, padding=padding,
                              padding_mode="reflect")

    def forward(self, x):
        x = self.act(x)
        x = self.conv(x)
        return x


In [4]:
class stem(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3,
                 stride=1, padding=1):
        super(stem, self).__init__()
        
        self.main_conv = nn.Sequential(
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels,
                                  kernel_size=kernel_size, stride=stride, padding=padding,
                                  padding_mode="reflect"),
            conv_block(in_channels=out_channels, out_channels=out_channels,
                                  kernel_size=kernel_size, stride=stride, padding=padding)
        )
        self.shortcut_conv =  nn.Sequential(
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels,
                                  kernel_size=(1,1,1), stride=stride, padding=0,
                                  padding_mode="replicate"),
            norm_act(out_channels, nn.Identity())
        )
        
    def forward(self, x):
        main = self.main_conv(x)
        shortcut = self.shortcut_conv(x)
        return main + shortcut

In [5]:
class residual_block(nn.Module):
    """
    Convolution Block
    """
    def __init__(self, in_channels, out_channels,
                 kernel_size=3,
                 stride=1,
                 drop=0):
        super(residual_block, self).__init__()
        
        self.main_conv = nn.Sequential(
            conv_block(in_channels=in_channels, out_channels=out_channels,
                                  kernel_size=kernel_size, stride=stride, padding=kernel_size//2),
            conv_block(in_channels=out_channels, out_channels=out_channels,
                                  kernel_size=kernel_size, stride=1, padding=1)
        )
        self.shortcut_conv =  nn.Sequential(
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels,
                                  kernel_size=(1,1,1), stride=stride, padding=0,
                                  padding_mode="replicate"),
            norm_act(out_channels, nn.Identity())
        )
        self.dropout = nn.Dropout3d(p=drop)

    def forward(self, x):
        main = self.main_conv(x)
        shortcut = self.shortcut_conv(x)
        out = main + shortcut
        return self.dropout(out)


In [6]:
class upsample_block(nn.Module):
    def __init__(self, in_channels, out_channels,
                 kernel_size=3,
                 stride=2, padding=2,
                 drop=0):
        super(upsample_block, self).__init__()
        
        self.unconv = torch.nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
                                               padding=1, output_padding=stride-1)
        
    def forward(self, x):
        out = self.unconv(x)
        return out

In [7]:
class attention_gate(nn.Module):
    def __init__(self, in_channels1, in_channels2, intermediate_channels, act=nn.LeakyReLU()):
        super(attention_gate, self).__init__()
        
        self.conv1 = nn.Conv3d(in_channels=in_channels1, out_channels=intermediate_channels,
                              kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv3d(in_channels=in_channels2, out_channels=intermediate_channels,
                              kernel_size=1, stride=1, padding=0)
        self.conv = nn.Conv3d(in_channels=intermediate_channels, out_channels=1,
                              kernel_size=1, stride=1, padding=0)
        self.act = act
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x1, x2):
        x1_conv = self.conv1(x1)
        x2_conv = self.conv2(x2)
        inter = self.act(x1_conv + x2_conv)
        inter = self.sigmoid(self.conv(inter))
        return x1*inter

In [8]:
class attention_concat(nn.Module):
    def __init__(self, main_channels, skip_channels, act=nn.LeakyReLU()):
        super(attention_concat, self).__init__()
        self.att_gate = attention_gate(skip_channels, main_channels, main_channels)
        
    def forward(self, main, skip):
        attention_across = self.att_gate(skip, main)
        return torch.cat([main, attention_across], dim=1)

In [29]:
class ResUNet(nn.Module):
    def __init__(self, in_channels=1, 
                 drop=0.2,
                 dropout_change_per_layer=0.0,
                 channels_coef=16,
                 out_act=nn.Sigmoid(),
                 use_input_noise=False):
        super(ResUNet, self).__init__()
        lc = [channels_coef, 2*channels_coef, 4*channels_coef, 8*channels_coef, 16*channels_coef]
        
        self.stem = stem(in_channels=in_channels, out_channels=lc[0])
        
        self.encoder = nn.ModuleList(
           [residual_block(lc[0], lc[1], stride=2, drop=drop),
            residual_block(lc[1], lc[2], stride=2, drop=drop + 1*dropout_change_per_layer),
            residual_block(lc[2], lc[3], stride=2, drop=drop + 2*dropout_change_per_layer),
            residual_block(lc[3], lc[4], stride=2, drop=drop + 3*dropout_change_per_layer)]
        )
        self.bridge = nn.Sequential(
            conv_block(lc[4], lc[4]),
            conv_block(lc[4], lc[4])
        )    
        self.decoder = nn.ModuleList([
           nn.ModuleList([upsample_block(lc[4], lc[3]),
                          attention_concat(lc[3], lc[3]),
                          residual_block(2*lc[3], lc[3], stride=1)]),
            
           nn.ModuleList([upsample_block(lc[3], lc[2]),
                          attention_concat(lc[2], lc[2]),
                          residual_block(2*lc[2], lc[2], stride=1)]),
            
           nn.ModuleList([upsample_block(lc[2], lc[1]),
                          attention_concat(lc[1], lc[1]),
                          residual_block(2*lc[1], lc[1], stride=1)]),
            
           nn.ModuleList([upsample_block(lc[1], lc[0]),
                          attention_concat(lc[0], lc[0]),
                          residual_block(2*lc[0], lc[0], stride=1)]),
        ])
        
        self.output_block = nn.Sequential(
            nn.Conv3d(in_channels=lc[0], out_channels=1, kernel_size=1, stride=1, padding=0),
            out_act
        )
        
    def forward(self, x):
        skip_layers = []
        x = self.stem(x)
        skip_layers.append(x)
        
        #encode
        for enc_blok in self.encoder:
            x = enc_blok(x)
            skip_layers.append(x)
        
        #bridge
        x = self.bridge(x)
        
        #decode
        for idx, dec_blok in enumerate(self.decoder):
            x = dec_blok[0](x)
            x = dec_blok[1](x, skip_layers[3-idx])
            x = dec_blok[2](x)
            
        out = self.output_block(x)
        return out

In [32]:
x = torch.rand(1, 1, 64, 64, 64)
RN = ResUNet(in_channels=1)
print(RN(x).shape)

torch.Size([1, 1, 64, 64, 64])
