<a href="https://colab.research.google.com/github/ZicoDiegoRR/cifar10-93plus-workflow/blob/main/cifar_10_workflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CIFAR-10 Custom CNN Architecture with 91+% Test Accuracy**

In this notebook, we'll explore the power of CNN in classifying images based on 10 defined labels, such as:

1.   Airplane
2.   Automobile
3.   Bird
4.   Car
5.   Deer
6.   Dog
7.   Frog
8.   Horse
9.   Ship
10.  Truck

<br>

We'll train a model using our own custom neural network architecture with `nn.Conv2d`, `nn.BatchNorm2d`, `nn.SiLU`, `nn.AdaptiveAvgPool2d`, `nn.Linear`, and `nn.Dropout` modules.

Before we continue, it would be helpful if you read this quick summary about those modules in action so you won't be clueless about the architecture. Let's break them down!

- `nn.Conv2d` processes tensor images by applying a 2D convolution over a quantized 2D input. In short, this is one of the most important network to have when working with CNN.
- `nn.BatchNorm2d` normalizes and stabilizes the training by applying batch normalization. There's a dedicated formula for this, but this isn't a research paper.
- `nn.SiLU` applies Sigmoid Linear Unit to the gradients. This is especially helpful for non-linear patterns and more powerful than `nn.ReLU()` in most cases.
- `nn.AdaptiveAvgPool2d` applies a 2D average pooling over an input signal. This works similarly to `nn.MaxPool2d`, but this finds the average value instead of the biggest value.
- `nn.Linear` applies a linear transformation to the input.
- `nn.Dropout` randomly zeroes the weight to force the model to not strictly rely on defined networks, enhancing generalization.

Now that we've covered all of them, let's go train our model, shall we?
(GPU acceleration is recommended to continue)

## Step 1: Import the mandatory libraries and modules

In [None]:
# Built-in module
import os

# The powerhouse libraries
import torch
import torchvision

# To-shorten-the-function modules
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# For visualization
import matplotlib.pyplot as plt

## Step 2: Let's define the path and the device for our model

In [None]:
# This is where we'll save our model
weight_save_path = "/content/training/CIFAR-10"
os.makedirs(weight_save_path, exist_ok=True)

# GPU or CPU?
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

## Step 3: Now we can download the dataset and compose the transformation method

In [None]:
# Get the mean and the standard deviation, very useful to normalize the data
def get_mean_std():
    # Still a simple transformation method
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    # Now we download and transform the images into tensors
    cifar10_dataset = torchvision.datasets.CIFAR10(
        root="/content/training",
        train=True,
        download=True,
        transform=transform
    )

    # After downloading, we can load them
    cifar10_load = DataLoader(
        cifar10_dataset,
        shuffle=False,
        num_workers=1,
        batch_size=50000,
    )

    # Let's obtain the image data
    data = next(iter(cifar10_load))[0]

    # Mean and standard deviation can be calculated here
    mean = torch.mean(data, dim=(0, 2, 3))
    std = torch.std(data, dim=(0, 2, 3))

    # Return the values
    return mean, std

# We get the mean and the standard deviation values
mean, std = get_mean_std()

# Now we create another transformation method with normalization
transform_flow = [
    transforms.ToTensor(), # Convert the data into tensors
    transforms.Normalize((mean), (std)) # Normalize the data
]

# Let's add separate transformation methods with data augmentation for the training data (better for generalization)
train_transform, test_transform = transforms.Compose(
    [
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        # The code ColorJitter applies random color jitter, like brightness, contrast, saturation, and hue commonly found in image editor.
        transforms.RandomCrop(32, padding=4), # Crop the images randomly
        transforms.RandomHorizontalFlip(), # Flip them
    ] + transform_flow
), transforms.Compose(transform_flow)

## Step 4: Let's prepare the data for training and testing

In [None]:
# Download the training data
train_dataset = torchvision.datasets.CIFAR10(
    root="/content/training", # The path to put the images
    train=True, # This argument is needed to determine if the data is for training or not (helpful for predefined datasets)
    download=True, # Enable downloading
    transform=train_transform # Applying transformation with data augmentation
)

# Same drill like above, but for validation
test_dataset = torchvision.datasets.CIFAR10(
    root="/content/training",
    train=False,
    download=True,
    transform=test_transform # Applying transformation without data augmentation
)

## Step 5: It's time to load the data

In [None]:
# Loading the data
train_tensor = DataLoader(
    train_dataset, # The data for training
    shuffle=True, # Optional, but shuffling sometimes increases the accuracy
    num_workers=1, # Counting how many subprocesses to load the data
    batch_size=128 # How many images per batch (imagine slicing a pizza in different size)
)

# Same drill as above, but for validation
test_tensor = DataLoader(
    test_dataset, # The data for validating
    shuffle=False,
    num_workers=1,
    batch_size=128
)

## Step 6: The most confusing part — let's define our neural network

