# Pre-train 5 CNN layers on CIFAR10
Author: Gabriel Kressin Palacios

[this notebook is not needed if you have the model already]

Pre-train a CNN with following properties:
- 100 Epochs on [cifar10](https://www.cs.toronto.edu/~kriz/cifar.html)
(Krizhevsky 2009)
- 5 convolutional layers with
  - 16,32,64,128,256 output Channels respectively
  - padding of 1
  - with bias
  - a stride of 1, 2, 2, 2, 2
  - a kernel size of 3
  - a BatchNormalization after each convolution
  - a ReLU nonlinearity after each convolution


This is to make our approach comparable to other approaches
(van de Ven et al. 2020, Vogelstein et al. 2021)

#### Imports

In [1]:
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np

import random
import math
import time
import copy
import itertools

In [2]:
BATCH_SIZE = 256
N_EPOCHS = 100
LEARNING_RATE = 0.0001
BETAS = (0.9, 0.999)
CONFIG = {
    "dataset": "cifar10",
    "img_size": 32,
    "n_labels": 10,
    "channels": 3,
}

#### ConvLayers

In [3]:
class ConvLayers(nn.Module):
    """ Contains 5 Convolutional Layers """

    def __init__(self, in_channels, img_size):
        """
        Creates 5 conv layers with batchnorm and ReLU activation
        Parameters
        ----------
        in_channels: int
            Amount of channels the input data has
        img_size: int
            Amount of pixels in one axis
        """
        super(ConvLayers, self).__init__()

        list_out_channels = [16,32,64,128,256]

        self.img_size = img_size
        self.out_size = int(np.ceil(img_size / 2**4))
        self.in_channels = in_channels
        self.out_channels = list_out_channels[-1]
        self.out_units = self.out_channels * self.out_size**2
        
        self.conv1 = nn.Conv2d(in_channels, list_out_channels[0],
                               kernel_size=3, stride=1, bias=True, padding=1)
        self.bn1 = nn.BatchNorm2d(list_out_channels[0])
        self.nl1 = nn.ReLU()

        self.conv2 = nn.Conv2d(list_out_channels[0], list_out_channels[1],
                               kernel_size=3, stride=2, bias=True, padding=1)
        self.bn2 = nn.BatchNorm2d(list_out_channels[1])
        self.nl2 = nn.ReLU()
        
        self.conv3 = nn.Conv2d(list_out_channels[1], list_out_channels[2],
                               kernel_size=3, stride=2, bias=True, padding=1)
        self.bn3 = nn.BatchNorm2d(list_out_channels[2])
        self.nl3 = nn.ReLU()

        self.conv4 = nn.Conv2d(list_out_channels[2], list_out_channels[3],
                               kernel_size=3, stride=2, bias=True, padding=1)
        self.bn4 = nn.BatchNorm2d(list_out_channels[3])
        self.nl4 = nn.ReLU()

        self.conv5 = nn.Conv2d(list_out_channels[3], list_out_channels[4],
                               kernel_size=3, stride=2, bias=True, padding=1)
        self.bn5 = nn.BatchNorm2d(list_out_channels[4])
        self.nl5 = nn.ReLU()

    def forward(self, X):

        X1 = self.nl1(self.bn1(self.conv1(X)))
        X2 = self.nl2(self.bn2(self.conv2(X1)))
        X3 = self.nl3(self.bn3(self.conv3(X2)))
        X4 = self.nl4(self.bn4(self.conv4(X3)))
        X5 = self.nl5(self.bn5(self.conv5(X4)))

        return X5

#### LinLayers

In [4]:
class LinLayers(nn.Module):
    """
    Contains 2 fully connected layers with 2000 units each and a ReLU.
    Additional output layer on top
    """

    def __init__(self, in_features, out_features):
        """
        Parameters
        ----------
        in_features: int
            input dimensionality of data
        out_features: int
            output dimensionality of data
        """
        super(LinLayers, self).__init__()

        self.in_features = in_features
        self.out_features = out_features

        hidden_dims = [2000, 2000]

        self.fc1 = nn.Linear(in_features, hidden_dims[0])
        self.nl1 = nn.ReLU()

        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.nl2 = nn.ReLU()

        self.fc3 = nn.Linear(hidden_dims[1], out_features)

    def forward(self, X):

        X1 = self.nl1(self.fc1(X))
        X2 = self.nl2(self.fc2(X1))
        X3 = self.fc3(X2)

        return X3

### Data

In [5]:
!mkdir data

#### Compute Norm and std

In [6]:
# un_ds_train_cifar10 = datasets.CIFAR10("data", train=True,
#                                        transform=transforms.ToTensor(),
#                                        download=True)
# un_dl_train_cifar10 = DataLoader(un_ds_train_cifar10,
#                                  batch_size=len(un_ds_train_cifar10))
# data, _ = next(iter(un_dl_train_cifar10))

In [7]:
# means = []
# stds = []
# for channel in range(data.shape[1]):
#     means.append(data[:,channel].mean().item())
#     stds.append(data[:,channel].std().item())

In [8]:
# print(means)
# print(stds)

#### Load Norm and Std

In [9]:
means = [0.491400808095932, 0.48215898871421814, 0.44653093814849854]
stds = [0.24703224003314972, 0.24348513782024384, 0.26158785820007324]

#### DataLoader

In [10]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(means, stds)])

