- CIFAR-10 consists of 32x32 color images in 10 different classes (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck). This dataset is more complex than MNIST or FashionMNIST, making it a good candidate for demonstrating the power of CNNs over simple ANNs for image tasks.

- This example will cover loading CIFAR-10, defining a basic CNN architecture, training, and evaluation.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F # For activation functions if not using nn.Module versions

# --- 1. Device Configuration ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- 2. Hyperparameters ---
num_epochs = 10 # Increase epochs for a more complex dataset/model
batch_size = 64 # Smaller batch size might be needed depending on GPU memory
learning_rate = 0.001
num_classes = 10 # CIFAR-10 has 10 classes

# --- 3. Load and Prepare CIFAR-10 Dataset ---
print("Loading CIFAR-10 dataset...")
# CIFAR-10 images are 3x32x32 (3 color channels)
# We need to normalize them. The means and stds are commonly used values for CIFAR-10.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Mean/Std for CIFAR-10
])

# Download and load the training data
train_dataset = torchvision.datasets.CIFAR10(root='./data',
                                             train=True,
                                             transform=transform,
                                             download=True)

# Download and load the test data
test_dataset = torchvision.datasets.CIFAR10(root='./data',
                                            train=False,
                                            transform=transform)

# Create DataLoaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) # num_workers can speed up loading
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Dataset loaded. Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
cifar10_classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# --- 4. Define the Convolutional Neural Network (CNN) Model ---
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        # Convolutional Layer 1
        # Input: 3 channels (RGB), Output: 16 channels, Kernel size: 3x3, Padding: 1 (to maintain size)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        # Pooling Layer 1
        # Kernel size: 2x2, Stride: 2 (downsamples by factor of 2)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # After pool1: Image size becomes 16x16 (32/2 = 16)

        # Convolutional Layer 2
        # Input: 16 channels, Output: 32 channels, Kernel size: 3x3, Padding: 1
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        # Pooling Layer 2
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # After pool2: Image size becomes 8x8 (16/2 = 8)

        # Fully Connected Layer 1
        # Input features: 32 channels * 8 width * 8 height = 32 * 8 * 8 = 2048
        self.fc1 = nn.Linear(32 * 8 * 8, 256)
        self.relu3 = nn.ReLU()
        # Fully Connected Layer 2 (Output Layer)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        # Input shape: (batch_size, 3, 32, 32)
        out = self.conv1(x)       # -> (batch_size, 16, 32, 32)
        out = self.relu1(out)
        out = self.pool1(out)     # -> (batch_size, 16, 16, 16)

        out = self.conv2(out)     # -> (batch_size, 32, 16, 16)
        out = self.relu2(out)
        out = self.pool2(out)     # -> (batch_size, 32, 8, 8)

        # Flatten the output for the fully connected layers
        # Reshape from (batch_size, 32, 8, 8) to (batch_size, 32*8*8)
        out = out.view(out.size(0), -1) # -> (batch_size, 2048)

        out = self.fc1(out)       # -> (batch_size, 256)
        out = self.relu3(out)
        out = self.fc2(out)       # -> (batch_size, 10) (Logits)
        return out

# --- 5. Instantiate the Model, Loss, and Optimizer ---
model = SimpleCNN(num_classes=num_classes).to(device)
print("\nModel Architecture:")
print(model)

# Loss Function: CrossEntropyLoss for multi-class classification
criterion = nn.CrossEntropyLoss()

# Optimizer: Adam
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# --- 6. Training Loop ---
print("\nStarting Training...")
n_total_steps = len(train_loader)
train_losses = []
train_accuracies = [] # Optional: track training accuracy per epoch

for epoch in range(num_epochs):
    model.train() # Set model to training mode
    running_loss = 0.0
    n_correct_train = 0
    n_samples_train = 0

    for i, (images, labels) in enumerate(train_loader):
        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

        # Calculate training accuracy
        _, predicted_train = torch.max(outputs.data, 1)
        n_samples_train += labels.size(0)
        n_correct_train += (predicted_train == labels).sum().item()

        if (i+1) % 100 == 0: # Print progress every 100 steps
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')

    epoch_loss = running_loss / n_samples_train
    epoch_acc = 100.0 * n_correct_train / n_samples_train
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_acc)
    print(f'--- Epoch {epoch+1} Summary --- Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}% ---')

print("Finished Training.")

# --- 7. Evaluation Loop (Testing) ---
print("\nStarting Evaluation on Test Set...")
model.eval() # Set model to evaluation mode
with torch.no_grad():
    n_correct_test = 0
    n_samples_test = 0
    all_labels_test = []
    all_predicted_test = []
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)

        _, predicted = torch.max(outputs.data, 1)
        n_samples_test += labels.size(0)
        n_correct_test += (predicted == labels).sum().item()
        
        # Store labels and predictions for confusion matrix/report
        all_labels_test.extend(labels.cpu().numpy())
        all_predicted_test.extend(predicted.cpu().numpy())


    accuracy_test = 100.0 * n_correct_test / n_samples_test
    print(f'Accuracy of the network on the {len(test_dataset)} test images: {accuracy_test:.2f} %')

    # --- 8. Confusion Matrix and Classification Report ---
    print("\nConfusion Matrix (CNN Test Set):")
    cm_cnn_test = confusion_matrix(all_labels_test, all_predicted_test)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_cnn_test, annot=True, fmt="d", cmap="Blues",
                xticklabels=cifar10_classes, yticklabels=cifar10_classes)
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix - CNN (CIFAR-10 Test Set)")
    plt.show()

    print("\nClassification Report (CNN Test Set):")
    print(classification_report(all_labels_test, all_predicted_test, target_names=cifar10_classes))


# --- 9. Plot Training Loss ---
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, marker='o', label='Training Loss (CrossEntropy)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss per Epoch (CNN - CIFAR-10)')
plt.legend()
plt.grid(True)
plt.show()


# PyTorch CIFAR-10 CNN Classifier: Code Explanation

This document explains a Python script that implements a Convolutional Neural Network (CNN) using PyTorch to classify images from the CIFAR-10 dataset.

## 1. Dataset: CIFAR-10
- **Source**: `torchvision.datasets.CIFAR10`.
- **Nature**: A dataset of 60,000 32x32 color images in 10 classes (e.g., airplane, automobile, bird, cat), with 6,000 images per class.
- **Normalization**:
    - Pixel values are transformed to PyTorch tensors and then normalized.
    - The normalization uses mean and standard deviation values specific to the CIFAR-10 dataset, applied per color channel (Red, Green, Blue). This helps stabilize training and improve model performance.
- **`DataLoader`**:
    - `num_workers`: This parameter in `DataLoader` can be set to a value greater than 0 to use multiple subprocesses for data loading. This can significantly speed up the data pipeline, especially if preprocessing is non-trivial, by loading data in parallel while the GPU is busy with computations.

## 2. Model Definition (`SimpleCNN` class)
- **Core Components**: The network primarily uses `nn.Conv2d` for convolutional operations and `nn.MaxPool2d` for pooling.
- **`nn.Conv2d` (Convolutional Layers)**:
    - **Purpose**: Apply learnable filters to input image data to extract features like edges, textures, and patterns.
    - **Key Arguments**:
        - `in_channels (int)`: Number of channels in the input image (e.g., 3 for RGB color images like CIFAR-10, 1 for grayscale).
        - `out_channels (int)`: Number of filters (also known as kernels or feature detectors) to be applied. This determines the depth (number of channels) of the output feature map.
        - `kernel_size (int or tuple)`: Specifies the dimensions of the convolutional filter (e.g., `3` for a 3x3 filter, or `(3, 5)` for a 3x5 filter).
        - `stride (int or tuple, optional)`: The step size with which the filter moves across the input image (default is 1).
        - `padding (int or tuple, optional)`: Adds a specified number of pixels around the border of the input image. This can be used to control the spatial dimensions of the output feature map. For instance, `padding=1` with a `kernel_size=3` and `stride=1` often preserves the input width and height.
