Skip to content

RenkunNi/MetaAug

Repository files navigation

Data Augmentation for Meta-Learning

Abstract

Conventional image classifiers are trained by randomly sampling mini-batches of images. To achieve state-of-the-art performance, sophisticated data augmentation schemes are used to expand the amount of training data available for sampling. In contrast, meta-learning algorithms sample not only images, but classes as well. We investigate how data augmentation can be used not only to expand the number of images available per class, but also to generate entirely new classes. We systematically dissect the meta-learning pipeline and investigate the distinct ways in which data augmentation can be integrated at both the image and class levels. Our proposed meta-specific data augmentation significantly improves the performance of meta-learners on few-shot classification benchmarks.

Dependencies

Usage

Installation

  1. Clone this repository:
    git clone https://github.com/RenkunNi/MetaAug.git
    cd MetaAug
  2. Download and decompress dataset files: miniImageNet (courtesy of Spyros Gidaris), CIFAR-FS

Meta-training Examples

  1. To train with data augmentations (i.e. query cutmix and task large rotation) on 5-way CIFAR-FS:

    python train_aug.py --gpu 0 --save-path "./experiments/ResNet_R2D2_qcm_tlr" --train-shot 5 \
    --head R2D2 --network ResNet --dataset CIFAR_FS --query_aug cutmix --q_p 1. --task_aug Rot90 --t_p 0.25
  2. To train Meta-MaxUp (4 samples) on 5-way CIFAR-FS:

    python train_maxup.py --gpu 0,1,2,3 --save-path "./experiments/ResNet_R2D2_maxup_4" --train-shot 5 \
    --head R2D2 --network ResNet --dataset CIFAR_FS --m 4

Meta-testing Examples

  1. To test models on 5-way-N-shot CIFAR-FS:
python test.py --gpu 0 --load ./experiments/ResNet_R2D2_maxup_4/best_model.pth --episode 1000 \
--way 5 --shot N --query 15 --head R2D2 --network ResNet --dataset CIFAR_FS 
  1. To test models on 5-way-N-shot CIFAR-FS with shot augmentation (flip):
python test.py --gpu 0 --load ./experiments/ResNet_R2D2_maxup_4/best_model.pth --episode 1000 \
--way 5 --shot N --query 15 --head R2D2 --network ResNet --dataset CIFAR_FS --shot_aug fliplr --s_du 2
  1. To test models on 5-way-N-shot CIFAR-FS with shot augmentation (flip) and ensemble:
python test_ens.py --gpu 0 --load ./experiments/ResNet_R2D2_maxup_4/ --episode 1000 \
--way 5 --shot N --query 15 --head R2D2 --network ResNet --dataset CIFAR_FS --shot_aug fliplr --s_du 2

Acknowledgments

This code is based on the implementations of Prototypical Networks, Dynamic Few-Shot Visual Learning without Forgetting, DropBlock, MetaOptNet and TaskLevelAug.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages