In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DenseLayer(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super(DenseLayer, self).__init__()
        self.bn = nn.BatchNorm2d(growth_rate)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1, bias=False)
        
    def forward(self, x):
        out = self.relu(self.bn(self.conv(x)))
        return torch.cat([x, out], dim= 1)

In [36]:
class DenseBlock(nn.Module):
    def __init__(self, num_layers, in_channels, growth_rate):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList([
            DenseLayer(in_channels + i * growth_rate, growth_rate) 
            for i in range(num_layers)
        ])
        
    def forward(self, x):
        for layer in self.layers:
            #print(layer)
            x = layer(x)
        return x

In [37]:
DenseBlock(4, 3,12).layers[-1].conv.out_channels

12

In [38]:
class TransitionLayer(nn.Module):
    def __init__(self, in_channels, out_channels, pool=True):
        super(TransitionLayer, self).__init__()
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
    
    def forward(self, x):
        x = self.conv(x)
        if pool:
            x = self.pool(self.relu(self.bn()))
            return x
        else:
            return x
            

In [39]:
class CSPBlock(nn.Module):
    def __init__(self, dense_block, in_channels, growth_rate, split_ratio=0.5):
        super(CSPBlock, self).__init__()
        self.split_channels = int(in_channels * split_ratio)
        self.remaining_channels = in_channels - self.split_channels

        # Split and dense processing
        self.dense_block = dense_block

        self.transition = TransitionLayer(in_channels + dense_block.layers[-1].conv.out_channels, 
                                    in_channels, pool=False)
    def forward(self, x):
        # Split the input into two parts
        part1, part2 = torch.split(x, [self.split_channels, self.remaining_channels], dim=1)
        # Process one part through the dense block
        part1 = self.dense_block(part1)
        # Concatenate the processed part with the bypassed part
        out = torch.cat([part1, part2], dim=1)
        out = self.transition(out)
        return out

In [40]:
class CSPDenseNet(nn.Module):
    def __init__(self, growth_rate = 32, block_config=[6, 12, 24, 16], num_init_features=64, split_ratio=0.5, num_classes=1000 ):
        super(CSPDenseNet, self).__init__()

        # Initial Convolution
        self.conv1 = nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(num_init_features)
        self.relu = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Dense Blocks and Trasition Layers
        num_features = num_init_features
        self.layers = nn.ModuleList()
        for i, num_layers in enumerate(block_config):
            dense_block = DenseBlock(num_layers, num_features, growth_rate)
            csp_block = CSPBlock(dense_block, num_features, growth_rate, split_ratio)
            self.layers.append(csp_block)
            num_features += num_layers * growth_rate

            if i != len(block_config) - 1:
                self.layers.append(TransitionLayer(num_features, num_features//2))
                num_features = num_features // 2
            
        # Final Batch Norm Layer
        self.bn2 = nn.BatchNorm2d(num_features)
        
        # Fully Connected Layer for Classification
        self.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):

        x = self.pool1(self.relu(self.bn1(self.conv1(x))))

        # Pass through the dense blocks and transition layers
        for layer in self.layers:
            x = layer(x)

        # Final batch normalization and global average pooling
        x = self.relu(self.bn2(x))
        x = nn.AdaptiveAvgPool2d(1)(x)  # Global Average Pooling

        # Flatten and pass through the fully connected layer
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

In [41]:
block = {'CSPDenseNet121':[6,12,24,16],
         'CSPDenseNet169':[6,12,32,32],
         'CSPDenseNet201':[6,12,48,32],
         'CSPDenseNet264':[6,12,64,48]
        }

In [42]:
CSPDenseNet121 = CSPDenseNet(growth_rate=32, block_config=block['CSPDenseNet121'], num_init_features=64,split_ratio=0.5, num_classes=1000)

In [43]:
x = torch.randn(1, 3, 224, 224)
x.shape

torch.Size([1, 3, 224, 224])

In [44]:
CSPDenseNet121(x).shape

RuntimeError: Given groups=1, weight of size [32, 64, 3, 3], expected input[1, 32, 56, 56] to have 64 channels, but got 32 channels instead

In [31]:
DenseNet169 = DenseNet(growth_rate=32, block_config=block['DenseNet169'], num_init_features=64, num_classes=1000)
DenseNet169(x).shape

torch.Size([1, 1000])

In [32]:
DenseNet201 = DenseNet(growth_rate=32, block_config=block['DenseNet201'], num_init_features=64, num_classes=1000)
DenseNet201(x).shape

torch.Size([1, 1000])

In [33]:
DenseNet264 = DenseNet(growth_rate=32, block_config=block['DenseNet264'], num_init_features=64, num_classes=1000)
DenseNet264(x).shape

torch.Size([1, 1000])

In [34]:
DenseNet201 = DenseNet(growth_rate=32, block_config=[2,64,34,42,32], num_init_features=64, num_classes=1000)
DenseNet201(x).shape

torch.Size([1, 1000])

In [38]:
p1, p2 = torch.split(x, [int(3*.5), 3 - int(3*.5)],dim=1)
p1.shape, p2.shape

(torch.Size([1, 1, 224, 224]), torch.Size([1, 2, 224, 224]))

In [None]:
nn.Conv2d(in_channels=3, out_channels=