<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/DenseNet_with_Dense_Blocks_for_Efficient_Feature_Reuse.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

class DenseBlock(nn.Module):
    def __init__(self, num_layers, in_channels, growth_rate):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList([self._make_layer(in_channels + i * growth_rate, growth_rate) for i in range(num_layers)])

    def _make_layer(self, in_channels, growth_rate):
        layer = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1, bias=False)
        )
        return layer

    def forward(self, x):
        features = [x]
        for layer in self.layers:
            out = layer(torch.cat(features, dim=1))
            features.append(out)
        return torch.cat(features, dim=1)

class TransitionLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TransitionLayer, self).__init__()
        self.transition = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        return self.transition(x)

class DenseNet(nn.Module):
    def __init__(self, num_blocks, growth_rate, reduction, num_classes):
        super(DenseNet, self).__init__()
        self.growth_rate = growth_rate
        num_channels = 2 * growth_rate
        self.conv1 = nn.Conv2d(3, num_channels, kernel_size=3, padding=1, bias=False)

        self.blocks = nn.ModuleList()
        for i in range(len(num_blocks)):
            block = DenseBlock(num_blocks[i], num_channels, growth_rate)
            self.blocks.append(block)
            num_channels += num_blocks[i] * growth_rate
            if i != len(num_blocks) - 1:
                transition = TransitionLayer(num_channels, int(num_channels * reduction))
                self.blocks.append(transition)
                num_channels = int(num_channels * reduction)

        self.bn = nn.BatchNorm2d(num_channels)
        self.fc = nn.Linear(num_channels, num_classes)

    def forward(self, x):
        out = self.conv1(x)
        for block in self.blocks:
            out = block(out)
        out = torch.relu(self.bn(out))
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

# Initialize the DenseNet model
model = DenseNet(num_blocks=[6, 12, 24, 16], growth_rate=12, reduction=0.5, num_classes=10)