In [1]:
####

In [2]:
import torch
from torch import nn

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:

class Conv(nn.Module):
    def __init__(self ,
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True ):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation


        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels ,
                               kernel_size , 
                               stride , 
                               padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)


    def forward(self , x):
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        x = self.conv1(x)
        return x

In [5]:

class DenseNet_Conv(nn.Module):
    def __init__(self, 
                 in_channels ,
                 out_channels):
        super(DenseNet_Conv , self).__init__()

        self.conv1 = Conv(in_channels , 
                          in_channels , 
                          kernel_size=(1 , 1) , 
                          stride = (1 , 1) , 
                          padding = 0)
        self.conv2 = Conv(in_channels , 
                          out_channels)
        
    def forward(self , x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

In [6]:
class DenseNet_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 repeats , 
                 k = 12):
        super(DenseNet_Block , self).__init__()

        self.layers = nn.ModuleList()
        out_channels_last = 0
        in_channels_conv = in_channels // 3
        self.remaining_in_channels = in_channels - in_channels_conv

        self.in_channels_ = in_channels
        self.in_channels_conv_ = in_channels_conv
        for r in range(repeats):
            #in_channels = in_channels + out_channels_last
            out_channels = in_channels_conv + k 
            #print(in_channels , out_channels)
            self.layers.append(DenseNet_Conv(in_channels_conv , 
                                             out_channels))
            
            out_channels_last += in_channels_conv
            in_channels_conv = out_channels + out_channels_last
        #print(out_channels)
        self.out_channels = out_channels + self.remaining_in_channels+ out_channels_last

        
    def _concat(self , x , prev):
        for prev_ in prev:
            x = torch.cat([x , prev_] , dim=1)
        return x

    def forward(self , x):
        x_cpy = x.clone()
        x = x[: , :self.in_channels_conv_ , : , :]
        x_ = x_cpy[: , self.in_channels_conv_ : self.in_channels_ , : , :]
        #print(x_.shape)
        prev = [x]
        for layer in self.layers:
            #print(x.shape)
            x = layer(x)
            x = self._concat(x , prev)
            prev.append(x)
        #print(x.shape , x_.shape)
        x = torch.cat([x , x_] , dim=1)
        return x 

In [9]:
densenet_block = DenseNet_Block(3 , 4).to(device)
x = torch.randn(2 , 3 , 112 , 112).to(device)
z = densenet_block(x)
z.shape

torch.Size([2, 2, 112, 112])
torch.Size([2, 286, 112, 112]) torch.Size([2, 2, 112, 112])


torch.Size([2, 288, 112, 112])

In [10]:
config = [
          #[out_channels , kernel_size , stride , padding]
          [512 , (7 , 7) , (2 , 2) , 3] , 
          ('P' , (3 , 3) , (2 , 2)) ,  # Tuple => (padding , (kernel_size) , (stride))
          
          6 , # int => DenseNet Block
          ('P' , (2 , 2) , (2 , 2)) , 

          12 , 
          ('P' , (2 , 2) , (2 , 2)) , 

          24 , 
          ('P' , (2 , 2) , (2 , 2)) , 

          16 , 
          ('P' , (2 , 2) , (2 , 2)) , 
]

In [11]:
class CSPDenseNet(nn.Module):
    def __init__(self , 
                 in_channels , 
                 config = config):
        super(CSPDenseNet , self).__init__()

        self.layers = nn.ModuleList()
        out_channels_list = [128 , 256 , 512 , 1024]
        i = 0
        for  layer in config:
            if isinstance(layer , list):
                out_channels , kernel_size , stride , padding = layer
                self.layers.append(Conv(in_channels , out_channels , kernel_size , stride , padding))
                in_channels = out_channels
            elif isinstance(layer , tuple):
                kernel_size , stride = layer[1] , layer[2]
                self.layers.append(nn.MaxPool2d(kernel_size , stride))
            elif isinstance(layer , int):
                repeats = layer
                a = DenseNet_Block(in_channels , repeats)
                self.layers.append(a)
                in_channels = a.out_channels
                out_channels = out_channels_list[i]
                self.layers.append(Conv(in_channels , out_channels , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0))
                i += 1
                in_channels = out_channels

                
    def forward(self , x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
cspdensenet = CSPDenseNet(3).to(device)
x = torch.randn(2 , 3 , 112 , 112).to(device)
z = cspdensenet(x)
print(z.shape)