In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
batch_size = 32
EPOCHS = 50
learning_rate_initial = 0.1
num_class = 10
schedule_step = 15
gamma_val = 0.5

In [None]:
# Data augmentation
train_transform = transforms.Compose([transforms.RandomRotation(5),
                                      transforms.RandomHorizontalFlip(0.3),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                           std = [0.229, 0.224, 0.225])
                                     ])

test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                          std = [0.229, 0.224, 0.225])
                                    ])

# Load data
train_set = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform =  train_transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, shuffle = True, num_workers = 4)

test_set = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = test_transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = batch_size, shuffle = False, num_workers = 4)

In [None]:
configuration = {
    'vgg11': [64, 'maxpool', 128, 'maxpool', 256, 256, 'maxpool', 512, 512, 'maxpool', 512, 512, 'maxpool'],
    'vgg13': [64, 64, 'maxpool', 128, 128, 'maxpool', 256, 256, 'maxpool', 512, 512, 'maxpool', 512, 512, 'maxpool'],
    'vgg16': [64, 64, 'maxpool', 128, 128, 'maxpool', 256, 256, 256, 'maxpool', 512, 512, 512, 'maxpool', 512, 512, 512, 'maxpool'],
    'vgg19': [64, 64, 'maxpool', 128, 128, 'maxpool', 256, 256, 256, 256, 'maxpool', 512, 512, 512, 512, 'maxpool', 512, 512, 512, 512, 'maxpool']
}

In [None]:
class VGG(nn.Module):
    def __init__(self, config):
        super(VGG, self).__init__()
        self.conv_layers = self.get_conv_layers(config)
        self.avgpool = nn.AvgPool2d(kernel_size = 1, stride = 1)
        self.fc = self.get_fc_layers()
        
        # Parameter Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
        
        
    def get_conv_layers(self, cfg):
        layers = []
        in_channels = 3
        for layer in cfg:
            if layer == 'maxpool':
                layers += [nn.MaxPool2d(kernel_size = 2, stride = 2)]
            else:
                layers += [nn.Conv2d(in_channels, layer, kernel_size = 3, padding = 1),
                           nn.BatchNorm2d(layer),
                           nn.ReLU(inplace = True)]
                in_channels = layer
        return nn.Sequential(*layers)
    
    def get_fc_layers(self):
        return nn.Sequential(
            nn.Linear(512, num_class)
        )
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [None]:
# Instantiate and send to GPU. 
vgg16 = VGG(configuration['vgg16'])
vgg16.to(device)

In [None]:
# Loss function, Optimizer, Learning Rate Scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vgg16.parameters(), lr = learning_rate_initial, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = schedule_step, gamma = gamma_val)

In [None]:
train_loss_list = []
train_acc_list = []
test_acc_list = []

# Train the data
for epoch in range(EPOCHS):
    print("Epoch: ", epoch + 1)
    vgg16.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        inputs, labels = Variable(inputs), Variable(labels)
    
        optimizer.zero_grad()
    
        outputs = vgg16(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
        running_loss += loss.item()
    
    running_loss /= len(train_loader)
    train_loss_list.append(running_loss)
    print('[%d] loss: %.3f' % (epoch + 1, running_loss))
    
    vgg16.eval()
    train_total = 0
    train_correct = 0
    
    total = 0
    correct = 0
    with torch.no_grad():
        for data in train_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
    
            outputs = vgg16(Variable(images))
            i, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
        for data in test_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
    
            outputs = vgg16(Variable(images))
            i, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    train_acc_list.append(100 * train_correct / train_total)  
    test_acc_list.append(100 * correct / total)
    print('Training Accuracy of current epoch: %.3f %%' % (100 * train_correct / train_total))
    print('Testing Accuracy of current epoch: %.3f %%' % (100 * correct / total))
    
    scheduler.step()

In [None]:
plt.plot(train_loss_list)
plt.title('Training Loss Plot of Each Epoch')
plt.xlabel('Epoch Number')
plt.ylabel('Training Loss')
plt.show()

plt.plot(train_acc_list)
plt.title('Training Accuracy Plot of Each Epoch')
plt.xlabel('Epoch Number')
plt.ylabel('Training Accuracy')
plt.show()

plt.plot(test_acc_list)
plt.title('Test Accuracy Plot of Each Epoch')
plt.xlabel('Epoch Number')
plt.ylabel('Testing Accuracy')
plt.show()