# Import requirements

In [30]:
import os
import random
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets

if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
print(f'Device: {DEVICE}')

Device: cuda


# Data Augmentation Using `transforms`

In [2]:
transform_train = transforms.Compose([transforms.Resize((32, 32)),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomRotation(10),
                                      transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
                                      transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                     ])

transform_test = transforms.Compose([transforms.Resize((32, 32)),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                 ])

train_dataset = datasets.CIFAR10(root='../data/CIFAR_10',
                                 train=True,
                                 download=True,
                                 transform=transform_train)
test_dataset = datasets.CIFAR10(root='../data/CIFAR_10',
                                train=False,
                                transform=transform_test)

Files already downloaded and verified


In [3]:
BATCH_SIZE = 32

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, 
    batch_size=BATCH_SIZE,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
)

# Check Shape of Images

In [4]:
for (input, label) in train_loader:
    print(input.shape)
    break

torch.Size([32, 3, 32, 32])


# Define the Convolution Neural Network (CNN)

In [55]:
class ConvNet(nn.Module):
    # similar with VGG-16
    def __init__(self):
        super(ConvNet, self).__init__()
        # input shape = (32, 32)
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1, bias=False), # (32, 32)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1, bias=False), # (32, 32)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # (16, 16)
            
            nn.Conv2d(64, 128, 3, padding=1, bias=False), # (16, 16)
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1, bias=False), # (16, 16)
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # (8, 8)
            
            nn.Conv2d(128, 256, 3, padding=1, bias=False), # (8, 8)
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1, bias=False), # (8, 8)
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1, bias=False), # (8, 8)
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # (4, 4)

            nn.Conv2d(256, 512, 3, padding=1, bias=False), # (4, 4)
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1, bias=False), # (4, 4)
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1, bias=False), # (4, 4)
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # (2, 2)

            nn.Conv2d(512, 512, 3, padding=1, bias=False), # (2, 2)
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1, bias=False), # (4, 4)
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1, bias=False), # (4, 4)
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # (1, 1)
            
        )


        self.clssify = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 10),
            nn.LogSoftmax(dim=1),
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 512)
        x = self.clssify(x)
        return x

# Define the train, evaluation

In [73]:
def train(model, train_loader, optimizer, log_interval):
    model.train()
    
    for batch_idx, (image, label) in enumerate(train_loader):
        image = image.to(DEVICE)
        label = label.to(DEVICE)
        optimizer.zero_grad()
        output = model(image)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            pct = 100 * batch_idx / len(train_loader) # percent
            train_loss = loss.item()
            print(f'Train Epoch: {Epoch} [{batch_idx * len(image)}/{len(train_loader.dataset)} ({pct:.0f}%)]\tTrain Loss: {train_loss:.6f}')


def evaluate(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for image, label in test_loader:
            image = image.to(DEVICE)
            label = label.to(DEVICE)
            output = model(image)
            test_loss += criterion(output, label).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    test_accuracy = 100 * correct / len(test_loader.dataset)

    return test_loss, test_accuracy

# set seeds

In [57]:
def fix_seeds(seed = 42, use_torch=False):
    # fix the seed for reproducibility 
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)

    if use_torch: 
        torch.manual_seed(seed) 
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

In [58]:
def init_weights(m):
    # initialize the weight, bias
    if isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_uniform_(m.weight.data)
        if m.bias is not None:
            torch.nn.init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight.data, mean=1, std=0.02)
        torch.nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        torch.nn.init.kaiming_uniform_(m.weight.data)
        torch.nn.init.normal_(m.bias.data)

In [64]:
SEED = 42
EPOCHS = 50

fix_seeds(seed=SEED, use_torch=True)
model = ConvNet().to(device=DEVICE)
model.apply(init_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()


for Epoch in range(1, EPOCHS + 1):
    train_acc = train(model, train_loader, optimizer, log_interval=200)
    test_loss, test_acc = evaluate(model, test_loader)
    print(f'\nEpoch: {Epoch}')
    print(f'Test Loss: {test_loss:.4f}')
    print(f'Test Accuracy: {test_acc:.2f}\n')


Epoch: 1
Test Loss: 0.0514
Test Accuracy: 37.26


Epoch: 2
Test Loss: 0.0397
Test Accuracy: 52.40


Epoch: 3
Test Loss: 0.0345
Test Accuracy: 60.02


Epoch: 4
Test Loss: 0.0296
Test Accuracy: 66.72


Epoch: 5
Test Loss: 0.0263
Test Accuracy: 70.77


Epoch: 6
Test Loss: 0.0222
Test Accuracy: 76.33


Epoch: 7
Test Loss: 0.0222
Test Accuracy: 76.93


Epoch: 8
Test Loss: 0.0202
Test Accuracy: 78.78


Epoch: 9
Test Loss: 0.0184
Test Accuracy: 81.27


Epoch: 10
Test Loss: 0.0168
Test Accuracy: 82.27


Epoch: 11
Test Loss: 0.0157
Test Accuracy: 83.34


Epoch: 12
Test Loss: 0.0156
Test Accuracy: 84.06


Epoch: 13
Test Loss: 0.0148
Test Accuracy: 84.29


Epoch: 14
Test Loss: 0.0138
Test Accuracy: 85.69


Epoch: 15
Test Loss: 0.0138
Test Accuracy: 85.62


Epoch: 16
Test Loss: 0.0136
Test Accuracy: 85.61


Epoch: 17
Test Loss: 0.0128
Test Accuracy: 86.38


Epoch: 18
Test Loss: 0.0125
Test Accuracy: 86.84


Epoch: 19
Test Loss: 0.0125
Test Accuracy: 86.94


Epoch: 20
Test Loss: 0.0120
Test Accura