In [7]:
import torch
import torch.nn as nn


In [33]:
class ConvBnAct(nn.Module):
    
    def __init__(self, 
                 n_in, # Input Channels
                 n_out, # Output Channels
                 kernel_size = 3, 
                 stride = 1, 
                 padding = 0, 
                 groups = 1, 
                 bn = True, # Batch Norm
                 act = True, # Activation
                 bias = False
                ):
        
        super(ConvBnAct, self).__init__()
        
        self.conv = nn.Conv2d(in_channels=n_in, 
                              out_channels=n_out, 
                              kernel_size=kernel_size, 
                              stride=stride, 
                              padding=padding, 
                              groups=groups, 
                              bias=bias
                             )
        self.batch_norm = nn.BatchNorm2d(num_features=n_out) if bn else nn.Identity()
        self.activation = nn.LeakyReLU() if act else nn.Identity()

    def forward(self,x):
        return self.activation(self.batch_norm(self.conv(x)))

In [34]:
class OSAModule(nn.Module):
    
    def __init__(self, 
                 n_in, # Input Channels
                 n_stage,# Channels in stage
                 n_out, # Output channel (after aggregation)
                 n_layers_per_block # number of layers per block
                ):
        super(OSAModule, self).__init__()
        self.layers = nn.ModuleList()
        n_input = n_in

        for _ in range(n_layers_per_block):
            self.layers.append(ConvBnAct(n_input, 
                                         n_stage, 
                                         kernel_size=3, 
                                         stride=1, 
                                         padding=1, 
                                         groups=1, 
                                         bn=True, 
                                         act=True
                                        ))
            n_input = n_stage
        
        #self.layers = nn.Sequential(*self.layers)
            
        # Feature Aggregation
        n_input = n_in + (n_layers_per_block * n_stage)
        
        self.feat_aggregation = ConvBnAct(n_input, 
                                          n_out, 
                                          kernel_size=1, 
                                          stride=1, 
                                          padding=0, 
                                          groups=1, 
                                          bn=True, 
                                          act=True
                                         )
        
    def forward(self, x):
        
        output = []
        output.append(x)
        
        for layer in self.layers:
            x = layer(x)
            output.append(x)
        
        # Concat and Aggregate
        x = torch.cat(output, dim=1)
        #print(x.shape)
        x = self.feat_aggregation(x)
        print(self.feat_aggregation.conv.in_channels)
        
        return x
        
        

In [35]:
block = OSAModule(n_in = 128, #VoVNet-27-Slim-Configuration (Stage-2 OSA Module)
                  n_stage=64, 
                  n_out=128,
                  n_layers_per_block = 5
                 )
x = torch.randn(1, 128, 64, 64)
block(x).shape

448


torch.Size([1, 128, 64, 64])

In [36]:
nn.Conv2d(3,12, 1).in_channels

3

In [37]:
class RepConv(nn.Module):
    def __init__(self, 
                 n_in, # Input Channels
                 n_out):
        super(RepConv, self).__init__()
        self.conv3x3 = nn.Conv2d(n_in, n_out,3,padding=1)
        self.conv1X1 = nn.Conv2d(n_in, n_out,1)
        self.bn = nn.BatchNorm2d(n_in)
    def forward(self, x):
        x = self.conv3x3(x) + self.conv1X1(x) + self.bn(x)
        return x

    
    

In [38]:
model = RepConv(3,3)

In [39]:
y = torch.randn(1,3,224,224)
model(y).shape

torch.Size([1, 3, 224, 224])

In [41]:
class ELAN(nn.Module):
    
    '''
      Combine CSPNet + VoVNet (OSA Stages)
    '''
    def __init__(self, 
                 n_in, # Input channels
                 n_stage, # Channels in stage
                 n_out, # Output channels
                 n_layers_per_block = 2, # number of layers per block,
                 n_stages_to_combine = 2 # number of stages before aggregating
                ):
        super(ELAN, self).__init__()
        
        # CSP Layers
        self.base_layer = nn.Sequential(
            ConvBnAct(n_in, 
                      2 * n_in, 
                      kernel_size=3, 
                      stride=1, 
                      padding=1, 
                      groups=1, 
                      bn=True, 
                      act=True
                     ), 
            ConvBnAct(2 * n_in, 
                      n_in, 
                      kernel_size=1, 
                      stride=1, 
                      padding=0, 
                      groups=1, 
                      bn=True, 
                      act=True
                     ),
        )
        # VoVNet layers
        n_input = n_in//2
        self.stage_layers = []
        
        for _ in range(n_stages_to_combine):
            block_layers = []
            for _ in range(n_layers_per_block):
                block_layers.append(ConvBnAct(n_input, 
                                              n_stage, 
                                              kernel_size=3, 
                                              stride=1, 
                                              padding=1, 
                                              groups=1, 
                                              bn=True, 
                                              act=True
                                             )
                                   ) 
                n_input = n_stage
             # Collect feature after every stage
            block_layers = nn.Sequential(*block_layers)
            self.stage_layers.append(block_layers)
        
        self.stage_layers = nn.Sequential(*self.stage_layers)
        
        # Feature Aggregation
        n_input = n_in + n_in//2 + (n_stages_to_combine * n_stage)
        
        self.feat_aggregation = ConvBnAct(n_input, 
                                          n_out, 
                                          kernel_size=1, 
                                          stride=1, 
                                          padding=0, 
                                          groups=1, 
                                          bn=True, 
                                          act=True
                                         )

    def forward(self, x):
        output = []
        output.append(x)
        
        # Base layer
        x = self.base_layer(x)
        
        # Partition in 2 parts
        x1, x2 = x.chunk(2, dim=1)
        output.append(x1)
        
        for layer in self.stage_layers:
            x2 = layer(x2)
            output.append(x2)
        
        # Concat and Aggregate
        [print(i.shape) for i in output]
        x = torch.cat(output, dim=1)
        x = self.feat_aggregation(x)
        
        return x
        

