- let's build a more complex Convolutional Neural Network (CNN) example using PyTorch, still on the CIFAR-10 dataset. This will allow us to see how adding more layers and techniques like Batch Normalization, Dropout, and Data Augmentation can improve performance and robustness.

- Complexity Additions in this Example:

- Deeper Architecture: We'll use more convolutional layers.
- Batch Normalization (nn.BatchNorm2d): Added after convolutional layers (before activation) to stabilize and accelerate training.
- Dropout (nn.Dropout): Added in the fully connected part to reduce overfitting.
- Data Augmentation: Simple augmentations will be applied to the training dataset to make the model more robust.
- Validation Loop: We'll split the training data to create a validation set and monitor performance on it during training.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

# --- 1. Device Configuration ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- 2. Hyperparameters ---
num_epochs = 25 # Increased epochs for a more complex model and data augmentation
batch_size = 128 # Can adjust based on GPU memory
learning_rate = 0.001
num_classes = 10 # CIFAR-10 has 10 classes
dropout_prob = 0.5

# --- 3. Load and Prepare CIFAR-10 Dataset with Data Augmentation ---
print("Loading CIFAR-10 dataset with augmentation...")

# Transformations for the training set (with augmentation)
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(), # Randomly flip images horizontally
    transforms.RandomCrop(32, padding=4), # Randomly crop images with padding
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Transformations for the validation and test set (only normalization)
transform_val_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Download and load the full training data (will be split)
full_train_dataset_aug = torchvision.datasets.CIFAR10(root='./data',
                                                      train=True,
                                                      transform=transform_train, # Apply training transforms
                                                      download=True)

# Create a separate dataset instance for validation with validation transforms
val_dataset_for_transform = torchvision.datasets.CIFAR10(root='./data',
                                                          train=True, # Still from the original training split
                                                          transform=transform_val_test, # Apply validation/test transforms
                                                          download=True)


# Download and load the test data
test_dataset_aug = torchvision.datasets.CIFAR10(root='./data',
                                                train=False,
                                                transform=transform_val_test) # Apply validation/test transforms

# Split full training data into training and validation sets
# We need to be careful here: apply different transforms to train and val splits.
# One way is to get indices from random_split and then create Subset datasets with different transforms.
# Or, simpler for this example: split, then wrap with DataLoaders that apply transforms if needed,
# but torchvision Datasets apply transform at __getitem__.
# So, we'll split the full_train_dataset_aug and then for validation, we'd ideally re-wrap the val_subset
# with a dataset object that uses transform_val_test.
# A common approach:
train_size = int(0.85 * len(full_train_dataset_aug)) # 85% for training
val_size = len(full_train_dataset_aug) - train_size   # 15% for validation

# Get indices for split
torch.manual_seed(42) # for reproducibility of split
train_indices, val_indices = random_split(range(len(full_train_dataset_aug)), [train_size, val_size])

# Create training dataset with augmentation
train_dataset_aug = torch.utils.data.Subset(full_train_dataset_aug, train_indices)

# Create validation dataset. For validation, we want the non-augmented version of these images.
# We use the 'val_dataset_for_transform' which has the correct non-augmenting transform.
val_dataset_aug = torch.utils.data.Subset(val_dataset_for_transform, val_indices)


