# Training a CNN model on CIFAR10

### Install dependencies

In [None]:
%pip install matplotlib seaborn torch torchvision torchaudio numpy

#### Prepare imports

In [None]:
import matplotlib.pyplot as plt # This is to load plotting functions
import seaborn as sns; sns.set() # This is to make the plots prettier
import torch # This is the ML library we will use
import torchvision # This is the supporting ML library for computer vision
from torchvision import datasets # This is to access the CIFAR10 dataset
from torch.utils.data import DataLoader # This is used to load the data efficiently
import torchvision.transforms as transforms # This is used to transform data when preparing the dataset
import numpy as np # This is used to handle arrays of data

#### Select the device from GPU and CPU

In [None]:
# If GPU is available, it is chosen for computations, and if not, CPU will be used
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

#### Download the CIFAR10 dataset and create tran and test data loaders

In [None]:
# This is used to normalize input images
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

batch_size = 4

train_set = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

test_set = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

classes = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)

#### Show example images from the dataset.

In [None]:

def imshow(img):
    # Unnormalize
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# Get some random training images
dataiter = iter(train_loader)
images, labels = next(dataiter)

# Show images
imshow(torchvision.utils.make_grid(images))
# Print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

#### Create a simple CNN classifier model

In [None]:
class Cifar10Classifier(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.body = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels=3,
                out_channels=6,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
            torch.nn.Conv2d(
                in_channels=6,
                out_channels=16,
                kernel_size=5,
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
            torch.nn.Flatten(),
            torch.nn.Linear(16 * 6 * 6, 120),
            torch.nn.ReLU(),
            torch.nn.Linear(120, 84),
            torch.nn.ReLU(),
            torch.nn.Linear(84, 10),
            torch.nn.ReLU(),
        )

    def forward(self, x):
        output = self.body(x)
        return output

In [None]:
model = Cifar10Classifier()
model.to(device)

#### Create a loss function and an optimizer.

In [None]:
learning_rate = 0.001

loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Train the model.

In [None]:
def train(num_epochs, model, loss_func, optimizer):

    model.train()

    # Train the model
    total_step = len(train_loader)

    for epoch in range(num_epochs):
        # For each batch in the training data
        for i, (images, labels) in enumerate(train_loader):

            images = images.to(device)
            labels = labels.to(device)

            # Compute output and loss
            output = model(images)
            loss = loss_func(output, labels)

            # Clear gradients for this training step
            optimizer.zero_grad()

            # Compute gradients
            loss.backward()
            # Apply gradients
            optimizer.step()

            if (i + 1) % 100 == 0:
                print(
                    "Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format(
                        epoch + 1, num_epochs, i + 1, total_step, loss.item()
                    )
                )

    print("Done.")


train(2, model, loss_func, optimizer)

Now, test the trained model.

In [None]:
def test(model):
    # Test the model
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:

            images = images.to(device)
            labels = labels.to(device)

            test_output = model(images)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            correct += (pred_y == labels).sum().item()
            total += labels.size(0)
        
        accuracy = correct / float(total)
        print('Test Accuracy of the model: %.2f%%' % (accuracy * 100))
test(model)

#### Now, let us train more and see how the result changes.

In [None]:
train(1, model, loss_func, optimizer)

In [None]:
test(model)

## Pretrained ResNet18 model

#### Load the pretrained model from torchvision.

In [None]:
resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
resnet18

#### ImageNet has different image size and the number of classes.

We can adapt this model by changing the first and the last layers to fit our needs.
Those layers are untrained, but the knowledge from all other layers is still relevant and helps the model to be trained faster.

The first convolutional layer of ResNet needs to receive as input an image with 3 channels, have 64 output channels, 3x3 filter, stride 1 and padding 1 with no bias.

The last fully-connected layer should receive as input a vector with the same number of features as before, but the number of output dimensions should be equal to the number of classes in CIFAR10.

<details>
<summary>Solution</summary>
    <code>
    resnet18.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)<br>
    num_features = resnet18.fc.in_features<br>
    resnet18.fc = torch.nn.Linear(num_features, num_classes)<br>
    </code>
</details>

In [None]:
num_classes = 10
resnet18.conv1 = ...
resnet18.fc = ...
resnet18.to(device)

#### It is pretrained on ImageNet, so the starting accuracy on CIFAR10 should be bad, let us check it.

In [None]:
test(resnet18)

#### Prepare the training objects.

In [None]:
learning_rate = 0.001

loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet18.parameters(), lr=learning_rate)

#### Train and test

In [None]:
train(2, resnet18, loss_func, optimizer)

In [None]:
test(resnet18)