In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import os

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Training on device: {device}")

#I. Building Blocks of DenseNet

We'll start by implementing the core components of the DenseNet architecture.

## 1. DenseNet Simple Layer

The simple layer in a DenseNet consists of a Batch Normalization layer, a ReLU activation function, and a 3x3 Convolutional layer.

In [None]:
class DenseNetSimpleLayer(nn.Module):
  def __init__(self, in_channels, growth_rate):
      """
      Initializes the DenseNet Simple Layer.

      Args:
          in_channels (int): Number of input channels.
          growth_rate (int): Number of output channels (k in the paper).
      """
      super(DenseNetSimpleLayer, self).__init__()
      self.bn1 = nn.BatchNorm2d(in_channels)
      self.relu1 = nn.ReLU(inplace=True)
      self.conv1 = nn.Conv2d(in_channels, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)

  def forward(self, x):
      """
      Forward pass of the DenseNet Simple Layer.

      Args:
          x (torch.Tensor): Input tensor.

      Returns:
          torch.Tensor: Output tensor.
      """
      out = self.conv1(self.relu1(self.bn1(x)))
      out = torch.cat([x, out], 1)
      return out

##2. DenseNet Bottleneck Layer

The bottleneck layer is a more computationally efficient version of the simple layer. It introduces a 1x1 convolution to reduce the number of feature maps before the more expensive 3x3 convolution. The 1x1 convolution produces 4 * growth_rate feature maps.

In [None]:
class DenseNetBottleneckLayer(nn.Module):
  def __init__(self, in_channels, growth_rate):
      """
      Initializes the DenseNet Bottleneck Layer.

      Args:
          in_channels (int): Number of input channels.
          growth_rate (int): Number of output channels for the 3x3 convolution.
      """
      super(DenseNetBottleneckLayer, self).__init__()
      inter_channels = 4 * growth_rate
      self.bn1 = nn.BatchNorm2d(in_channels)
      self.relu1 = nn.ReLU(inplace=True)
      self.conv1 = nn.Conv2d(in_channels, inter_channels, kernel_size=1, stride=1, bias=False)

      self.bn2 = nn.BatchNorm2d(inter_channels)
      self.relu2 = nn.ReLU(inplace=True)
      self.conv2 = nn.Conv2d(inter_channels, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)

  def forward(self, x):
      """
      Forward pass of the DenseNet Bottleneck Layer.

      Args:
          x (torch.Tensor): Input tensor.

      Returns:
          torch.Tensor: Output tensor.
      """
      out = self.conv1(self.relu1(self.bn1(x)))
      out = self.conv2(self.relu2(self.bn2(out)))
      out = torch.cat([x, out], 1)
      return out


## 3. Transition Layer

The transition layer connects two dense blocks. It consists of a Batch Normalization layer, a 1x1 Convolutional layer to reduce the number of channels (compression), and an Average Pooling layer to reduce the spatial dimensions.

