LeNet with and without CBAM

In [None]:
pip install torchmetrics


Collecting torchmetrics
  Downloading torchmetrics-1.6.0-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)
Downloading torchmetrics-1.6.0-py3-none-any.whl (926 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m926.4/926.4 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.9-py3-none-any.whl (28 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.9 torchmetrics-1.6.0


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
from tqdm import tqdm

# Define the CBAM (Convolutional Block Attention Module)
class CBAM(nn.Module):
    def __init__(self, channels, reduction=16):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(channels, reduction)
        self.spatial_attention = SpatialAttention()

    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x

class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_pool = F.adaptive_avg_pool2d(x, 1)
        max_pool = F.adaptive_max_pool2d(x, 1)
        avg_out = self.fc2(F.relu(self.fc1(avg_pool)))
        max_out = self.fc2(F.relu(self.fc1(max_pool)))
        out = avg_out + max_out
        return x * self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        max_pool, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_pool, max_pool], dim=1)
        return x * self.sigmoid(self.conv(x_cat))

# LeNet model with and without CBAM
class LeNet(nn.Module):
    def __init__(self, use_cbam=False):
        super(LeNet, self).__init__()
        self.use_cbam = use_cbam

        # LeNet layers
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.conv3 = nn.Conv2d(16, 120, kernel_size=5)
        self.fc1 = nn.Linear(120 * 2 * 2, 84)  # Corrected input size for flattened tensor
        self.fc2 = nn.Linear(84, 100)  # CIFAR-100 has 100 classes

        # CBAM module (applied after final convolution layer)
        if self.use_cbam:
            self.cbam = CBAM(120)
        else:
            self.cbam = None

    def forward(self, x):
        # Apply first convolutional layer
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        print(f"After conv1: {x.shape}")  # Debugging print statement

        # Apply second convolutional layer
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        print(f"After conv2: {x.shape}")  # Debugging print statement

        # Apply third convolutional layer
        x = F.relu(self.conv3(x))
        print(f"After conv3: {x.shape}")  # Debugging print statement

        # Apply CBAM after the final convolution if enabled
        if self.cbam:
            x = self.cbam(x)
        print(f"After CBAM (if applied): {x.shape}")  # Debugging print statement

        # Flatten the tensor before passing it to the fully connected layers
        x = x.view(x.size(0), -1)  # Flatten the tensor dynamically
        print(f"After flatten: {x.shape}")  # Debugging print statement

        # Pass through the fully connected layers
        x = F.relu(self.fc1(x))
        print(f"After fc1: {x.shape}")  # Debugging print statement
        x = self.fc2(x)
        print(f"After fc2: {x.shape}")  # Debugging print statement
        return x

# Load CIFAR-100 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
testloader = DataLoader(testset, batch_size=100, shuffle=False)

# Training function
def train(model, trainloader, criterion, optimizer, device, num_epochs=20):
    best_acc = 0
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Calculate training accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = 100 * correct / total
        print(f"Epoch {epoch+1}/{num_epochs}, Training Accuracy: {train_acc:.2f}%")

        # Update best accuracy
        if train_acc > best_acc:
            best_acc = train_acc
        print(f"Best Accuracy so far: {best_acc:.2f}%")

    print("Finished Training")

# Testing function for top-1 and top-5 accuracy
def test(model, testloader, device):
    model.eval()
    top1 = 0
    top5 = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(testloader, desc="Testing"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.topk(outputs, 5, dim=1)
            correct1 = (preds[:, 0] == labels).sum().item()
            correct5 = (preds == labels.view(-1, 1).expand_as(preds)).sum().item()
            total += labels.size(0)
            top1 += correct1
            top5 += correct5

    top1_accuracy = 100 * top1 / total
    top5_accuracy = 100 * top5 / total
    print(f"Top-1 Accuracy: {top1_accuracy:.2f}%")
    print(f"Top-5 Accuracy: {top5_accuracy:.2f}%")

# Initialize models, loss function, and optimizers
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# LeNet with CBAM
model_with_cbam = LeNet(use_cbam=True)
optimizer_with_cbam = optim.Adam(model_with_cbam.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# LeNet without CBAM
model_without_cbam = LeNet(use_cbam=False)
optimizer_without_cbam = optim.Adam(model_without_cbam.parameters(), lr=0.001)

# Train LeNet with CBAM
print("Training LeNet with CBAM...")
train(model_with_cbam, trainloader, criterion, optimizer_with_cbam, device, num_epochs=20)

# Train LeNet without CBAM
print("Training LeNet without CBAM...")
train(model_without_cbam, trainloader, criterion, optimizer_without_cbam, device, num_epochs=20)

# Test both models
print("\nEvaluating LeNet with CBAM:")
test(model_with_cbam, testloader, device)

print("\nEvaluating LeNet without CBAM:")
test(model_without_cbam, testloader, device)


Output hidden; open in https://colab.research.google.com to view.