In [None]:
# Importing basic libs
import torch
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# User verification for device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using {device}.')

In [None]:
class VGG16(Module):
  """
  VGG16 (2015)
  This paper showed that deep convnets are better than wide convnets
  3 x 224 x 224 input
  """
  def __init__(self):
    super(VGG16, self).__init__()
    # ReLU non-linearity after every hidden layer
    self.NL = nn.ReLU()
    # Conv kernels fixed to (3x3) with stride of 1, spatial padding to conserve resolution
    self.C1 = nn.Conv2d(3, 64, 3, 1, 1)
    self.C2 = nn.Conv2d(64, 64, 3, 1, 1)
    # Subsampling performed by maxpooling with kernel (2x2) by stride 2
    self.SS1 = nn.MaxPool2d(2, 2)
    self.C3 = nn.Conv2d(64, 128, 3, 1 ,1)
    self.C4 = nn.Conv2d(128, 128, 3, 1, 1)
    self.SS2 = nn.MaxPool2d(2, 2)
    self.C5 = nn.Conv2d(128, 256, 3, 1, 1)
    self.C6 = nn.Conv2d(256, 256, 3, 1, 1)
    self.C7 = nn.Conv2d(256, 256, 1)
    self.SS3 = nn.MaxPool2d(2, 2)
    self.C8 = nn.Conv2d(256, 512, 3, 1, 1)
    self.C9 = nn.Conv2d(512, 512, 3, 1, 1)
    self.C10 = nn.Conv2d(512, 512, 1)
    self.SS4 = nn.MaxPool2d(2, 2)
    self.C11 = nn.Conv2d(512, 512, 3, 1, 1)
    self.C12 = nn.Conv2d(512, 512, 3, 1, 1)
    self.C13 = nn.Conv2d(512, 512, 1)
    self.SS5 = nn.MaxPool2d(2, 2)
    self.FC1 = nn.Linear(512 * 7 * 7, 4096)
    self.FC2 = nn.Linear(4096, 4096)
    self.FC3 = nn.Linear(4096, 2) #1000 in original paper
    self.OUT = nn.Softmax(dim = 1)
  def forward(self, X):
    y = self.C1(X)
    y = self.NL(y)
    y = self.C2(y)
    y = self.NL(y)
    y = self.SS1(y)
    y = self.C3(y)
    y = self.NL(y)
    y = self.C4(y)
    y = self.NL(y)
    y = self.SS2(y)
    y = self.C5(y)
    y = self.NL(y)
    y = self.C6(y)
    y = self.NL(y)
    y = self.C7(y)
    y = self.NL(y)
    y = self.SS3(y)
    y = self.C8(y)
    y = self.NL(y)
    y = self.C9(y)
    y = self.NL(y)
    y = self.C10(y)
    y = self.NL(y)
    y = self.SS4(y)
    y = self.C11(y)
    y = self.NL(y)
    y = self.C12(y)
    y = self.NL(y)
    y = self.C13(y)
    y = self.NL(y)
    y = self.SS5(y)
    y = torch.reshape(y, (y.shape[0], y.shape[1] * y.shape[2] * y.shape[3]))
    y = self.FC1(y)
    y = self.NL(y)
    y = self.FC2(y)
    y = self.NL(y)
    logits = self.FC3(y)
    probs = self.OUT(logits)
    return probs

In [None]:
# Just a sanity check
X = torch.ones([32, 3, 224, 224])
net = VGG16()
y = net(X)
print(y.shape)

In [None]:
# WARNING : VGG16 IS A HUGE MODEL AND IS NOT MEANT TO BE TRAINED ON SMALL DATASETS
# THIS IS SIMPLY A SAMPLE TRAINING PLATFORM
# VGG16 WILL SEVERELY OVERFIT ON SMALL DATASETS
import os
%mkdir data
%cd /content/data/
!wget http://files.fast.ai/data/examples/dogscats.tgz
!tar -zxvf dogscats.tgz


In [None]:
data_dir = '/content/data/dogscats'

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

imagenet_format = transforms.Compose([
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])

dsets = {x: datasets.ImageFolder(os.path.join(data_dir, x), imagenet_format)
         for x in ['train', 'valid']}