Here, you can experiment with the layers and/or the values. But, keep in mind that modifying them usually leads to instability, inefficent parameters (too many parameters for minimal gain), or longer training time.

In [None]:

# This is the neural network for residual network
class ResNetBlock(nn.Module):
    def __init__(self, input, out, reduce=False): # `reduce` triggers downsampling process
        super(ResNetBlock, self).__init__()

        # This is the first convolution stack
        self.first_layer = nn.Sequential(
            nn.Conv2d(
                input,
                out,
                kernel_size=3,
                padding=1,
                bias=False,
                stride=1 if not reduce else 2
                # Stride value 2 downsamples the input for less computation
            ),
            nn.BatchNorm2d(out),
            nn.SiLU(),
        )

        # This is the second convolution stack
        self.second_layer = nn.Sequential(
            nn.Conv2d(
                out,
                out,
                kernel_size=3,
                padding=1,
                bias=False,
                stride=1
            ),
            nn.BatchNorm2d(out),
            nn.SiLU(),
        )

        # This is the skip connection
        self.skip = nn.Conv2d(
            input, out, kernel_size=1, bias=False, stride=1 if not reduce else 2,
        ) if input != out else nn.Identity()

    # The forward pass
    def forward(self, x):
        # Computing the input through a convolution layer
        res = self.skip(x)

        # Feeding the input to the hidden layers
        out = self.second_layer(self.first_layer(x))

        # Adding the input to the output (skip connection)
        return out + res

# This is the main neural network
class CIFAR_10(nn.Module):
    def __init__(self):
        super(CIFAR_10, self).__init__()

        # This is a convolution stack
        self.conv_layer = nn.Sequential(
            ResNetBlock(3, 32),
            ResNetBlock(32, 64),
            ResNetBlock(64, 128, reduce=True),
            ResNetBlock(128, 256, reduce=True),
            ResNetBlock(256, 512),
            ResNetBlock(512, 1024),
        )

        # This is a linear stack with AdaptiveAvgPool2d and Flatten
        self.classify_layer = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            self.lin_layer(1024, 512, dropout=0.2),
            self.lin_layer(512, 256),
            self.lin_layer(256, 128, dropout=0.2),
            self.lin_layer(128, 64),
            nn.Linear(64, 10)
        )

        # We use two skip connections in this example
        self.skips = nn.ModuleList([
            self.resnet(3, 128, stride=2),
            self.resnet(3, 1024, stride=4)
        ])


    # For encapsulation
    def lin_layer(self, input, out, dropout=0.0):
        dropout = nn.Dropout(dropout) if dropout != 0.0 else nn.Identity()
        return nn.Sequential(
            nn.Linear(input, out), dropout,
        )

    # For encapsulation
    def resnet(self, in_channels, out_channels, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.SiLU(),
        )

    # The forward pass
    def forward(self, x):
        # Here, we add the first skip connection to the first three convolution stack
        first_down = self.conv_layer[:3](x) + self.skips[0](x)

        # Same as above, but we use the second skip connection and the last three convolution stack
        second_down = self.conv_layer[3:](first_down) + self.skips[1](x)

        # We classify the weight and return the prediction
        return self.classify_layer(second_down)

# This is our model loaded into the available device
cifar10_model = CIFAR_10().to(device)

## Step 7: For performance and quality, let's tune our hyperparameters

In [None]:
# You can change the hyperparameters, but beware of instability!
# SGD optimizes the training by optimizing the gradients. This isn't an adaptive learning, so the model
# will converge slowly because of fixed learning rate
optimizer = optim.SGD(
    cifar10_model.parameters(),
    lr=0.01, # Custom fixed learning rate
    weight_decay=10**(-6), # Penalizing weights to encourage generalization
    momentum=0.9 # Smoothes the gradient updates
)

epochs = 100 # Defines how many times to repeat the training
loss_hist = [] # Saves the losses and the accuracy for visualization

# Learning rate scheduler optimizes and modifies the learning rate to make updates optimal
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, # The optimizer
    T_max=epochs, # This tells the scheduler the amount of steps so that
                  # the learning rate scheduler can decay the learning rate gradually
)

# This is the loss function using CrossEntropyLoss for classification task
# Using label smoothing smoothes the loss calculation for precise accuracy
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)

# This section covers the early stopping variables
wait = 0 # This is the wait time before triggering the halt
tolerance = 10 # Determines how many iterations needed before halting the training
best_test_loss = np.inf # Saves the best test loss during training
early_stopping = False # Toggleable, this stops the training when the test loss stops improving to save you some time
hopeful_early_stopping = True # Toggleable, but needs early_stopping variable set to True to work.
# 'hopeful_early_stopping' lets the early stopping mechanism to only allow consecutive increase in test loss.

## Final step: What are we waiting for? Let's train it!

