In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
from torchvision.transforms import Compose, ToTensor, Resize
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
# change the path based on your Drive
! cd gdrive/MyDrive/AVision_CW
! unzip /content/gdrive/MyDrive/AVision_CW/train.zip

In [None]:
#these are just random not calculated
#need to be calculated
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.25, 0.25, 0.25])

In [None]:
#split the train data to train and validation sets
def train_val_dataset(dataset, val_split=0.25):
    train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_split)
    datasets = {}
    
    datasets['train'] = Subset(dataset, train_idx)
    datasets['val'] = Subset(dataset, val_idx)

#data augmentation starts
#these are random augmentations do not keep them when you use your own augmentation strategy      
    datasets['train'].dataset.transform = transforms.Compose([

    	transforms.RandomResizedCrop(224),
    	transforms.RandomHorizontalFlip(),
    	transforms.ToTensor(),
    	transforms.Normalize(mean, std)
            ])
            
    datasets['val'].dataset.transform = transforms.Compose([
 		
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
            ])
    
    return datasets

In [None]:
#loads the unzipped train folder from colab
dataset = ImageFolder('/content/train')

In [None]:
#print the number of total images/train images/val images
print(len(dataset))
image_datasets = train_val_dataset(dataset)
print(len(image_datasets['train']))
print(len(image_datasets['val']))

In [None]:
#fastest loader on colab with batch_size=64 and num_workers=2
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=16,
                                             shuffle=True, num_workers=2)
              for x in ['train','val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train','val']}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def train_model(model,dataloaders,criterion,scheduler,optimizer,num_epochs=25):

  # a dictionary used to store all loss and accuracy
  stats = {"train_loss":[],"train_acc":[],"val_acc":[],"val_loss":[]}

  # training by epoch
  for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch + 1, num_epochs))
    print('-' * 10)
    # train
    model.train()
    running_loss = 0.0
    correct = 0
    t1 = time.perf_counter()
    for step, data in enumerate(dataloaders['train'], start=0):
      images, labels = data
      optimizer.zero_grad()
      outputs = model(images.to(device))
      loss = criterion(outputs, labels.to(device))
      loss.backward()
      optimizer.step()

      running_loss += loss.item()

      # get the training accuracy
      _, predicted = outputs.max(1)
      correct += (predicted == labels.to(device)).sum().item()

      # print train process
      rate = (step + 1) / len(dataloaders['train'])
      a = "*" * int(rate * 50)
      b = "." * int((1 - rate) * 50)

      print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")

    scheduler.step()

    # print statistics
    print()
    print("running time is: " + str(time.perf_counter()-t1) + " seconds")

    #calculate the training accuracy
    train_accurate = correct / len(dataloaders['train'])

    # validate
    model.eval()  
    acc = 0.0  # accumulate accurate number / epoch
    best_acc = 0.0
    val_loss = 0.0

    with torch.no_grad():
      for v_step,val_data in enumerate(dataloaders['val']):
        val_images, val_labels = val_data
        outputs = model(val_images.to(device))
        predict_y = torch.max(outputs, dim=1)[1]
        acc += (predict_y == val_labels.to(device)).sum().item()
        val_loss += criterion(outputs, val_labels.to(device)).item()
      # Calculate the validation accuracy
      val_accurate = acc / len(dataloaders['val'])
      if val_accurate > best_acc:
        best_acc = val_accurate

    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' % (epoch + 1, running_loss / step, val_accurate))
    
    # save the status

    stats["train_loss"].append(running_loss / step)
    stats["train_acc"].append(train_accurate)
    stats["val_acc"].append(val_accurate)
    stats["val_loss"].append(val_loss / v_step)


  return stats, best_acc

In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=False)

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.001)

# StepLR Decays the learning rate of each parameter group by gamma every step_size epochs
# Decay LR by a factor of 0.1 every 7 epochs
step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

model.to(device)

# load checkpoint from google drive
# FILE = "/content/gdrive/MyDrive/resnet50_fconv_model_best.pth.tar"
# checkpoint = torch.load(FILE)
# model = nn.DataParallel(model)
# model.load_state_dict(checkpoint['state_dict'])

# start to train
stats, _ = train_model(model, dataloaders, criterion, step_lr_scheduler, optimizer, num_epochs=5)

In [None]:
def plot_figure(name, stats):
  fig_1 = plt.figure(figsize=(8, 4))
  ax_1 = fig_1.add_subplot(111)
  for k in ['train_loss', 'val_loss']:
    item = stats[k]
    ax_1.plot(np.arange(1, len(item)+1), item, label='{}_{}'.format(name, k))

  ax_1.legend(loc=0)
  ax_1.set_ylabel('Loss')
  ax_1.set_xlabel('Epoch number')

  # Plot the change in the validation and training set accuracy over training.
  fig_2 = plt.figure(figsize=(8, 4))
  ax_2 = fig_2.add_subplot(111)

  for k in ['train_acc', 'val_acc']:
    item = stats[k]
    ax_2.plot(np.arange(1, len(item)+1), item, label='{}_{}'.format(name, k))

  ax_2.legend(loc=0)
  ax_2.set_ylabel('Accuracy')
  ax_2.set_xlabel('Epoch number')

  return fig_1, fig_2

In [None]:
%matplotlib inline
plt.style.use('ggplot')
plot_figure("Res50",stats)