ds_train_cifar10 = datasets.CIFAR10("data", train=True,
                                  transform=transform,
                                  download=True)
ds_test_cifar10 = datasets.CIFAR10("data", train=False,
                                 transform=transform,
                                 download=False)

dl_train_cifar10 = DataLoader(ds_train_cifar10, batch_size=BATCH_SIZE,
                              shuffle=True)
dl_test_cifar10 = DataLoader(ds_test_cifar10, batch_size=BATCH_SIZE,
                             shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting data/cifar-10-python.tar.gz to data


#### Training functions

In [11]:
def epoch_time(start_time, end_time):
    """
    Computes minutes and seconds given start and end time

    Parameters
    ----------
    start_time: float
        systemtime in ms at the start of interval to be measured
    end_time: float
        systemtime in ms at the end of interval to be measured
    """
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [12]:
def train_cifar10(convLayers, linLayers, dl, opt, crit, clip):
    """
    trains convlayers on cifar10 dataset

    Parameters
    ----------
    convLayers: ConvLayers
        ConvLayers object that will be trained
    linLayers: LinLayers
        LinLayers object that will be trained
    dl: DataLoader
        dataloader containing batched data
    opt: optim.Optimizer
        optimizer for linLayers and convLayers
    crit: nn.Criterion
        criterion taking outputs from linLayers and labels
    clip: int
        clips if gradients are too high

    Returns
    -------
    epoch_loss: float
        Accumulated loss over epochs
    """

    convLayers.train()
    linLayers.train()

    epoch_loss = 0

    for X, y in dl:

        opt.zero_grad()

        conv_out = convLayers(X)
        output = linLayers(torch.flatten(conv_out, start_dim=1))

        loss = crit(output, y)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(convLayers.parameters(), clip)
        torch.nn.utils.clip_grad_norm_(linLayers.parameters(), clip)

        opt.step()

        epoch_loss += loss.item()

    return epoch_loss / len(dl)

In [13]:
def hitsLoss(outputs, labels):
    """
    Computes the amount of labels correct by taking the
    argmax of the output.
    Cannot be used for training!

    Parameters
    ----------
    outputs: torch.tensor
        model outputs
    labels: torch.tensor
        true labels

    Returns
    -------
    hits: float [0,1]
        A number between 0 and 1 indicating how many labels
        were classified correctly
    """
    #outputs = [batch_size, n_classes]
    #labels = [batch_size]
    
    batch_size = outputs.shape[0]
    
    preds = outputs.detach().argmax(dim=-1)
    
    hits = sum(preds == labels)/batch_size
    
    return hits

In [14]:
def eval_cifar10(convLayers, linLayers, dl, crit):
    """
    Evaluates convLayers together with linLayers on given DataLoader

    Parameters
    ----------
    convLayers: ConvLayers
        ConvLayers object that will be trained
    linLayers: LinLayers
        LinLayers object that will be trained
    dl: DataLoader
        dataloader containing batched data
    crit: nn.Criterion
        criterion taking outputs from linLayers and labels

    Returns
    -------
    epoch_loss: float
        Accumulated loss over epochs

    """
    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        for X, y in dataloader:
            #X = [batch_size, input_dim]
            #y = [batch_size]

            output = model(X[:,permutation])
            #output = [batch_size, n_classes]
            
            loss = criterion(output, y)

            epoch_loss += loss.item()

    return epoch_loss / len(dataloader)

#### Pre-train Conv Layers

In [15]:
convLayers = ConvLayers(in_channels=CONFIG["channels"],
                        img_size=CONFIG["img_size"])
linLayers = LinLayers(in_features=convLayers.out_units,
                      out_features=CONFIG["n_labels"])

opt = optim.Adam(list(convLayers.parameters()) + list(linLayers.parameters()),
                 lr=LEARNING_RATE, betas=BETAS)

crit = nn.CrossEntropyLoss()

In [None]:
for epoch in range(N_EPOCHS):

    start_time = time.time()
    train_loss = train_cifar10(convLayers, linLayers, dl_train_cifar10, opt,
                               crit, 1)
    end_time = time.time()

    e_mins, e_secs = epoch_time(start_time, end_time)
    print(f'Epoch: {epoch+1:02} | Time: {e_mins}m {e_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')

Epoch: 01 | Time: 0m 57s
	Train Loss: 1.576
Epoch: 02 | Time: 0m 59s
	Train Loss: 1.207
Epoch: 03 | Time: 0m 57s
	Train Loss: 1.023


In [None]:
torch.save(convLayers, "convLayers_trained_cifar10.pt")

## References

Learning Multiple Layers of Features from Tiny Images, Alex Krizhevsky, Technical Report, 2009, https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf

van de Ven, G.M., Siegelmann, H.T. & Tolias, A.S. Brain-inspired replay for continual learning with artificial neural networks. Nat Commun 11, 4069 (2020). https://doi.org/10.1038/s41467-020-17866-2

Joshua T. Vogelstein, Jayanta Dey, Hayden S. Helm, Will LeVine, Ronak D. Mehta, Ali Geisa, Gido M. van de Ven, Emily Chang, Chenyu Gao, Weiwei Yang, Bryan Tower, Jonathan Larson, Christopher M. White, Carey E. Priebe, Omnidirectional Transfer for Quasilinear Lifelong Learning, 2021, [https://arxiv.org/abs/2004.12908v7](https://arxiv.org/abs/2004.12908v7)