- **Activation Function (`nn.ReLU`)**:
    - Applied after each convolutional layer to introduce non-linearity, enabling the network to learn more complex features.
- **`nn.MaxPool2d` (Max Pooling Layers)**:
    - **Purpose**: Downsamples the feature maps, reducing their spatial dimensions (width and height) while retaining the most prominent features. This helps to reduce computational complexity, control overfitting, and create some translation invariance.
    - **Key Arguments**:
        - `kernel_size (int or tuple)`: The size of the window over which to take the maximum.
        - `stride (int or tuple, optional)`: The step size of the window. Often, `stride` is set equal to `kernel_size` (e.g., `kernel_size=2, stride=2`) to achieve non-overlapping pooling, effectively halving the input dimensions.
- **Flattening Operation**:
    - **Necessity**: Before connecting the output of the convolutional/pooling layers to fully connected (dense) layers, the multi-dimensional feature maps need to be flattened into a 1D vector for each sample in the batch.
    - **Implementation**: `out.view(out.size(0), -1)` reshapes the tensor.
        - `out.size(0)` gets the batch size.
        - `-1` infers the remaining dimensions, effectively collapsing all feature dimensions (channels, height, width) into a single long vector.
        - If the output of the last pooling layer is `[batch_size, num_output_channels, final_height, final_width]`, it gets flattened to `[batch_size, num_output_channels * final_height * final_width]`.
- **`nn.Linear` (Fully Connected Layers)**:
    - Used after the flattening operation, similar to their use in ANNs/MLPs, to perform classification based on the extracted features.

## 3. Training and Evaluation Loops
- **Structural Similarity**: The overall structure of the training and evaluation loops remains very similar to those used for the ANN/MLP examples. This includes:
    - Iterating through epochs and batches.
    - Moving data to the configured device.
    - Zeroing gradients (`optimizer.zero_grad()`).
    - Performing the forward pass (`outputs = model(images)`).
    - Calculating the loss (`loss = criterion(outputs, labels)`).
    - Performing the backward pass (`loss.backward()`).
    - Updating model parameters (`optimizer.step()`).
    - Using `model.train()` and `model.eval()` modes appropriately.
    - Using `with torch.no_grad():` during evaluation.
- **Key Difference from ANN Input Handling**:
    - Unlike the ANN examples where input images were explicitly flattened in the `DataLoader` or at the beginning of the `forward` method, CNNs process images in their 2D (or 3D for channels) spatial format directly.
    - The `forward` method of the CNN handles the spatial dimensions through its convolutional and pooling layers. The flattening step is an internal part of the CNN's `forward` method, specifically before transitioning to the fully connected layers.

## 4. Performance
- **Advantage of CNNs**: A simple CNN, like the one described, is expected to achieve significantly better classification accuracy on image datasets like CIFAR-10 compared to a basic MLP/ANN that only uses flattened pixel inputs.
- **Reason**: CNNs are designed to effectively capture spatial hierarchies and local patterns in images through their use of convolutions (shared weights, local receptive fields) and pooling.
- **State-of-the-Art**: It's important to note that while this `SimpleCNN` demonstrates the core concepts, achieving state-of-the-art results on CIFAR-10 typically requires much deeper and more sophisticated CNN architectures (e.g., ResNets, DenseNets, Vision Transformers).

## Summary
This example provides a fundamental but complete illustration of how to construct, train, and evaluate a Convolutional Neural Network using PyTorch for an image classification task. It highlights the specific layers and operations (Conv2d, MaxPool2d, flattening) that distinguish CNNs from simpler feedforward networks and showcases their suitability for processing image data.