# DenseNet Architecture in PyTorch

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

## Introduction to DenseNet

DenseNet (Densely Connected Convolutional Network) is an architecture where each layer is connected to every other layer in a feed-forward fashion.

In [None]:
# Load a pre-trained DenseNet121
densenet = models.densenet121(pretrained=True)
print(densenet)

## Key Components of DenseNet

In [None]:
# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

print(f'DenseNet121 has {count_parameters(densenet) / 1e6:.1f}M parameters')

## Custom DenseNet Block

In [None]:
class DenseBlock(nn.Module):
    """A dense block consisting of batch norm, relu, and convolution"""
    def __init__(self, in_channels, growth_rate):
        super(DenseBlock, self).__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = self.bn(x)
        x = self.relu(x)
        x = self.conv(x)
        return x

class DenseSequential(nn.Module):
    """Concatenates outputs from multiple dense blocks"""
    def __init__(self, num_blocks, in_channels, growth_rate):
        super(DenseSequential, self).__init__()
        self.blocks = nn.ModuleList()
        for i in range(num_blocks):
            self.blocks.append(DenseBlock(in_channels + i * growth_rate, growth_rate))
    
    def forward(self, x):
        features = [x]
        for block in self.blocks:
            new_feature = block(torch.cat(features, 1))
            features.append(new_feature)
        return torch.cat(features, 1)

# Test the dense block
dense_block = DenseSequential(num_blocks=4, in_channels=64, growth_rate=32)
x = torch.randn(2, 64, 32, 32)
out = dense_block(x)
print(f'Input shape: {x.shape}')
print(f'Output shape: {out.shape}')

## Transfer Learning with DenseNet

In [None]:
# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                       std=[0.5, 0.5, 0.5])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, 
                                  download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, 
                                 download=True, transform=transform)

# Use a subset for quick training
train_dataset.data = train_dataset.data[:5000]
train_dataset.targets = train_dataset.targets[:5000]
test_dataset.data = test_dataset.data[:1000]
test_dataset.targets = test_dataset.targets[:1000]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Modify the final layer for CIFAR-10 (10 classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = models.densenet121(pretrained=True).to(device)

# Replace the final classifier
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, 10)
model = model.to(device)

print(f'Model on device: {next(model.parameters()).device}')

In [None]:
# Train for a few epochs
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 3
train_losses = []
val_accs = []

for epoch in range(epochs):
    # Training
    model.train()
    train_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    train_loss /= len(train_loader)
    train_losses.append(train_loss)
    
    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_acc = 100 * correct / total
    val_accs.append(val_acc)
    
    print(f'Epoch {epoch+1}/{epochs}, Loss: {train_loss:.4f}, Val Acc: {val_acc:.2f}%')