Skip to content
Efficient Learning of Augmentation Policy Schedules
Branch: master
Clone or download
arcelien Add test PBA search script, remove dataset assumptions (#2)
* Remove hardcoded assumptions about dataset

* Add example script, fix syntax

* Add missing hparams

* Fix CIFAR dataloading, update script

* update test search

* implicitly compute image size

* Remove more image size assumptions

* reduce logspam

* modify test batch size

* Reduce logging, reduce test script size
Latest commit dbac418 May 22, 2019
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
autoaugment Initial commit May 14, 2019
datasets Initial commit May 14, 2019
figs Initial commit May 14, 2019
pba Add test PBA search script, remove dataset assumptions (#2) May 23, 2019
results Initial commit May 14, 2019
schedules Initial commit May 14, 2019
scripts Add test PBA search script, remove dataset assumptions (#2) May 23, 2019
.gitignore Initial commit May 14, 2019
.style.yapf Initial commit May 14, 2019
LICENSE Adding LICENSE, update README May 15, 2019
README.md Adding LICENSE, update README May 15, 2019
requirements.txt Initial commit May 14, 2019

README.md

Population Based Augmentation (PBA)

Table of Contents

  1. Introduction
  2. Getting Started
  3. Reproduce Results
  4. Run PBA Search
  5. Citation

Introduction

Population Based Augmentation (PBA) is a algorithm that quickly and efficiently learns data augmentation functions for neural network training. PBA matches state-of-the-art results on CIFAR with one thousand times less compute, enabling researchers and practitioners to effectively learn new augmentation policies using a single workstation GPU.

This repository contains code for the work "Population Based Augmentation: Efficient Learning of Augmentation Schedules" (http://arxiv.org/abs/1905.05393) in TensorFlow and Python 2. It includes training of models with the reported augmentation schedules and discovery of new augmentation policy schedules.

See below for a visualization of our augmentation strategy.

Getting Started

Install requirements

pip install -r requirements.txt

Download CIFAR-10/CIFAR-100 datasets

bash datasets/cifar10.sh
bash datasets/cifar100.sh

Reproduce Results

Dataset Model Test Error (%)
CIFAR-10 Wide-ResNet-28-10 2.58
Shake-Shake (26 2x32d) 2.54
Shake-Shake (26 2x96d) 2.03
Shake-Shake (26 2x112d) 2.03
PyramidNet+ShakeDrop 1.46
Reduced CIFAR-10 Wide-ResNet-28-10 12.82
Shake-Shake (26 2x96d) 10.64
CIFAR-100 Wide-ResNet-28-10 16.73
Shake-Shake (26 2x96d) 15.31
PyramidNet+ShakeDrop 10.94
SVHN Wide-ResNet-28-10 1.18
Shake-Shake (26 2x96d) 1.13
Reduced SVHN Wide-ResNet-28-10 7.83
Shake-Shake (26 2x96d) 6.46

Scripts to reproduce results are located in scripts/table_*.sh. One argument, the model name, is required for all of the scripts. The available options are those reported for each dataset in Tables 1-4 of the paper, among the choices: wrn_28_10, ss_32, ss_96, ss_112, pyramid_net. Hyperparamaters are also located inside each script file.

For example, to reproduce CIFAR-10 results on Wide-ResNet-28-10:

bash scripts/table_1_cifar10.sh wrn_28_10

To reproduce Reduced SVHN results on Shake-Shake (26 2x96d):

bash scripts/table_4_svhn.sh rsvhn_ss_96

A good place to start is Reduced SVHN on Wide-ResNet-28-10 which can complete in under 10 minutes on a Titan XP GPU reaching 91%+ test accuracy.

Running the larger models on 1800 epochs may require multiple days of training. For example, CIFAR-10 PyramidNet+ShakeDrop takes around 9 days on a Tesla V100 GPU.

Run PBA Search

Run PBA search on Wide-ResNet-40-2 with the file scripts/search.sh. One argument, the dataset name, is required. Choices are rsvhn or rcifar10.

A partial GPU size is specified to launch multiple trials on the same GPU. Reduced SVHN takes around an hour on a Titan XP GPU, and Reduced CIFAR-10 takes around 5 hours.

CUDA_VISIBLE_DEVICES=0 bash scripts/search.sh rsvhn

The resulting schedules used in search can be retreived from the Ray result directory, and the log files can be converted into policy schedules with the parse_log() function in pba/utils.py. For example, policy schedule learned on Reduced CIFAR-10 over 200 epochs is split into probability and magnitude hyperparameter values (the two values for each augmentation operation are merged) and visualized below:

Probability Hyperparameters over Time Magnitude Hyperparameters over Time

Citation

If you use PBA in your research, please cite:

@inproceedings{ho2019pba,
  title     = {Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules},
  author    = {Daniel Ho and
               Eric Liang and
               Ion Stoica and
               Pieter Abbeel and
               Xi Chen
  },
  booktitle = {ICML},
  year      = {2019}
}
You can’t perform that action at this time.