In [None]:
for epoch in range(epochs): # To repeat the training
    print(f"\nEpoch: {epoch+1}")
    cifar10_model.train() # This enables 'training' mode for the model

    running_train_loss = 0.0 # This is to calculate the training loss. Initial value is zero.
    for images, labels in train_tensor: # We get the image tensors and the labels
        optimizer.zero_grad() # We reset the gradient for fresh epochs
        output = cifar10_model(images.to(device)) # Feeding the data into the model

        loss = loss_fn(output, labels.to(device)) # Calculating the training loss
        loss.backward() # Backpropagation to calculate the gradient again
        optimizer.step() # The optimizer will optimize the learning
        running_train_loss += loss.item() * images.size(0) # We add the loss to the previous value because we work with multiple batches

    epoch_train_loss = running_train_loss / len(train_tensor.dataset) # Calculating the average loss per batch

    cifar10_model.eval() # Turning on the `evaluation` mode for the model
    total = 0 # This is to count how many images are there
    correct = 0 # And, this is to count how many times the model classify the images correctly
    running_test_loss = 0.0 # This is to count the test loss. Zero is the initial value
    with torch.no_grad(): # Disabling gradient calculation since we're not training the model at this phase
        for images, labels in test_tensor: # We get the image tensors and the labels
            images, labels = images.to(device), labels.to(device) # Moving the tensors and the labels to the correct device to avoid error
            test_output = cifar10_model(images) # Now, we classify again with unseen data in the training

            test_loss = loss_fn(test_output, labels) # We count the test loss
            running_test_loss += test_loss.item() * images.size(0) # And, we add the loss value to the previous one because we work in batch

            _, pred = torch.max(test_output.data, 1) # Classifying again with previous output, but adding more confidence to the model
            total += labels.size(0) # We add the amount of images to the total variable
            correct += (pred == labels).sum().item() # Now, we add the amount of correct labels from the model

    epoch_test_loss = running_test_loss / len(test_tensor.dataset) # We count the average test loss per batch
    acc = 100 * correct/total # We get the accuracy from here
    lr_scheduler.step(epoch_test_loss) # The learning rate scheduler adjusts the learning rate based on the last test loss

    # This part is for visualization
    loss_hist.append([
        epoch_train_loss,
        epoch_test_loss,
        acc/100,
    ])

    print(f"Training Loss: {epoch_train_loss:.4f}\nTest Loss: {epoch_test_loss:.4f}\nAccuracy: {acc:.2f}%")

    # Early stopping
    if early_stopping: # Checking if early stopping is True
        if epoch_test_loss < best_test_loss: # This will check if the current test loss is better than the last best test loss
            best_test_loss = epoch_test_loss # We change the variable of the besy test loss

            # If the hopeful early stopping is True, it'll reset the wait time when the test loss decreases
            if hopeful_early_stopping and wait > 0:
                wait = 0
                print("Detected a decrease in test loss. Resetting the wait...")
            torch.save(cifar10_model, os.path.join(weight_save_path, f"Epoch_{epoch+1}.pth")) # Saving the model

        # Adds one to the wait value when an increase in test loss is detected
        else:
            wait += 1
            print(f"Detected an increase in test loss. ({wait}/{tolerance})")

        # This triggers the early stopping when the wait time is equal to the tolerance or patience
        if wait == tolerance:
            print(f"Overfitting detected. ({wait}/{tolerance})\nTriggering early stopping...")
            break

    # If not early stopping, then save every state of the model
    else:
        torch.save(cifar10_model, os.path.join(weight_save_path, f"Epoch_{epoch+1}.pth"))

print("Training finished.")
loss_hist.append(len(loss_hist)) # Adds the epoch count to the history for visualization

## Optional: Visualize the loss and the acccuracy trends

In [None]:
fig, ax = plt.subplot_mosaic([[1]], figsize=(10, 5), layout="constrained")
loss_list_hist = [element for element in loss_hist if isinstance(element, list)]
x_axis = np.arange(len(loss_list_hist)) + 1
labels = ["Training Loss", "Test Loss", "Accuracy"]
best_epoch_test_loss = []
best_acc = []
for i in range(3):
    y_values = [element[i] for element in loss_list_hist]
    ax[1].plot(x_axis, y_values, label=labels[i])
    if i == 1:
        best_epoch_test_loss = ([np.argmin(y_values) + 1, np.min(y_values)])
    elif i == 2:
        best_acc = ([np.argmax(y_values) + 1, np.max(y_values)])
ax[1].set_xlabel("Epoch")
ax[1].set_title("CIFAR-10 Model's Accuracy and Loss per Epoch")
ax[1].grid(True)
ax[1].legend()
plt.savefig("/content/cifar10.png")
plt.show()
print(f"Epoch with the lowest test loss is the {best_epoch_test_loss[0]} epoch with {best_epoch_test_loss[1]:.4f} loss.")
print(f"Epoch with the highest accuracy is the {best_acc[0]} epoch the with {best_acc[1]*100:.2f}% accuracy.")