In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DenseLayer(nn.Module):
    expansion = 4
    
    def __init__(self, in_channels, growth_rate):
        super(DenseLayer, self).__init__()
        zip_channels = self.expansion * growth_rate

        self.features = nn.Sequential(
                            nn.Conv2d(in_channels, zip_channels, kernel_size=1, bias=False),
                            nn.BatchNorm2d(zip_channels),
                            nn.ReLU(inplace=True),
                            nn.Conv2d(zip_channels, growth_rate, kernel_size=3, padding=1, bias=False),
                            nn.BatchNorm2d(growth_rate),
                            nn.ReLU(inplace=True),
                            )    
        
    def forward(self, x):
        out = self.features(x)
        return torch.cat([x, out], dim= 1)


x = torch.randn(1,3,224,224)
model = DenseLayer(3,12)
model(x).shape

torch.Size([1, 15, 224, 224])

In [23]:
class DenseBlock(nn.Module):
    def __init__(self, num_layers, in_channels, growth_rate):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList([
            DenseLayer(in_channels + i * growth_rate, growth_rate) 
            for i in range(num_layers)
        ])
        
    def forward(self, x):
        for layer in self.layers:
            #print(layer)
            x = layer(x)
        return x
model = DenseBlock(4,3,12)
model(x).shape

torch.Size([1, 51, 224, 224])

In [248]:
import torch.nn as nn

class TransitionLayer(nn.Module):
    def __init__(self, in_channels, out_channels, pool = False):
        super(TransitionLayer, self).__init__()
        self.layer = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                )
        self.pool = pool
        self.avg = nn.AvgPool2d(kernel_size=2, stride=2)

    
    def forward(self, x):
        if self.pool:
            return self.avg(self.layer(x))
        else:
            return self.layer(x)


In [249]:
class CSPBlock(nn.Module):
    def __init__(self, num_layers, in_channels, growth_rate, split_ratio=0.5, pool=True):
        super(CSPBlock, self).__init__()
        self.split_channels = int(in_channels * split_ratio)
        self.remaining_channels = in_channels - self.split_channels

        # Split and dense processing
        # DenseBlock with the split feature map as input
        self.dense_block = DenseBlock(
            num_layers=num_layers,
            in_channels=self.split_channels,
            growth_rate=growth_rate
        )
        
        self.transition_first = TransitionLayer(self.split_channels  + num_layers * growth_rate, 
                                     self.split_channels)
        
        self.transition_last = TransitionLayer(in_channels, 
                                     in_channels)
    def forward(self, x):
        # Split the input into two parts
        part1, part2 = torch.split(x, [self.split_channels, self.remaining_channels], dim=1)
  
        # Process one part through the dense block
        part1 = self.dense_block(part1)

        part1 = self.transition_first(part1)
        #print(self.dense_block.layers[-1].conv.in_channels)
        # Concatenate the processed part with the bypassed part
        out = torch.cat([part1, part2], dim=1)
        
        # Apply the transition layer
        out = self.transition_last(out)
        return out

In [250]:
model = CSPBlock(4,64,12)
x = torch.randn(1,64, 56, 56)
x.shape

torch.Size([1, 64, 56, 56])

In [251]:
model(x).shape

torch.Size([1, 64, 56, 56])

In [261]:
class CSPDenseNet(nn.Module):
    def __init__(self, growth_rate = 32, block_config=[6, 12, 24, 16], num_init_features=64, split_ratio=0.5, num_classes=1000 ):
        super(CSPDenseNet, self).__init__()

        # Initial Convolution
        self.base_layer = nn.Sequential(
                nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False),
                nn.BatchNorm2d(num_init_features),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                 )

        # Dense Blocks and Trasition Layers
        num_features = num_init_features
        #self.csp_block = CSPBlock(block_config[0], num_features, growth_rate, split_ratio=split_ratio) 

    
        self.layers = nn.ModuleList()
        for i, num_layers in enumerate(block_config):
            csp_block = CSPBlock(num_layers, num_features, growth_rate, split_ratio)
            self.layers.append(csp_block)
            

            if i != len(block_config) - 1:
                self.layers.append(TransitionLayer(num_features, num_features//2,pool=True))
                num_features = num_features // 2
            
        # Final Batch Norm Layer

        self.relu = nn.ReLU(inplace=True)
        self.bn2 = nn.BatchNorm2d(num_features)
        
        # Fully Connected Layer for Classification
        self.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):

        x = self.base_layer(x)
        # x = self.csp_block(x)
        # print(x.shape)

        # Pass through the dense blocks and transition layers
        for layer in self.layers:
            #print(layer)
            x = layer(x)
        print(x.shape)

        # Final batch normalization and global average pooling
        x = self.relu(self.bn2(x))
        x = nn.AdaptiveAvgPool2d(1)(x)  # Global Average Pooling

        # Flatten and pass through the fully connected layer
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

In [262]:
block = {'CSPDenseNet121':[6,12,24,16],
         'CSPDenseNet169':[6,12,32,32],
         'CSPDenseNet201':[6,12,48,32],
         'CSPDenseNet264':[6,12,64,48]
        }

In [269]:
CSPDenseNet121 = CSPDenseNet(growth_rate=32, block_config=block['CSPDenseNet121'], num_init_features=64,split_ratio=0.5, num_classes=1000)

In [276]:
x = torch.randn(1, 3, 224,224)
x.shape

torch.Size([1, 3, 612, 612])

In [277]:
CSPDenseNet121(x).shape

torch.Size([1, 8, 19, 19])


torch.Size([1, 1000])

In [278]:
CSPDenseNet169 = CSPDenseNet(growth_rate=32, block_config=block['CSPDenseNet169'], num_init_features=64, num_classes=1000)
CSPDenseNet169(x).shape

torch.Size([1, 8, 19, 19])


torch.Size([1, 1000])

In [279]:
CSPDenseNet201 = CSPDenseNet(growth_rate=32, block_config=block['CSPDenseNet201'], num_init_features=64, num_classes=1000)
CSPDenseNet201(x).shape

torch.Size([1, 8, 19, 19])


torch.Size([1, 1000])

In [280]:
CSPDenseNet264 = CSPDenseNet(growth_rate=32, block_config=block['CSPDenseNet264'], num_init_features=64, num_classes=1000)
CSPDenseNet264(x).shape

torch.Size([1, 8, 19, 19])


torch.Size([1, 1000])

In [281]:
x = torch.randn(1,64,56,56)
nn.AvgPool2d(kernel_size=2, stride=2)(x).shape

torch.Size([1, 64, 28, 28])