# Training VGG16 - Tutorial

In this tutorial we will present how to train a VGG16 network and create a checkpoint file for the trained network.

### Imports
Firstly, we start by importing all the necessary functions

In [1]:
import torch
import numpy as np
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
import sys
sys.path.insert(0, '/scratch/lmeneghe/Smithers/')

from smithers.ml.models.vgg import VGG

import warnings
warnings.filterwarnings("ignore")

### Setting the proper device
The following lines will detect if a gpu is available in the system running this tutorial. If that is the case, all the objects of the following tutorial will be allocated in the gpu, thus speeding up the training process.

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(f"{device} has been detected as the device which the script will be run on.")

cuda has been detected as the device which the script will be run on.


### Loading of the model
We use VGG16 as model, as implemented in ***smithers/ml/models/vgg.py***. The net is initialized using weights pre-trained on ImageNet, a common choice instead of using random ones (i.e. setting init_weights=True).

In [3]:
VGGnet = VGG(cfg=None,
             classifier='cifar',
             batch_norm=False,
             num_classes=10,
             init_weights='imagenet').to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(VGGnet.parameters(), lr=0.001, momentum=0.9)


Loaded base model.



### Loading of the CIFAR10 dataset
As stated before, we use the CIFAR10 dataset (already implemented in PyTorch) to test our technique. It is a computer-vision dataset used for object recognition. It consists of 60000 32 × 32 colour images divided in 10 non-overlapping classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck.

See https://www.cs.toronto.edu/~kriz/cifar.html for more details on this dataset and on how to download it.

In [4]:
batch_size = 8 #this can be changed
data_path = '../cifar/' 
# transform functions: take in input a PIL image and apply this
# transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = datasets.CIFAR10(root=data_path + 'CIFAR10/',
                                 train=True,
                                 download=True,
                                 transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
test_dataset = datasets.CIFAR10(root=data_path + 'CIFAR10/',
                                train=False,
                                download=True,
                                transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)
train_labels = torch.tensor(train_loader.dataset.targets).to(device)
targets = list(train_labels)
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
n_classes = len(classes)

Files already downloaded and verified
Files already downloaded and verified


### Custom Dataset
If we want to use a custom dataset, we need firstly to construct it, following for example the tutorial on the construction of a custom dataset for the problem of Image Recognition. Hence, the previuous cell will be substitute with the following one.

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler
from collections import OrderedDict
from smithers.ml.imagerec_dataset import Imagerec_Dataset

# load custom dataset for training and testing
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

data = pd.read_csv('../dataset_imagerec/dataframe.csv')
data_path = '../dataset_imagerec/'
# SPLIT OF THE DATASET
batch_size = 128
validation_split = .2
shuffle_dataset = True
random_seed = 42

dataset_size = len(data)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
print('train data', len(train_indices))
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
resize_dim = [32, 32]