# Create DataLoaders
train_loader_aug = DataLoader(dataset=train_dataset_aug, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader_aug = DataLoader(dataset=val_dataset_aug, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader_aug = DataLoader(dataset=test_dataset_aug, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Dataset loaded. Train samples: {len(train_dataset_aug)}, Validation samples: {len(val_dataset_aug)}, Test samples: {len(test_dataset_aug)}")
cifar10_classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# --- 4. Define a More Complex Convolutional Neural Network (CNN) Model ---
class ComplexCNN(nn.Module):
    def __init__(self, num_classes=10, dropout_prob=0.5):
        super(ComplexCNN, self).__init__()
        # Block 1
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32) # Batch Norm after Conv, before ReLU
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.relu2 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 32x32 -> 16x16
        self.dropout_conv1 = nn.Dropout2d(p=0.25) # Dropout for conv layers (spatial dropout)

        # Block 2
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.relu4 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 16x16 -> 8x8
        self.dropout_conv2 = nn.Dropout2d(p=0.25)

        # Block 3 (Optional, can add more)
        self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.relu5 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # 8x8 -> 4x4
        self.dropout_conv3 = nn.Dropout2d(p=0.25)


        # Fully Connected Layers
        # Input features: 128 channels * 4 width * 4 height = 128 * 4 * 4 = 2048
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.relu_fc1 = nn.ReLU()
        self.dropout_fc = nn.Dropout(dropout_prob) # Dropout before output layer
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        # Block 1
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.pool1(x)
        x = self.dropout_conv1(x)

        # Block 2
        x = self.relu3(self.bn3(self.conv3(x)))
        x = self.relu4(self.bn4(self.conv4(x)))
        x = self.pool2(x)
        x = self.dropout_conv2(x)
        
        # Block 3
        x = self.relu5(self.bn5(self.conv5(x)))
        x = self.pool3(x)
        x = self.dropout_conv3(x)

        # Fully Connected Part
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu_fc1(x)
        x = self.dropout_fc(x)
        x = self.fc2(x) # Logits
        return x

# --- 5. Instantiate the Model, Loss, and Optimizer ---
model = ComplexCNN(num_classes=num_classes, dropout_prob=dropout_prob).to(device)
print("\nModel Architecture:")
print(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Learning rate scheduler (optional, but often helpful for deeper models)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # Reduce LR every 7 epochs

# --- 6. Training Loop with Validation ---
print("\nStarting Training...")
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

for epoch in range(num_epochs):
    model.train()
    running_train_loss, n_correct_train, n_samples_train = 0.0, 0, 0
    for i, (images, labels) in enumerate(train_loader_aug):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_train_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs.data, 1)
        n_samples_train += labels.size(0)
        n_correct_train += (predicted == labels).sum().item()

    epoch_train_loss = running_train_loss / n_samples_train
    epoch_train_acc = 100.0 * n_correct_train / n_samples_train
    train_losses.append(epoch_train_loss)
    train_accuracies.append(epoch_train_acc)

    model.eval()
    running_val_loss, n_correct_val, n_samples_val = 0.0, 0, 0
    with torch.no_grad():
        for images_val, labels_val in val_loader_aug:
            images_val, labels_val = images_val.to(device), labels_val.to(device)
            outputs_val = model(images_val)
            loss_val = criterion(outputs_val, labels_val)
            running_val_loss += loss_val.item() * images_val.size(0)
            _, predicted_val = torch.max(outputs_val.data, 1)
            n_samples_val += labels_val.size(0)
            n_correct_val += (predicted_val == labels_val).sum().item()

    epoch_val_loss = running_val_loss / n_samples_val
    epoch_val_acc = 100.0 * n_correct_val / n_samples_val
    val_losses.append(epoch_val_loss)
    val_accuracies.append(epoch_val__acc)

    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}%, '
          f'Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.2f}%')
    # if scheduler: scheduler.step() # Step the LR scheduler

print("Finished Training.")

# --- 7. Evaluation on Test Set ---
print("\nStarting Evaluation on Test Set...")
model.eval()
all_labels_test_aug, all_predicted_test_aug = [], []
with torch.no_grad():
    n_correct_test, n_samples_test = 0, 0
    for images_test, labels_test in test_loader_aug:
        images_test, labels_test = images_test.to(device), labels_test.to(device)
        outputs_test = model(images_test)
        _, predicted_test = torch.max(outputs_test.data, 1)
        n_samples_test += labels_test.size(0)
        n_correct_test += (predicted_test == labels_test).sum().item()
        all_labels_test_aug.extend(labels_test.cpu().numpy())
        all_predicted_test_aug.extend(predicted_test.cpu().numpy())

accuracy_test_aug = 100.0 * n_correct_test / n_samples_test
print(f'Accuracy of the complex CNN on test images: {accuracy_test_aug:.2f} %')

# --- 8. Confusion Matrix and Classification Report for Test Set ---
print("\nConfusion Matrix (Complex CNN Test Set):")
cm_cnn_test_aug = confusion_matrix(all_labels_test_aug, all_predicted_test_aug)
plt.figure(figsize=(10, 8))
sns.heatmap(cm_cnn_test_aug, annot=True, fmt="d", cmap="YlGnBu",
            xticklabels=cifar10_classes, yticklabels=cifar10_classes)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix - Complex CNN (CIFAR-10 Test Set)")
plt.show()

print("\nClassification Report (Complex CNN Test Set):")
print(classification_report(all_labels_test_aug, all_predicted_test_aug, target_names=cifar10_classes))

# --- 9. Plot Training and Validation Loss and Accuracy ---
epochs_range = range(1, num_epochs + 1)
plt.figure(figsize=(14, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_losses, 'bo-', label='Training Loss')
plt.plot(epochs_range, val_losses, 'ro-', label='Validation Loss')
plt.title('Training and Validation Loss (Complex CNN)')
plt.xlabel('Epochs'); plt.ylabel('Loss (CrossEntropy)'); plt.legend(); plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_accuracies, 'bs-', label='Training Accuracy')
plt.plot(epochs_range, val_accuracies, 'rs-', label='Validation Accuracy')
plt.title('Training and Validation Accuracy (Complex CNN)')
plt.xlabel('Epochs'); plt.ylabel('Accuracy (%)'); plt.legend(); plt.grid(True)
plt.tight_layout()
plt.show()


# PyTorch "More Complex" CNN: Key Enhancements for CIFAR-10 Classification

This document details significant enhancements applied to a Convolutional Neural Network (CNN) for improved performance and robustness on the CIFAR-10 image classification task. These changes incorporate common techniques used in modern deep learning.