os.path.join(data_dir,'train')
dset_sizes = {x: len(dsets[x]) for x in ['train', 'valid']}

training_loader = DataLoader(dsets['train'], batch_size = 32, shuffle = True)
validation_loader = DataLoader(dsets['valid'], batch_size = 5, shuffle = False)

In [None]:
# Some sanity checks
def viz(image):
  image = image.numpy().transpose((1, 2, 0))
  plt.imshow(image)

for i, data in enumerate(training_loader):
  print("Checking training loader : ")
  batch, labels = data
  print(f'Batch shape : {batch.shape}')
  if i == 3:
    break
print("==================")
for i, data in enumerate(validation_loader):
  print("Checking validation loader : ")
  batch, labels = data
  print(f'Batch shape : {batch.shape} batch labels : {labels}')
  batch_viz = torchvision.utils.make_grid(batch)
  viz(batch_viz)
  if i == 3:
    break

# These are some default training parameters which are tested and indeed lead to learning
# but they could be improved
vgg = VGG16().to(device)
criterion = nn.CrossEntropyLoss()
lr = 0.001
optim = torch.optim.SGD(vgg.parameters(),lr = lr)

# Defining single epoch training pass
def trainer(model, dataloader, criterion, optimizer):
  model.train()
  running_loss = 0.0

  for X, y_T in dataloader:
      optimizer.zero_grad()

      X = X.to(device)
      y_T = y_T.to(device)

      probs = model(X)   
      loss = criterion(probs, y_T)
      running_loss += loss.item() * X.size(0)

      loss.backward()
      optimizer.step()
  epoch_loss = running_loss / len(dataloader.dataset)
  return model, optimizer, epoch_loss

# Define single epoch testing pass
def tester(model, dataloader, criterion, optimizer):
  model.eval()
  running_loss = 0.0

  for X, y_T in dataloader:
      X = X.to(device)
      y_T = y_T.to(device)

      probs = model(X)
      loss = criterion(probs, y_T)
      running_loss += loss.item() * X.size(0)
  epoch_loss = running_loss / len(dataloader.dataset)
  return model, epoch_loss

# Custom function to compute accuracy
def compute_accuracy(model, dataloader):
    correct_preds = 0 
    n = 0
    
    with torch.no_grad():
        model.eval()
        for X, y_T in dataloader:

            X = X.to(device)
            y_T = y_T.to(device)

            probs = model(X)
            _, predictions = torch.max(probs, 1)

            n += y_T.size(0)
            correct_preds += (predictions == y_T).sum()

    return correct_preds.float() / n

# Defining model trainer
def training(model, training_loader, validation_loader, criterion, optimizer, epochs):
  train_losses = []
  train_acc =[]
  valid_losses = []
  valid_acc = []

  for epoch in range(epochs):
    print(f'\nEpoch {epoch} : ')
    print("===========")
    
    model, optimizer, training_loss = trainer(model, training_loader, criterion, optimizer)
    training_acc = compute_accuracy(model, training_loader)
    train_losses.append(training_loss)
    train_acc.append(training_acc)

    with torch.no_grad():
      model, valid_loss = tester(model, validation_loader, criterion, optimizer)
      validation_acc = compute_accuracy(model, validation_loader)
      valid_losses.append(valid_loss)
      valid_acc.append(validation_acc)

    print(f'Training loss : {training_loss}')
    print(f'Training acc : {training_acc}')
    print(f'Validation loss : {valid_loss}')
    print(f'Validation acc : {validation_acc}')

  train_losses = np.array(train_losses) 
  valid_losses = np.array(valid_losses)
  train_acc = np.array(train_acc)
  valid_acc = np.array(valid_acc)

  fig, ax = plt.subplots(figsize = (10, 5))
  ax.plot(train_losses, color='blue', label='Training loss')
  ax.plot(train_acc, color='green', label='Training accuracy') 
  ax.plot(valid_losses, color='red', label='Validation loss')
  ax.plot(valid_acc, color='black', label='Validation accuracy')
  
  ax.set(title="Loss evolution", 
            xlabel='Epoch',
            ylabel='Loss/Accuracy') 
  ax.legend()
  fig.show()

training(vgg, training_loader, validation_loader, criterion, optim, 15)