# Progressive Resizing Baseline
**Paper:** https://www.fast.ai/posts/2018-04-30-dawnbench-fastai.html  
Appears to be a commonly accepted mechanism for training on smaller images at the start of training, and gradually increasing image size as you train further. Effectively the opposite of dynamic resolution, except all images in dataset are the same size at one time.


**Extra Resources:**
*   https://miguel-data-sc.github.io/2017-11-23-second/
*   https://docs.mosaicml.com/projects/composer/en/stable/method_cards/progressive_resizing.html

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import shutil

# copy dataset from google drive to local storage
drive_data_path = '/content/drive/MyDrive/dynamic_res/imagenette2'
local_data_path = '/content/imagenette2'

if not os.path.exists(local_data_path):
    shutil.copytree(drive_data_path, local_data_path)
else:
    print(f'dataset already exists at {local_data_path}')

data_path = local_data_path

In [None]:

import argparse
import time
import csv
import torch
import torch.nn as nn
import torch.optim as optim

from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import sys

def parse_args():
    argv = [arg for arg in sys.argv if not arg.startswith('-f')]

    parser = argparse.ArgumentParser(description='PyTorch Imagenette Progressive Resizing Training')

    parser.add_argument('data', metavar='DIR', help='path to dataset')
    parser.add_argument('--arch')
    parser.add_argument('--epochs', type=int)
    parser.add_argument('--batch-size', type=int)
    parser.add_argument('--workers', type=int)
    parser.add_argument('--r-min', type=int)
    parser.add_argument('--r-start', type=int)
    parser.add_argument('--r-max', type=int)
    parser.add_argument('--gamma', type=float)
    parser.add_argument('--reassign-epoch', type=int)

    parser.add_argument('--lr', default=0.1, type=float)
    parser.add_argument('--momentum', default=0.9, type=float)
    parser.add_argument('--weight-decay', default=1e-4, type=float)
    parser.add_argument('--print-freq', default=10, type=int)

    return parser.parse_args(argv[1:])

class AverageMeter:
    def __init__(self, name):
        self.name = name
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []

        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))

        return res

# inverted geometric resolution schedule (starts low, goes high)
def compute_resolution_schedule(r_start, r_max, gamma):
    schedule = []
    r = r_start

    while True:
        schedule.append(r)
        r_next = int(r * gamma)
        r_next = min(r_next, r_max)

        if r_next == r:
            break

        r = r_next

    return schedule

class RobustImageFolder(torch.utils.data.Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        try:
            return self.base_dataset[idx]
        except Exception as e:
            print(f'skipping corrupted image at index {idx}: {e}')
            return self.__getitem__((idx + 1) % len(self.base_dataset))

def train_epoch(train_loader, model, criterion, optimizer, epoch, device, args):
    losses = AverageMeter('Loss')
    top1 = AverageMeter('Acc@1')
    top5 = AverageMeter('Acc@5')

    model.train()
    epoch_start = time.time()

    for i, (images, target) in enumerate(train_loader):
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        output = model(images)
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        if i % args.print_freq == 0:
            print(f'Epoch [{epoch}][{i}/{len(train_loader)}] Loss {losses.avg} Acc@1 {top1.avg} Acc@5 {top5.avg}')

    epoch_time = time.time() - epoch_start

    return losses.avg, top1.avg, top5.avg, epoch_time

def validate(val_loader, model, criterion, device, args):
    losses = AverageMeter('Loss')
    top1 = AverageMeter('Acc@1')
    top5 = AverageMeter('Acc@5')

    model.eval()

    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            output = model(images)
            loss = criterion(output, target)

            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            if i % args.print_freq == 0:
                print(f'Test [{i}/{len(val_loader)}] Loss {losses.avg} Acc@1 {top1.avg} Acc@5 {top5.avg}')

    print(f'Val: Acc@1 {top1.avg} Acc@5 {top5.avg}')
    return losses.avg, top1.avg, top5.avg

def main():
    args = parse_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f'resolution range: {args.r_start} - {args.r_max}')
    print(f'gamma: {args.gamma}')

    resolution_schedule = compute_resolution_schedule(args.r_start, args.r_max, args.gamma)
    current_res_idx = 0

    model = models.__dict__[args.arch](num_classes=10)
    model = model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    base_train_dataset = datasets.ImageFolder(traindir, transforms.Compose([transforms.RandomResizedCrop(resolution_schedule[current_res_idx]), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]))
    train_dataset = RobustImageFolder(base_train_dataset)

    base_val_dataset = datasets.ImageFolder(valdir, transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]))
    val_dataset = RobustImageFolder(base_val_dataset)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    log_file = f'training_log_{args.arch}_progressive.csv'
    with open(log_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch', 'train_loss', 'train_top1', 'epoch_time'])

    print(f'logging to: {log_file}')
    total_start = time.time()

    for epoch in range(args.epochs):
        train_loss, train_top1, train_top5, epoch_time = train_epoch(train_loader, model, criterion, optimizer, epoch, device, args)
        scheduler.step()

        with open(log_file, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch, train_loss, train_top1.item(), epoch_time])

        print(f'epoch {epoch} completed in {epoch_time}s')

        # progressive resizing -------------------------------------------------
        # reassign resolutions every reassign_epoch starting from reassign_epoch-1
        if (epoch + 1) % args.reassign_epoch == 0 and epoch >= args.reassign_epoch - 1:

            # get current resolution level and move to next higher resolution
            current_res_idx = min(current_res_idx + 1, len(resolution_schedule) - 1)
            base_train_dataset = datasets.ImageFolder(traindir, transforms.Compose([transforms.RandomResizedCrop(resolution_schedule[current_res_idx]), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]))
            train_dataset = RobustImageFolder(base_train_dataset)
            train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
        # progressive resizing -------------------------------------------------

    total_time = time.time() - total_start
    print(f'\ntotal training time: {total_time}s')

    # final test set evaluation
    test_loss, test_top1, test_top5 = validate(val_loader, model, criterion, device, args)
    print(f'\nfinal test results:')
    print(f'Test Loss: {test_loss}')
    print(f'Test Acc@1: {test_top1}')
    print(f'Test Acc@5: {test_top5}')

    # append final test results to CSV
    with open(log_file, 'a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['FINAL_TEST', test_loss, test_top1.item(), total_time])

    print(f'results saved to: {log_file}')

    # save model to google drive
    save_path = '/content/drive/MyDrive/dynamic_res/model_progressive.pth'
    torch.save(model.state_dict(), save_path)
    print(f'\nmodel saved to: {save_path}')

In [None]:
sys.argv = ['train.py', data_path, '--arch', 'resnet18', '--epochs', '30', '--batch-size', '128', '--workers', '2', '--r-min', '112', '--r-start', '112', '--r-max', '224',
            '--gamma', '2.0', '--reassign-epoch', '10']
main()

In [None]:
from google.colab import files
files.download('training_log_resnet18_progressive.csv')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>