Learning to Compose Domain-Specific Transformations for Data Augmentation
Or: Transformation Adversarial Networks for Data Augmentations (TANDA)
Paper (NeurIPS 2017): Learning to Compose Domain-Specific Transformations for Data Augmentation
Corresponding authors: Alex Ratner (firstname.lastname@example.org), Henry Ehrenberg (email@example.com)
TANDA blog post
*For more on using Transformation Functions (TFs) for data augmentation, see the Snorkel project
NEW: an easy-to-use Keras interface
Just in time for NeurIPS 2017, we're releasing an easy-to-use substitute for Keras'
ImageDataGenerator data augmentation
class. Just swap in
TANDAImageDataGenerator and you'll
be using our trained data augmentation models! For a recipe on how to use it,
All we did was copy
Keras' CIFAR-10 CNN example script
and plug in the
TANDAImageDataGenerator. Easy as that.
Using data augmentation on benchmark machine learning tasks, like MNIST and CIFAR-10, yields large performance gains. But using data augmentation on new tasks can prove difficult. We've found that while it's usually easy for practitioners to
- obtain large quantities of labeled data; and
- come up with individual label-preserving data transformations (e.g. small image rotations),
constructing and tuning the more sophisticated compositions typically needed to achieve state-of-the-art results is a time-consuming manual task. The TANDA library unlabeled data points and arbitrary, user-provided transformation functions as input, and learns how to compose them to generate realistic, augmented data points.
The original data points (blue) are distributed at random within the purple dotted line. We define several random displacement vectors as transformations, and the orange points are augmented copies of blue data points. At first, the transformations are applied effectively at random, yielding many augmented points outside of the true data distribution. After a few iterations, the augmentation model learns how to create sequences of displacements that yield augmented data points within the distribution of interest.
We learned an augmentation model for the MNIST data set using rotation, shear, elastic deformation, and rescaling transformation functions. The figure shows 100 augmented MNIST images. While they initially do not look like realistic digits, the model learns to compose the image transformations to generate realistic augmented images.
pip install --requirement python-package-requirement.txt
If you're using the Keras interface, you'll need to install Keras as well.
Note: currently, TANDA only works with TensorFlow 1.2. This is enforced in
python-package-requirement.txt. We do not recommend using newer versions
right now, as models will not train correctly.
TANDA includes example TAN training scripts for MNIST and CIFAR-10. You'll need
to add the TANDA library to your path first. From the top-level
directory, just run
The example scripts can be found in
example-scripts. To train an MNIST TAN:
Before running experiments with CIFAR-10, you'll need to download the data:
cd experiments/cifar10 ./download-data.sh cd $TANDAHOME
Then to train a CIFAR-10 TAN, run:
Running experiments with custom parameters
To run a single experiment, for example on CIFAR-10:
source set_env.sh python experiments/cifar10/train.py --run_name test_run [FLAGS]
The vast majority of flags can be found in
individual train scripts (e.g.
experiments/cifar10/train.py) may also have
run_type flag determines the mode to run in:
tanda-full[default]: Train a TAN, then use this to train a data-augmented end model
tan-only: Train TAN only
tanda-pretrained: Load trained TAN, then use this to train a data-augmented end model
random: Train a randomly-augmented end model
baseline: Train an end model with no data augmentation
TensorBoard visualizations are available during (and after) training:
tensorboard --logdir experiments/log/[DATESTAMP]/[RUN_NAME]_[TIMESTAMP]
To launch a set of experiments in parallel, first define a config file (see
experiments/cifar10/config/ for examples), then run e.g.:
source set_env.sh python experiments/launch_run.py --script experiments/cifar10/train.py --config experiments/cifar10/config/tan_search_config.json
To see quick stats from the TAN training, run:
python experiments/print_tan_stats.py --log_root [LOG_ROOT]
One procedure is to train a set of TAN models (setting
choose the best ones (by e.g. visual appearance or generative-to-random
loss ratio), then run these with end models. This can be done in parallel:
python experiments/launch_end_models.py --script experiments/cifar10/train.py --end_model_config experiments/cifar10/config/end_model_config.json --tan_log_root [LOG_ROOT] --model_indexes 1 5 7