In [114]:
import torch
import torch.nn as nn
import torchsummary
import torch.nn.functional as F

In [115]:
class ResidualBlock(nn.Module):
    def __init__(self,in_ch,out_ch,drop_out = True):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch,out_ch,3,1,1)
        self.conv2 = nn.Conv2d(out_ch,out_ch,3,1,1)
        self.BatchNorm1 = nn.BatchNorm2d(out_ch)
        self.BatchNorm2 = nn.BatchNorm2d(out_ch)
        
        if drop_out:
            self.drop_out = nn.Dropout(0.5, inplace = True)
        else:
            self.drop_out = nn.Identity()
            
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch,out_ch,1)
        else:
            self.shortcut = nn.Identity()
    def forward(self,x):
        out = F.relu(self.conv1(x))
        out = self.BatchNorm1(out)
        out = self.drop_out(out)
        out = F.relu(self.conv2(out))
        out = self.BatchNorm2(out)
        
        x += self.shortcut(out)
        
        return x

class Block(nn.Module):
    def  __init__(self,in_ch,out_ch):
        super().__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.res_ch = max(in_ch,out_ch) // 2
        
        
        self.head = nn.Conv2d(self.in_ch,self.res_ch,1,bias=False)
        self.tail = nn.Conv2d(self.res_ch,out_ch,1,bias=False)
        
        self.model = []
        self.block = nn.Sequential(*self.model)
        
    def _append(self):
        self.block.append(ResidualBlock(self.res_ch,self.res_ch))
        self.block.append(nn.ReLU(inplace=True))
        self.res_ch *= 2
        self.block = nn.Sequential(*self.block)

    
    def forward(self,x):
        x = self.head(x)
        x = self.block(x)
        x = self.tail(x)
        
        return x
    

In [116]:
class CNNModel(nn.Module):
    def __init__(self,channels,classes):
        super().__init__()
        self.channels = channels
        self.classes = classes
        
        # 256 x 256 x 3
        # 128 x 128 x 64
        # 64 x 64 x 128
        # 32 x 32 x 256
        # 16 x 16 x 512
        # 8 x 8 x 1024 -> fully connected layers
        
        self.block_1 = Block(self.channels,64)
        self.block_2 = Block(64,128)
        self.block_3 = Block(128,256)
        self.block_4 = Block(256,512)
        self.block_5 = Block(512,1024)
        
    def forward(self,x):
        x = F.max_pool2d(self.block_1(x),2)
        x = F.max_pool2d(self.block_2(x),2)
        x = F.max_pool2d(self.block_3(x),2)
        x = F.max_pool2d(self.block_4(x),2)
        x = F.max_pool2d(self.block_5(x),2)

        return x
        
    

In [119]:
random = CNNModel(3,10)
total_params = sum(p.numel() for p in random.parameters())
print(f"{total_params} - params before _append")
random.block_5._append()
random.block_4._append()
random.block_3._append()
random.block_2._append()
random.block_1._append()

total_params = sum(p.numel() for p in random.parameters())
print(f"{total_params} - params after _append")


1046624 - params before _append
7337888 - params after _append


In [120]:
torchsummary.summary(random, input_data=(3,256,256))

Layer (type:depth-idx)                   Output Shape              Param #
├─Block: 1-1                             [-1, 64, 256, 256]        --
|    └─Conv2d: 2-1                       [-1, 32, 256, 256]        96
|    └─Sequential: 2-2                   [-1, 32, 256, 256]        --
|    |    └─ResidualBlock: 3-1           [-1, 32, 256, 256]        18,624
|    |    └─ReLU: 3-2                    [-1, 32, 256, 256]        --
|    └─Conv2d: 2-3                       [-1, 64, 256, 256]        2,048
├─Block: 1-2                             [-1, 128, 128, 128]       --
|    └─Conv2d: 2-4                       [-1, 64, 128, 128]        4,096
|    └─Sequential: 2-5                   [-1, 64, 128, 128]        --
|    |    └─ResidualBlock: 3-3           [-1, 64, 128, 128]        74,112
|    |    └─ReLU: 3-4                    [-1, 64, 128, 128]        --
|    └─Conv2d: 2-6                       [-1, 128, 128, 128]       8,192
├─Block: 1-3                             [-1, 256, 64, 64]         -

Layer (type:depth-idx)                   Output Shape              Param #
├─Block: 1-1                             [-1, 64, 256, 256]        --
|    └─Conv2d: 2-1                       [-1, 32, 256, 256]        96
|    └─Sequential: 2-2                   [-1, 32, 256, 256]        --
|    |    └─ResidualBlock: 3-1           [-1, 32, 256, 256]        18,624
|    |    └─ReLU: 3-2                    [-1, 32, 256, 256]        --
|    └─Conv2d: 2-3                       [-1, 64, 256, 256]        2,048
├─Block: 1-2                             [-1, 128, 128, 128]       --
|    └─Conv2d: 2-4                       [-1, 64, 128, 128]        4,096
|    └─Sequential: 2-5                   [-1, 64, 128, 128]        --
|    |    └─ResidualBlock: 3-3           [-1, 64, 128, 128]        74,112
|    |    └─ReLU: 3-4                    [-1, 64, 128, 128]        --
|    └─Conv2d: 2-6                       [-1, 128, 128, 128]       8,192
├─Block: 1-3                             [-1, 256, 64, 64]         -