In [None]:
class TransitionLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        """
        Initializes the Transition Layer.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
        """
        super(TransitionLayer, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        """
        Forward pass of the Transition Layer.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        out = self.conv1(self.relu1(self.bn1(x)))
        out = self.avg_pool(out)
        return out

#II. Assembling the Full DenseNet Model

Now we will combine these building blocks to create the complete DenseNet architecture.

In [None]:
class DenseNet(nn.Module):
  def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10, init_weights=True, dataset_used="cifar10"):
    """
    Initializes the DenseNet model.

    Args:
        block (nn.Module): The type of dense layer to use (Simple or Bottleneck).
        nblocks (list of int): The number of layers in each dense block.
        growth_rate (int): The growth rate (k).
        reduction (float): The compression factor for the transition layers.
        num_classes (int): The number of output classes.
        init_weights (bool): Whether to initialize the weights.
        dataset_used (str): The dataset used for training.
    """
    super(DenseNet, self).__init__()
    self.growth_rate = growth_rate
    num_planes = 2 * growth_rate

    if dataset_used == "cifar10":
      # Initial convolution for CIFAR-10
      self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
    else:
      # # Initial convolution for ImageNet
      self.conv1 = nn.Sequential(
            nn.Conv2d(3, 2 * growth_rate, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(2 * growth_rate),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

    # First Dense Block
    self.dense1 = self._make_dense_block(block, num_planes, nblocks[0])
    num_planes += nblocks[0] * growth_rate
    out_planes = int(num_planes * reduction)
    self.trans1 = TransitionLayer(num_planes, out_planes)
    num_planes = out_planes

    # Second Dense Block
    self.dense2 = self._make_dense_block(block, num_planes, nblocks[1])
    num_planes += nblocks[1] * growth_rate
    out_planes = int(num_planes * reduction)
    self.trans2 = TransitionLayer(num_planes, out_planes)
    num_planes = out_planes

    # Third Dense Block
    self.dense3 = self._make_dense_block(block, num_planes, nblocks[2])
    num_planes += nblocks[2] * growth_rate
    out_planes = int(num_planes * reduction)
    self.trans3 = TransitionLayer(num_planes, out_planes)
    num_planes = out_planes

    # Fourth Dense Block
    self.dense4 = self._make_dense_block(block, num_planes, nblocks[3])
    num_planes += nblocks[3] * growth_rate

    # Final layers
    self.bn = nn.BatchNorm2d(num_planes)
    self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
    self.linear = nn.Linear(num_planes, num_classes)

    if init_weights:
        self._initialize_weights()

  def _make_dense_block(self, block, in_planes, nblock):
      layers = []
      for _ in range(nblock):
          layers.append(block(in_planes, self.growth_rate))
          in_planes += self.growth_rate
      return nn.Sequential(*layers)

  def _initialize_weights(self):
      for m in self.modules():
          if isinstance(m, nn.Conv2d):
              nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
              if m.bias is not None:
                  nn.init.constant_(m.bias, 0)
          elif isinstance(m, nn.BatchNorm2d):
              nn.init.constant_(m.weight, 1)
              nn.init.constant_(m.bias, 0)
          elif isinstance(m, nn.Linear):
              nn.init.constant_(m.bias, 0)

  def forward(self, x):
      out = self.conv1(x)
      out = self.trans1(self.dense1(out))
      out = self.trans2(self.dense2(out))
      out = self.trans3(self.dense3(out))
      out = self.dense4(out)
      out = self.avg_pool(F.relu(self.bn(out)))
      out = torch.flatten(out, 1)
      out = self.linear(out)
      return out

#III. Training and Evaluation

In [None]:
def Densenet_cifar():
    return DenseNet(DenseNetBottleneckLayer, [6, 12, 24, 16], growth_rate=12)

In [None]:
def DenseNet121():
    return DenseNet(DenseNetBottleneckLayer, [6,12,24,16], growth_rate=32, dataset_used="imagenet")

def DenseNet169():
    return DenseNet(DenseNetBottleneckLayer, [6,12,32,32], growth_rate=32, dataset_used="imagenet")

def DenseNet201():
    return DenseNet(DenseNetBottleneckLayer, [6,12,48,32], growth_rate=32, dataset_used="imagenet")

def DenseNet161():
    return DenseNet(DenseNetBottleneckLayer, [6,12,36,24], growth_rate=48, dataset_used="imagenet")

## 1. For CIFAR-10

In [None]:
transform_train_cifar = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [None]:
# TRAINING FUNCTION

def train(epoch, model, trainloader, optimizer, criterion, device):
    print(f'\nEpoch: {epoch}')
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx % 100 == 0:
            print(f'Epoch {epoch} | Batch {batch_idx}/{len(trainloader)} | Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})')

# EVALUATION FUNCTION

def evaluate_cifar(model, testloader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100. * correct / total
    print("\n--- Test Results ---")
    print(f"Average Loss: {test_loss / len(testloader):.4f}")
    print(f"Top-1 Accuracy: {acc:.2f}% ({correct}/{total})")
    print("--------------------\n")
    return acc

In [None]:
# DataLoaders
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train_cifar)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test_cifar)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# Use our from-scratch model
model = densenet_cifar().to(device)
print("Custom DenseNet model for CIFAR-10 created successfully.")

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Training Loop
for epoch in range(200):
    train(epoch, model, trainloader, optimizer, criterion, device)
    evaluate_cifar(model, testloader, criterion, device)
    scheduler.step()

## 2. For ImageNet

In [None]:
# DATA LOADING AND TRANSFORMATION

# ImageNet statistics
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

# Data augmentation for the training set
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

# Transformation for the validation set
transform_val = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

In [None]:
# TRAINING FUNCTION
def train_imagenet(epoch, model, trainloader, optimizer, criterion, device):
    print(f'\nEpoch: {epoch}')
    model.train()
    train_loss = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        total += targets.size(0)

        if batch_idx % 100 == 0:
            print(f'Epoch {epoch} | Batch {batch_idx}/{len(trainloader)} | Loss: {train_loss/(batch_idx+1):.3f}')

# EVALUATION FUNCTION (WITH TOP-1 AND TOP-5)
def evaluate_imagenet(model, valloader, criterion, device):
    model.eval()
    val_loss = 0
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(valloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            val_loss += loss.item()

            # Calculate Top-1 and Top-5 accuracy
            _, pred = outputs.topk(5, 1, largest=True, sorted=True)
            pred = pred.t()
            correct = pred.eq(targets.view(1, -1).expand_as(pred))

            correct_top1 += correct[:1].reshape(-1).float().sum(0, keepdim=True).item()
            correct_top5 += correct[:5].reshape(-1).float().sum(0, keepdim=True).item()
            total += targets.size(0)

    # Calculate final accuracies
    top1_acc = 100. * correct_top1 / total
    top5_acc = 100. * correct_top5 / total

    print("\n--- Validation Results ---")
    print(f"Average Loss: {val_loss / len(valloader):.4f}")
    print(f"Top-1 Accuracy: {top1_acc:.2f}% ({int(correct_top1)}/{total})")
    print(f"Top-5 Accuracy: {top5_acc:.2f}% ({int(correct_top5)}/{total})")
    print("--------------------------\n")

In [None]:
data_dir = './imagenette2'
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')

if not os.path.isdir(data_dir):
  print(f"Error: Dataset directory not found at '{data_dir}'")
  print("Please download a dataset like ImageNette and update the path.")

# DataLoaders
train_dataset = torchvision.datasets.ImageFolder(root=train_dir, transform=transform_train)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

val_dataset = torchvision.datasets.ImageFolder(root=val_dir, transform=transform_val)
valloader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=4)

# Use our from-scratch model, adjusting for the number of classes in the dataset
num_classes = len(train_dataset.classes)
model = DenseNet121(num_classes=num_classes).to(device)
print(f"Custom DenseNet-121 model for ImageNet created successfully with {num_classes} classes.")

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Training Loop
best_acc = 0.0
for epoch in range(90):
  train(epoch, model, trainloader, optimizer, criterion, device)
  evaluate_imagenet(model, valloader, criterion, device)
  scheduler.step()
  # if acc > best_acc:
  #     print("Saving new best model...")
  #     best_acc = acc
  #     torch.save(model.state_dict(), 'densenet_imagenet_scratch_best.pth')