In [None]:
import torch
from torch import nn

import random
import numpy as np

In [None]:
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
import torchvision.datasets

In [None]:
CIFAR_train = torchvision.datasets.CIFAR10('./', download=True, train=True)
CIFAR_test = torchvision.datasets.CIFAR10('./', download=True, train=False)

In [None]:
X_train = torch.FloatTensor(CIFAR_train.data)
y_train = torch.LongTensor(CIFAR_train.targets)

X_test = torch.FloatTensor(CIFAR_test.data)
y_test = torch.LongTensor(CIFAR_test.targets)

In [None]:
X_train /= 255.
X_test /= 255.

In [None]:
CIFAR_train.classes

In [None]:
import matplotlib.pyplot as plt

In [None]:
# fig, ax = plt.subplots(2, 3)
# k = 0
#
# for i in range(2):
#     for j in range(3):
#         ax[i][j].imshow(X_train[k], aspect='auto')
#         k += 1

In [None]:
X_train = X_train.permute(0, 3, 1, 2)
X_test = X_test.permute(0, 3, 1, 2)

In [None]:
class LeNet5(nn.Module):
    def __init__(self,
                 activation='tanh',
                 pooling='avg',
                 conv_size=5,
                 use_batch_norm=False):
        super(LeNet5, self).__init__()

        self.conv_size = conv_size
        self.use_batch_norm = use_batch_norm

        if activation == 'tanh':
            activation_function = nn.Tanh()
        elif activation == 'relu':
            activation_function = nn.ReLU()
        else:
            raise NotImplementedError

        if pooling == 'avg':
            pooling_layer = nn.AvgPool2d(kernel_size=2, stride=2)
        elif pooling == 'max':
            pooling_layer = nn.MaxPool2d(kernel_size=2, stride=2)
        else:
            raise NotImplementedError

        if conv_size == 5:
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, padding=0)
        elif conv_size == 3:
            self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, padding=0)
            self.conv1_2 = nn.Conv2d(in_channels=6, out_channels=6, kernel_size=3, padding=0)
        else:
            raise NotImplementedError

        self.act1 = activation_function
        self.bn1 = nn.BatchNorm2d(num_features=6)
        self.pool1 = pooling_layer

        if conv_size == 5:
            self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, padding=0)
        elif conv_size == 3:
            self.conv2_1 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=3, padding=0)
            self.conv2_2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=0)
        else:
            raise NotImplementedError

        self.act2 = activation_function
        self.bn2 = nn.BatchNorm2d(num_features=16)
        self.pool2 = pooling_layer
        self.fl1 = nn.Flatten()
        self.fc1 = nn.Linear(5 * 5 * 16, 120)
        self.act3 = activation_function

        self.fc2 = nn.Linear(120, 84)
        self.act4 = activation_function

        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        if self.conv_size == 5:
            x = self.conv1(x)
        elif self.conv_size == 3:
            x = self.conv1_1(x)
            x = self.conv1_2(x)

        x = self.act1(x)
        if self.use_batch_norm:
            x = self.bn1(x)
        x = self.pool1(x)

        if self.conv_size == 5:
            x = self.conv2(x)
        elif self.conv_size == 3:
            x = self.conv2_1(x)
            x = self.conv2_2(x)

        x = self.act2(x)
        if self.use_batch_norm:
            x = self.bn2(x)
        x = self.pool2(x)

        x = self.fl1(x)

        x = self.fc1(x)
        x = self.act3(x)

        x = self.fc2(x)
        x = self.act4(x)

        x = self.fc3(x)

        return x

In [None]:
def train(net, X_train, y_train, X_test, y_test):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    net = net.to(device)
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=1.0e-3)

    num_epoch = 10
    batch_size = 512

    test_accuracy_history = list()
    test_loss_history = list()

    X_test = X_test.to(device)
    y_test = y_test.to(device)

    for epoch in range(num_epoch):
        order = np.random.permutation(len(X_train))
        for start_index in range(0, len(X_train), batch_size):
            optimizer.zero_grad()
            net.train()

            batch_indexes = order[start_index:start_index+batch_size]

            X_batch = X_train[batch_indexes].to(device)
            y_batch = y_train[batch_indexes].to(device)

            preds = net.forward(X_batch)

            loss_value = loss(preds, y_batch)
            loss_value.backward()

            optimizer.step()

        net.eval()
        with torch.no_grad():
            test_preds = net.forward(X_test)
            test_loss_history.append(loss(test_preds, y_test).item())

            accuracy = (test_preds.argmax(dim=1) == y_test).float().mean().item()
            test_accuracy_history.append(accuracy)

        print(f'Epoch {epoch+1}/{num_epoch}, Loss(test): {test_loss_history[epoch]:.4f}, Accuracy(test): {accuracy:.4f}')

    del net
    return test_accuracy_history, test_loss_history

In [None]:
accuracies = dict()
losses = dict()

In [None]:
# accuracies['tanh'], losses['tanh'] = train(LeNet5(activation='tanh', conv_size=5),
#                                            X_train, y_train, X_test, y_test)
# accuracies['relu'], losses['relu'] = train(LeNet5(activation='relu', conv_size=5),
#                                            X_train, y_train, X_test, y_test)
# accuracies['relu_3'], losses['relu_3'] = train(LeNet5(activation='relu', conv_size=3),
#                                                X_train, y_train, X_test, y_test)
# accuracies['relu_3_max_pool'], losses['relu_3_max_pool'] = train(LeNet5(activation='relu', conv_size=3, pooling='max'),
#                                                                  X_train, y_train, X_test, y_test)
# accuracies['relu_3_max_pool_bn'], losses['relu_3_max_pool_bn'] = train(LeNet5(activation='relu', conv_size=3, pooling='max', use_batch_norm=True),
#                                                                        X_train, y_train, X_test, y_test)

In [None]:
# for experiment_id in accuracies.keys():
#     plt.plot(accuracies[experiment_id], label=experiment_id)
# plt.legend()
# plt.title('Validation Accuracy');

In [None]:
# for experiment_id in losses.keys():
#     plt.plot(losses[experiment_id], label=experiment_id)
# plt.legend()
# plt.title('Validation Loss');

In [None]:
class CIFARNet(nn.Module):
    def __init__(self):
        super(CIFARNet, self).__init__()

        self.batch_norm0 = nn.BatchNorm2d(3)

        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2, 2),
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, 2),
        )

        self.conv_block3 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
        )

        self.linear_block1 = nn.Sequential(
            nn.Linear(8 * 8 * 64, 256),
            nn.Tanh(),
            nn.BatchNorm1d(256),
        )

        self.linear_block2 = nn.Sequential(
            nn.Linear(256, 64),
            nn.Tanh(),
            nn.BatchNorm1d(64),
        )

        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.batch_norm0(x)

        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)

        x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))

        x = self.linear_block1(x)
        x = self.linear_block2(x)
        x = self.fc3(x)

        return x

In [None]:
# accuracies['cifar_net'], losses['cifar_net'] = train(CIFARNet(),
#                                                      X_train, y_train, X_test, y_test)

In [None]:
for experiment_id in accuracies.keys():
    plt.plot(accuracies[experiment_id], label=experiment_id)
plt.legend()
plt.title('Validation Accuracy');

In [None]:
for experiment_id in losses.keys():
    plt.plot(losses[experiment_id], label=experiment_id)
plt.legend()
plt.title('Validation Loss');