In [43]:
# Test
block = ELAN(n_in = 128,
             n_stage=64, 
             n_out=128,
             n_layers_per_block = 2,
             n_stages_to_combine = 2
            )

x = torch.randn(1, 128, 64, 64)
block(x)

torch.Size([1, 128, 64, 64])
torch.Size([1, 64, 64, 64])
torch.Size([1, 64, 64, 64])
torch.Size([1, 64, 64, 64])


tensor([[[[-7.6246e-04,  3.7295e-01,  5.0131e-01,  ..., -1.7312e-02,
            7.3806e-01, -1.5050e-02],
          [ 9.3829e-01, -1.3767e-02,  6.8926e-01,  ..., -5.1319e-03,
            1.6566e+00,  9.0356e-01],
          [ 1.1516e+00, -6.7091e-03,  5.6886e-01,  ..., -2.0229e-02,
            2.2276e-01, -1.2774e-02],
          ...,
          [ 5.6693e-02, -1.4866e-02, -1.1851e-02,  ...,  8.7788e-01,
           -2.6503e-04, -8.9352e-03],
          [-1.2073e-02,  1.5202e-01, -7.4300e-03,  ...,  1.2531e+00,
           -1.8616e-02, -3.8209e-03],
          [-1.7929e-02,  1.1516e+00,  3.7885e-01,  ...,  1.9867e-01,
           -9.1930e-03, -1.1661e-02]],

         [[ 1.7268e+00,  4.7714e-01,  4.1079e-01,  ...,  2.9288e-01,
           -6.6743e-03, -1.0654e-04],
          [ 1.3292e+00,  1.7059e+00, -7.0571e-03,  ...,  5.1221e-01,
            2.1367e-01,  6.6429e-02],
          [ 4.6191e-01, -1.2336e-03, -1.3481e-02,  ...,  2.7611e-01,
            8.9970e-01,  5.1949e-01],
          ...,
     

In [None]:
class E_ELAN(nn.Module):
    
    '''
      Extension of ELAN
    '''
    def __init__(self, 
                 n_in, # Input channels
                 n_stage, # Channels in stage
                 n_out, # Output channels
                 n_layers_per_block = 2, # number of layers per block,
                 n_stages_to_combine = 2 # number of stages before aggregating
                ):
        super(E_ELAN, self).__init__()
        
        # CSP Layers
        self.base_layer = nn.Sequential(
            ConvBnAct(n_in, 
                      2 * n_in, 
                      kernel_size=3, 
                      stride=1, 
                      padding=1, 
                      groups=n_in, 
                      bn=True, 
                      act=True
                     ), 
            ConvBnAct(2 * n_in, 
                      n_in, 
                      kernel_size=1, 
                      stride=1, 
                      padding=0, 
                      groups=1, 
                      bn=True, 
                      act=True
                     ),
        )
        # VoVNet layers
        n_input = n_in//2
        self.stage_layers = []
        
        for _ in range(n_stages_to_combine):
            block_layers = []
            for _ in range(n_layers_per_block):
                block_layers.append(ConvBnAct(n_input, 
                                              n_stage, 
                                              kernel_size=3, 
                                              stride=1, 
                                              padding=1, 
                                              groups=1, 
                                              bn=True, 
                                              act=True
                                             )
                                   ) 
                n_input = n_stage
             # Collect feature after every stage
            block_layers = nn.Sequential(*block_layers)
            self.stage_layers.append(block_layers)
        
        self.stage_layers = nn.Sequential(*self.stage_layers)
        
        # Feature Aggregation
        n_input = n_in + n_in//2 + (n_stages_to_combine * n_stage)
        
        self.feat_aggregation = ConvBnAct(n_input, 
                                          n_out, 
                                          kernel_size=1, 
                                          stride=1, 
                                          padding=0, 
                                          groups=1, 
                                          bn=True, 
                                          act=True
                                         )

    def forward(self, x):
        output = []
        output.append(x)
        
        # Base layer
        x = self.base_layer(x)
        
        # Partition in 2 parts
        x1, x2 = x.chunk(2, dim=1)
        output.append(x1)
        
        for layer in self.stage_layers:
            x2 = layer(x2)
            output.append(x2)
        
        # Concat and Aggregate
        [print(i.shape) for i in output]
        x = torch.cat(output, dim=1)
        x = self.feat_aggregation(x)
        
        return x