In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ELU-net construction

<img src="img/elunet_arch.png"  width="300" height="300">

### ELU-Net Components

In [168]:
class DoubleConv(nn.Module):
    """ [(Conv2d) => (BN) => (ReLu)] * 2 """
    
    def __init__(self,in_channels,out_channels) -> None:
        super().__init__()
        self.double_conv = nn.Sequential(
                nn.Conv2d(in_channels,out_channels,3,padding="same",stride=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.Conv2d(out_channels,out_channels,3,padding="same",stride=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()      
            )

    def forward(self,x):
        return self.double_conv(x)

In [169]:
class DownSample(nn.Module):
    """ MaxPool => DoubleConv """
    def __init__(self,in_channels,out_channels) -> None:
        super().__init__()
        self.down_sample = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels,out_channels)
        )
    def forward(self,x):
        x  = self.down_sample(x)
        return x

In [170]:
import math
class UpSample(nn.Module):
    def __init__(self,in_channels,out_channels,c:int) -> None:
        """ UpSample input tensor by a factor of `c`
                - the value of base 2 log c defines the number of upsample 
                layers that will be applied
        """
        super().__init__()
        n = 0 if n == 0 else int(math.log(c,2))

        self.upsample = nn.ModuleList(
            [nn.ConvTranspose2d(in_channels,in_channels,2,2) for i in range(n)]
        )
        self.conv_3 = nn.Conv2d(in_channels,out_channels,3,padding="same",stride=1)

    def forward(self,x):
        for layer in self.upsample:
            x = layer(x)
        return self.conv_3(x)

## Construct ELUnet

In [171]:
class ELUnet(nn.Module):
    def __init__(self,in_channels,out_channels) -> None:
        super().__init__()
        # ------ Input convolution --------------
        self.in_conv = DoubleConv(in_channels,64)
        # -------- Encoder ----------------------
        self.down_1 = DownSample(64,128)
        self.down_2 = DownSample(128,256)
        self.down_3 = DownSample(256,512)
        self.down_4 = DownSample(512,1024)
        
        # -------- Upsampling ------------------
        self.up_1024_512 = UpSample(1024,512,2)

        self.up_512_64 = UpSample(512,64,8)
        self.up_512_128 = UpSample(512,128,4)
        self.up_512_256 = UpSample(512,256,2)
        self.up_512_512 = UpSample(512,512,0)

        self.up_256_64 = UpSample(256,64,4)
        self.up_256_128 = UpSample(256,128,2)
        self.up_256_256 = UpSample(256,256,0)

        self.up_128_64 = UpSample(128,64,2)
        self.up_128_128 = UpSample(128,128,0)

        self.up_64_64 = UpSample(64,64,0)
     
        # ------ Decoder block ---------------
        self.dec_4 = DoubleConv(1024,512)
        self.dec_3 = DoubleConv(768,256)
        self.dec_2 = DoubleConv(512,128)
        self.dec_1 = DoubleConv(320,64)
        # ------ Output convolution

        self.out_conv = DoubleConv(64,out_channels)

    def forward(self,x):
        x = self.in_conv(x) # 64
        # ---- Encoder outputs
        x_enc_1 = self.down_1(x) # 128
        x_enc_2 = self.down_2(x_enc_1) # 256
        x_enc_3 = self.down_3(x_enc_2) # 512
        x_enc_4 = self.down_4(x_enc_3) # 1024
    
        # ------ decoder outputs
        x_up_1 = self.up_1024_512(x_enc_4)
        x_dec_4 = self.dec_4(torch.cat([x_up_1,self.up_512_512(x_enc_3)],dim=1))

        x_up_2 = self.up_512_256(x_dec_4)
        x_dec_3 = self.dec_3(torch.cat([x_up_2,
            self.up_512_256(x_enc_3),
            self.up_256_256(x_enc_2)
            ],
        dim=1))

        x_up_3 = self.up_256_128(x_dec_3)
        x_dec_2 = self.dec_2(torch.cat([
            x_up_3,
            self.up_512_128(x_enc_3),
            self.up_256_128(x_enc_2),
            self.up_128_128(x_enc_1)
        ],dim=1))

        x_up_4 = self.up_128_64(x_dec_2)
        x_dec_1 = self.dec_1(torch.cat([
            x_up_4,
            self.up_512_64(x_enc_3),
            self.up_256_64(x_enc_2),
            self.up_128_64(x_enc_1),
            self.up_64_64(x)
        ],dim=1))

        return self.out_conv(x_dec_1)

## Test

In [172]:
elunet = ELUnet(1,3)
x = torch.randn(1,1,256,256)
elunet(x).shape


torch.Size([1, 3, 256, 256])

In [173]:
from torch.utils.tensorboard import SummaryWriter
tb = SummaryWriter()
tb.add_graph(elunet,x)
tb.close()

In [167]:
from torchsummary import summary
summary(elunet,(1,256,256))

torch.Size([2, 128, 128, 128])
torch.Size([2, 256, 64, 64])
torch.Size([2, 512, 32, 32])
torch.Size([2, 1024, 16, 16])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
       BatchNorm2d-2         [-1, 64, 256, 256]             128
              ReLU-3         [-1, 64, 256, 256]               0
            Conv2d-4         [-1, 64, 256, 256]          36,928
       BatchNorm2d-5         [-1, 64, 256, 256]             128
              ReLU-6         [-1, 64, 256, 256]               0
        DoubleConv-7         [-1, 64, 256, 256]               0
         MaxPool2d-8         [-1, 64, 128, 128]               0
            Conv2d-9        [-1, 128, 128, 128]          73,856
      BatchNorm2d-10        [-1, 128, 128, 128]             256
             ReLU-11        [-1, 128, 128, 128]               0
           Conv2d-12        [-1, 128, 128, 128] 