## 1. Data Augmentation (`transform_train`)
- **Purpose**: To artificially increase the diversity of the training dataset, which helps the model learn more invariant features and generalize better to unseen data, thereby reducing overfitting.
- **Techniques Applied (Only to the Training Set)**:
    - **`transforms.RandomHorizontalFlip()`**: Randomly flips images horizontally with a certain probability (typically 0.5) during each training epoch. This is effective for datasets where horizontal orientation doesn't change the class label (e.g., a cat is still a cat if flipped horizontally).
    - **`transforms.RandomCrop(32, padding=4)`**:
        - First, the 32x32 input image is padded with 4 pixels on each side (making it 40x40).
        - Then, a random 32x32 crop is taken from this padded image.
        - This introduces variations in the object's position and scale within the image, forcing the model to be less sensitive to exact object placement.
- **Validation/Test Set Transforms**: The validation and test sets **do not** use these random augmentations. They only use `transforms.ToTensor()` and `transforms.Normalize()` to ensure consistent evaluation.

## 2. Deeper and More Sophisticated Architecture (`ComplexCNN` class)
- **Increased Depth**: The network architecture features three convolutional blocks, allowing it to learn a hierarchy of features from simple to more complex.
- **Typical Block Structure**: Each convolutional block generally follows a pattern:
    - One or two `nn.Conv2d` layers.
    - `nn.BatchNorm2d` layer after each convolution.
    - `nn.ReLU` activation function.
    - `nn.MaxPool2d` layer at the end of the block (or sometimes `nn.Dropout2d` before or after pooling).
- **Increasing Filter Depth**: A common design pattern is to increase the number of filters (output channels) in deeper convolutional layers (e.g., starting with 32 filters, then 64, then 128).
    - **Rationale**: Early layers detect low-level features (edges, corners, textures). As the network goes deeper, these features are combined to form more abstract and complex patterns. Increasing the number of filters provides more capacity for these higher-level representations.

## 3. Batch Normalization (`nn.BatchNorm2d`)
- **Implementation**: `nn.BatchNorm2d(num_features)` layers are inserted after each convolutional layer, typically *before* the ReLU activation function. `num_features` corresponds to the number of output channels from the preceding convolutional layer.
- **Benefits**:
    - **Stabilizes Training**: Normalizes the activations of the previous layer by re-centering them to have a mean of 0 and re-scaling them to have a standard deviation of 1 for each channel across the current mini-batch. This helps combat the "internal covariate shift" problem.
    - **Allows Higher Learning Rates**: Makes the model less sensitive to the scale of parameters and activations, often permitting the use of higher learning rates, which can speed up convergence.
    - **Reduces Sensitivity to Initialization**: Training deep networks becomes less dependent on careful weight initialization.
    - **Acts as a Regularizer**: Has a slight regularization effect, sometimes reducing the need for other forms of regularization like Dropout (though they are often used together).
    - **During `model.eval()`**: Batch Normalization uses the learned running mean and variance (accumulated during training) instead of mini-batch statistics for normalization.

## 4. Dropout (`nn.Dropout2d` and `nn.Dropout`)
- **Purpose**: A powerful regularization technique to prevent overfitting by randomly deactivating units (neurons or channels) during training.
- **Types Used**:
    - **`nn.Dropout2d(p=dropout_probability)` (Spatial Dropout)**:
        - Applied typically after pooling layers or between convolutional blocks.
        - Instead of dropping individual elements, `nn.Dropout2d` randomly zeros out entire *channels* (feature maps) from its input.
        - This is often considered more effective for convolutional layers because features in feature maps are spatially correlated. Dropping entire channels encourages the network to learn more diverse and less co-dependent feature representations.
    - **`nn.Dropout(p=dropout_probability)` (Standard Dropout)**:
        - Applied before the final fully connected classification layer(s).
        - Randomly zeros out individual elements (neurons) of its input tensor.
- **Behavior**: Dropout is only active during `model.train()` mode. It is automatically disabled during `model.eval()` mode, ensuring deterministic output for inference.

## 5. Enhanced Training Workflow with a Validation Set
- **Data Split**: The original training dataset is partitioned into a (new, smaller) training subset and a validation subset (e.g., using `torch.utils.data.random_split`).
- **Integrated Validation Loop**:
    - After each epoch of training on the training subset (with `model.train()` active):
        1.  The model is switched to evaluation mode: `model.eval()`. This ensures Dropout is turned off and Batch Normalization uses its running statistics.
        2.  Performance metrics (typically loss and accuracy) are computed on the validation set using a `DataLoader` for the validation data.
        3.  Crucially, these computations are done within a `with torch.no_grad():` block to disable gradient calculations, saving memory and computation.
- **Monitoring for Overfitting**: This process allows for continuous monitoring of the model's ability to generalize to unseen data.
    - If training loss/accuracy continues to improve but validation loss/accuracy stagnates or worsens, it's a strong indicator of overfitting. This information can be used for decisions like early stopping or adjusting hyperparameters.

## 6. Increased Number of Training Epochs
- **Rationale**: More complex models with data augmentation and regularization often