# Training the ResNet-18 with CIFAR-10

This notebook trains the ResNet-18 network with the CIFAR-10 dataset in a deterministic fashion and stores the obtained weights.

### 0. Import libraries and define settings

In [None]:
# import Python packages
import time
from tqdm import tqdm
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

# allows to automatically update the imported modules
%load_ext autoreload
%autoreload 2

In [None]:
use_cuda = True

if use_cuda and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
print(device)


device(type='cuda')

### 1. Load and preprocess data

In [None]:
transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

trainset = torchvision.datasets.CIFAR10(
    root="./data_cifar", train=True, download=True, transform=transform_train
)

testset = torchvision.datasets.CIFAR10(
    root="./data_cifar", train=False, download=True, transform=transform_test
)

batch_size = 128

c, w, h = 3, 32, 32

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

classes = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data_cifar/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 50265825.56it/s]


Extracting ./data_cifar/cifar-10-python.tar.gz to ./data_cifar
Files already downloaded and verified


### 2. Define architecture

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels, hidden_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.batch1 = nn.BatchNorm2d(hidden_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(
            hidden_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
        )
        self.batch2 = nn.BatchNorm2d(out_channels)

        if in_channels != out_channels:
            self.skip_connection = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=1, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.skip_connection = lambda x: x

    def forward(self, x):
        skip = self.skip_connection(x)
        x = self.conv1(x)
        x = self.batch1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.batch2(x)
        x = self.relu(x + skip)
        return x


In [None]:
class ResNet(nn.Module):
    def __init__(self, in_channels, out_size):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.batch1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.pool1 = nn.MaxPool2d(3, stride=2, padding=1)

        self.res_blocks = nn.ModuleList(
            [
                ResidualBlock(64, 64, 64),
                ResidualBlock(64, 64, 64),
                ResidualBlock(64, 128, 128),
                ResidualBlock(128, 128, 128),
                ResidualBlock(128, 256, 256),
                ResidualBlock(256, 256, 256),
                ResidualBlock(256, 512, 512),
                ResidualBlock(512, 512, 512),
            ]
        )

        self.dense_layer = nn.Linear(512, out_size)

        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(
                    module.weight, mode="fan_out", nonlinearity="relu"
                )

    def forward(self, x):
        x = self.conv1(x)
        x = self.batch1(x)
        x = self.relu(x)
        x = self.pool1(x)
        for block in self.res_blocks:
            x = block.forward(x)
        x = F.avg_pool2d(x, x.shape[2:])

        x = x.view(x.size(0), -1)
        x = self.dense_layer(x)

        return x


### 3. Train the network

In [None]:
net = ResNet(c, len(classes)).to(device)

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)


In [None]:
start = time.time()

# define lists that store train and test accuracy for each epoch

train_acc = []
test_acc = []

for epoch in range(0, 200):

    # initialize number of correct train predictions for the current epoch

    correct_train_total = 0

    net.train()  # Put the network in train mode
    for i, (x_batch, y_batch) in enumerate(trainloader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(
            device
        )  # Move the data to the device that is used

        optimizer.zero_grad()  # Set all currenly stored gradients to zero

        y_pred = net(x_batch)

        loss = criterion(y_pred, y_batch)

        loss.backward()

        optimizer.step()

        # Compute relevant metrics

        y_pred_max = torch.argmax(
            y_pred, dim=1
        )  # Get the labels with highest output probability

        correct = torch.sum(
            torch.eq(y_pred_max, y_batch)
        ).item()  # Count how many are equal to the true labels

        elapsed = time.time() - start  # Keep track of how much time has elapsed

        # accumulate number of correct predictions

        correct_train_total += correct

        # Show progress every 20 batches

        if not i % 20:
            print(
                f"epoch: {epoch}, time: {elapsed:.3f}s, loss: {loss.item():.3f}, train accuracy: {correct / batch_size:.3f}"
            )
        correct_total = 0
    # store train accuracy for current epoch

    train_acc.append(correct_train_total / len(trainset))

    net.eval()  # Put the network in eval mode
    for i, (x_batch, y_batch) in enumerate(testloader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(
            device
        )  # Move the data to the device that is used

        y_pred = net(x_batch)
        y_pred_max = torch.argmax(y_pred, dim=1)

        correct_total += torch.sum(torch.eq(y_pred_max, y_batch)).item()
    print(f"Accuracy on the test set: {correct_total / len(testset):.3f}")

    # store test accuracy for current epoch

    test_acc.append(correct_total / len(testset))


epoch: 0, time: 1.718s, loss: 2.473, train accuracy: 0.109
epoch: 0, time: 6.944s, loss: 1.855, train accuracy: 0.344
epoch: 0, time: 12.277s, loss: 1.789, train accuracy: 0.312
epoch: 0, time: 17.511s, loss: 1.617, train accuracy: 0.359
epoch: 0, time: 22.686s, loss: 1.559, train accuracy: 0.406
epoch: 0, time: 28.103s, loss: 1.631, train accuracy: 0.359
epoch: 0, time: 33.285s, loss: 1.417, train accuracy: 0.461
epoch: 0, time: 38.700s, loss: 1.438, train accuracy: 0.383
epoch: 0, time: 43.942s, loss: 1.477, train accuracy: 0.453
epoch: 0, time: 49.270s, loss: 1.357, train accuracy: 0.453
epoch: 0, time: 54.655s, loss: 1.302, train accuracy: 0.500
epoch: 0, time: 59.979s, loss: 1.531, train accuracy: 0.438
epoch: 0, time: 65.481s, loss: 1.251, train accuracy: 0.508
epoch: 0, time: 70.875s, loss: 1.364, train accuracy: 0.516
epoch: 0, time: 76.710s, loss: 1.332, train accuracy: 0.508
epoch: 0, time: 82.030s, loss: 1.072, train accuracy: 0.602
epoch: 0, time: 87.418s, loss: 1.287, trai

In [None]:
correct_total = 0

for i, (x_batch, y_batch) in enumerate(testloader):
    x_batch, y_batch = x_batch.to(device), y_batch.to(
        device
    )  # Move the data to the device that is used

    y_pred = net(x_batch)
    y_pred_max = torch.argmax(y_pred, dim=1)

    correct_total += torch.sum(torch.eq(y_pred_max, y_batch)).item()
print(f"Accuracy on the test set: {correct_total / len(testset):.3f}")


In [None]:
plt.plot(range(len(train_acc)), train_acc)
plt.plot(range(len(test_acc)), test_acc)
plt.legend(["train", "test"])
plt.xlabel("epoch number")
plt.ylabel("accuracy")
plt.title("Train and test accuracy evolution")
