In [None]:
# just if you run on colab
# link colab and drive
#from google.colab import drive
#drive.mount("/content/drive", force_remount=True)

# CV project
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import os
import torchvision
from torchvision import datasets, transforms, models
from IPython import display
from PIL import Image
import glob
import matplotlib
from matplotlib.pyplot import *
import random
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
from cvxopt import matrix, solvers
torch.manual_seed(11)

Build our own dataset and dataloaders 

In [None]:
class customDataset(Dataset):
  def __init__(self, rootPath, resize_transformation, data_augmentation_transformation=None, augment=True):
    ''' 
      @rootPath: path of the folder containing class subfolders
      @transformation: transformation to be applied to each image
      @data_augmentation: transformations to be applied to train images
      @augment: boolean indicating if data augmentation should be performed
    '''
    self.data = torchvision.datasets.ImageFolder(rootPath)
    self.transformation = resize_transformation
    self.data_augmentation = data_augmentation_transformation
    self.augment = True
    
  def __getitem__(self, key):
    
    true_class = self.data[key][1]
    im = Image.open(self.data.imgs[key][0])
    if self.augment and self.data_augmentation!=None:
      img_tensor = self.data_augmentation(im)
    else:
      img_tensor = self.transformation(im)

    return img_tensor, true_class

  def __len__(self):
    return len(self.data)

  def set_augment(self, value):
    self.augment = value

  def change_transformation(self, new_transformation):
    self.transformation = new_transformation

  def change_augmentation(self, new_augmentation):
    self.data_augmentation = new_augmentation

  def get_keys(self):
    return range(len(self.data))

def split(dataset, val_size):
    '''
    @ dataset: a customDataset object
    @ val_size: % of validation data
    '''
    index = list(dataset.get_keys())
    val_per_class = int(val_size*100)
    validation_index = []
    for i in range(15):
        idx = random.sample(range(100*i,100*(i+1)),val_per_class)
        validation_index= validation_index + idx
    train_index = list(set(index)-set(validation_index))

    train_sampler = SubsetRandomSampler(train_index)
    validation_sampler = SubsetRandomSampler(validation_index)

    return train_sampler, validation_sampler

def loaders(dataset, val_size, batch_size, num_workers):
  ''' 
    @dataset: a customDataset object
    @val_size: % validation data
    @batch_size: the number of examples in each batch
    @num_workers: number of subprocesses to use in the data loader
  '''

  train_sampler, validation_sampler = split(dataset, val_size)
  train_loader = DataLoader(dataset,
                            batch_size = batch_size,
                            sampler = train_sampler,
                            num_workers = num_workers)
  val_loader = DataLoader(dataset,
                          batch_size = int(val_size*len(dataset)),
                          sampler = validation_sampler,
                          num_workers = num_workers)
  return train_loader, val_loader

Training function

In [1]:
def training_with_scheduler(net, trainLoader,valLoader, optimizer, criterion, val_patience, validate_each, scheduler=None):

    # here I save loss and accuracy
    train_loss = []
    val_loss = []
    train_accuracy = []
    val_accuracy = []

    # early stopping
    best_net = copy.deepcopy(net.state_dict())
    best_loss = 100.0
    worsening_count = 0

    net.train()
    n_batches = len(trainLoader)
    for e in range(epochs):
        correct_classified = 0
        for i, data in enumerate(trainLoader):

            batch = data[0].to(device)
            batch = batch.float()
            labels = data[1].to(device)    

            optimizer.zero_grad() 
            outputs = net(batch)

            loss = criterion(outputs, labels) 
            train_loss.append(loss.item())
            predicted_class = torch.argmax(outputs, dim=1)
            correct_classified = correct_classified + sum((predicted_class==labels).int())
            acc = sum((predicted_class==labels).int())/batch.shape[0]
            train_accuracy.append(acc)

            loss.backward()
            optimizer.step()

            # parameter to decide how often to validate
            if i % validate_each ==0:
                with torch.no_grad():
                    trainingSet.set_augment(False)
                    valLoss, valAcc = validate(valLoader,net,criterion)
                    # save validation loss and accuracy
                    val_loss.append(valLoss)
                    val_accuracy.append(valAcc)
                    trainingSet.set_augment(True)

                # if validation loss increase (at least +1%), increase the counter
                if valLoss>best_loss:
                    worsening_count = worsening_count+1
                    # if I exceed the patience, early stop
                    if worsening_count > val_patience:
                        return [best_net, train_loss, val_loss, train_accuracy, val_accuracy]
                # else reset the counter and use actual validation loss as reference, save the net
                else:
                    worsening_count = 0
                    best_loss = valLoss
                    best_net = copy.deepcopy(net.state_dict())

                print("[LR]: {:.4f}\n".format(scheduler.get_last_lr()[0]))
                print("[EPOCH]: {}, [BATCH]: {}/{}, [LOSS]: t {}, v {},\t [ACC.]: t {},\t v {}".format(e, i, n_batches, loss.item(), valLoss, acc, valAcc))


        if scheduler is not None:
            scheduler.step()
    return [best_net, train_loss, val_loss, train_accuracy, val_accuracy]

# validation 
def validate(valLoader, net, criterion):
  correct_count=0
  size = 0
  for i, data in enumerate(valLoader):
    batch = data[0].to(device)
    batch = batch.float()
    labels = data[1].to(device)    

    outputs = net(batch)
    loss = criterion(outputs, labels) 
    predicted_class = torch.argmax(outputs, dim=1)
    correct_count = correct_count + sum((predicted_class==labels).int())
    size = size + batch.shape[0]
  acc = correct_count/size
  return [loss.item(), acc]