dataset_imagerec = Imagerec_Dataset(data, data_path, resize_dim, transform)
train_dataset = dataset_imagerec.getdata(train_indices)
train_loader = torch.utils.data.DataLoader(dataset_imagerec,
                                           batch_size=batch_size,
                                           sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(dataset_imagerec,
                                          batch_size=batch_size,
                                          sampler=valid_sampler)

data.sort_values(by=['encoded_labels'], inplace=True)
classes = data['labels'].unique()
#classes = ('class_1', 'class_2', 'class_3', 'class_4')
n_class = len(classes)
targets = list(dataset_imagerec.targets[train_indices])
train_labels = torch.tensor(targets)

### Training phase 
The following lines of code will train the network on the train images from the CIFAR10 dataset. Beware that training time can be very long and a gpu is recommended.

It is possible to change the number of epochs for the training: if a large number is given, the network will perform better on the train images but will take longer to train, vice versa for a low number.

In [5]:
n_epochs = 5
for epoch in range(n_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = VGGnet(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Let's print statistics at the end of the epoch
        running_loss += loss.item()     # extract the loss value
        if i==len(train_loader)-1:
            print('Epoch {}, Loss Value: {:.5f}'.format
                 (epoch + 1, running_loss / ((i+1)*batch_size)))
            # zero the loss
            running_loss = 0.0


print("The network has been successfully trained for {} epochs.".format(n_epochs))

Epoch 1, Loss Value: 0.09345
Epoch 2, Loss Value: 0.05922
Epoch 3, Loss Value: 0.04876
Epoch 4, Loss Value: 0.04112
Epoch 5, Loss Value: 0.03649
The network has been successfully trained for 5 epochs.


### Accuracy after the training phase
Once we have trained our model we can check its accuracy on the testing dataset.

In [6]:
def testAccuracy(net, test_loader, device):
    '''
    Function for testing the accuracy of the model.

    :param nn.Module net: network under consideration
    :param iterable test_loader: iterable object, it load the dataset for
            testing. It iterates over the given dataset, obtained combining a
            dataset(images, labels) and a sampler.
    '''
    net.eval()
    accuracy = 0.0
    total = 0.0
    net.to(device)
    
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            # run the model on the test set to predict labels
            outputs = net(images)
            # the label with the highest energy will be our prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            accuracy += (predicted == labels).sum().item()

    # compute the accuracy over all test images
    accuracy = (100 * accuracy / total)
    return(accuracy)


def testClasses(net, n_classes, test_loader, classes, device):
    '''
    Function testing the accuracy reached for each class
    composing the dataset.

    :param nn.Module net: network under consideration
    :param int n_classes: number of classes composing the dataset
    :param iterable test_loader: iterable object, it load the dataset for
            testing. It iterates over the given dataset, obtained combining a
            dataset(images, labels) and a sampler.
    '''
    class_correct = list(0. for i in range(n_classes))
    class_total = list(0. for i in range(n_classes))
    net.eval()
    net.to(device)
    
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)            
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            if len(labels)==1:
                c = torch.tensor([c])
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    for i in range(n_classes):
        print('Accuracy of {} : {:.2f}%'.format(
            classes[i], 100 * class_correct[i] / class_total[i]))

In [7]:
accuracy = testAccuracy(VGGnet, test_loader, device)
print('The accuracy over the whole test set is {:.2f}%'.format(accuracy))
testClasses(VGGnet, n_classes, test_loader, classes, device)

The accuracy over the whole test set is 87.51%
Accuracy of airplane : 92.30%
Accuracy of automobile : 96.80%
Accuracy of bird : 87.90%
Accuracy of cat : 82.40%
Accuracy of deer : 87.30%
Accuracy of dog : 73.80%
Accuracy of frog : 89.50%
Accuracy of horse : 87.30%
Accuracy of ship : 92.80%
Accuracy of truck : 85.00%


### Creating a checkpoint of the state of the network
Once the training is done, the state of the network should be saved for a later use.

In [8]:
import copy
torch.save(copy.deepcopy(VGGnet), 'check_vgg.pth')

### Loading state of the network from a checkpoint file
Once the state of the network has been saved in a ***.pth*** file, we can load it for a futher use, as an additional training or other tests.

In [9]:
VGGnet = torch.load('check_vgg.pth')

In [10]:
accuracy = testAccuracy(VGGnet, test_loader, device)
print('The accuracy over the whole test set is {:.2f}%'.format(accuracy))
testClasses(VGGnet, n_classes, test_loader, classes, device)

The accuracy over the whole test set is 87.51%
Accuracy of airplane : 92.30%
Accuracy of automobile : 96.80%
Accuracy of bird : 87.90%
Accuracy of cat : 82.40%
Accuracy of deer : 87.30%
Accuracy of dog : 73.80%
Accuracy of frog : 89.50%
Accuracy of horse : 87.30%
Accuracy of ship : 92.80%
Accuracy of truck : 85.00%
