Skip to content

A simple minimal implementation of Reversible Vision Transformers

Notifications You must be signed in to change notification settings

karttikeya/minREV

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

minREV

Inspired by minGPT

A PyTorch reimplementation of Reversible Vision Transformer architecture that is prefers simplicity over tricks, hackability over tedious organization, and interpretability over generality.

It is meant to serve as an educational guide for newcomers that are not familiar with the reversible backpropagation algorithm and reversible vision transformer.

The entire Reversible Vision Transformer is implemented from scratch in under <300 lines of pytorch code, including the memory-efficient reversible backpropagation algorithm (<100 lines). Even the driver code is < 150 lines. The repo supports both memory-efficient training and testing on CIFAR-10.

💥 The CVPR 2021 oral talk for a 5-minute introduction to RevViT.

💥 A gentle and in-depth 15 minute introduction to RevViT.

💥 Additional implementations of reversible MViT and Swin for examples of hierarchical transformers.

💥 New implementations of fast, parallelized reversible backpropagation (PaReprop), featured as a spotlight at the workshop on Transformers for Vision @ CVPR 2023.

Setting Up

Simple! 🌟

(if using conda for env, otherwise use pip)

conda create -n revvit python=3.8
conda activate revvit
conda install pytorch torchvision pytorch-cuda=11.7 -c pytorch -c nvidia

If you wish to also use RevSwin and RevMViT, also install timm.

conda install timm=0.9.2

Code Organization

The code organization is also minimal 💫:

  • rev.py defines the reversible vision model that supports:
    • The vanilla backpropagation
    • The memory-efficient reversible backpropagation
  • main.py that has the driver code for training on CIFAR-10. By default, --model vit is enabled.
  • fast_rev.py contains a simplified implementation of fast, parallelized reversible backpropagation (PaReprop). Use --pareprop True to enable.
  • rev_swin.py contains the reversible Swin Transformer with PaReprop functionality. Use --model swin to enable.
  • rev_mvit.py and utils.py contain the reversible MViTv2 with PaReprop functionality. Use --model mvit to enable.

Running CIFAR-10

We currently provide three model options for reversible training: ViT, Swin, and MViT. The architectures in some cases have been simplified to run on CIFAR-10.

Reversible ViT 🍦

python main.py --lr 1e-3 --bs 128 --embed_dim 128 --depth 6 --n_head 8 --epochs 100 --model vit

By default, the --model flag is set to vit. This will achieve 80%+ validation accuracy on CIFAR-10 from scratch training!

Here are the Training/Validation Logs 💯

python main.py --lr 1e-3 --bs 128 --embed_dim 128 --depth 6 --n_head 8 --epochs 100 --model vit --vanilla_bp True

Will train the same network but without memory-efficient backpropagation to the same accuracy as above. Hence, there is no accuracy drop from the memory-efficient reversible backpropagation.

Here are the Training/Validation Logs 💯

Reversible Swin 🐬

python main.py --lr 1e-3 --bs 128 --embed_dim 128 --depth 4 --n_head 4 --epochs 100 --model swin

This will achieve 80%+ validation accuracy on CIFAR-10 from scratch training for a Reversible Swin!

Reversible MViT 🏰

python main.py --lr 1e-3 --bs 128 --embed_dim 64 --depth 4 --n_head 1 --epochs 100 --model mvit

This will achieve 90%+ validation accuracy on CIFAR-10 from scratch training for a Reversible MViT!

You can find Training/Validation Logs for both Swin and MViT here 💯

👁️ Note: The relatively low accuracy is due to difficulty in training vision transformer (reversible or vanilla) from scratch on small datasets like CIFAR-10. Also likely is that a much higher accuracy can be achieved with the same code, using a better chosen model design and optimization parameters. The authors have done no tuning since this repository is meant for understanding code, not pushing performance.

Mixed precision training

Mixed precision training is also supported and can be enabled by adding --amp True flag to above commands. Training progresses smoothly and achieves 80%+ validation accuracy on CIFAR-10 similar to training without AMP.

📝 Note: Pytorch vanilla AMP, maintains full precision (fp32) on weights and only uses half-precision (fp16) on intermediate activations. Since reversible is already saving up on almost all intermediate activations (see video for examplanation), using AMP (ie half-precision on activations) brings little additional memory savings. For example, on a 16G V100 setup, AMP can improve rev maximum CIFAR-10 batch size from 12000 to 14500 ( ~20%). At usual training batch size (128) there is small gain in GPU training memory (about 4%).

Distributed Data Parallel Training

There are no additional overheads for DDP training with reversible that progresses the same as vanilla training. All results in paper (also see below) are obtained in DDP setups (>64 GPUs per run). However, implementing distributed training is not commensurate with the purpose of this repo, and instead can be found in the pyslowfast distributed training setup.

Running ImageNet, Kinetics-400 and more

For more usecases such as reproducing numbers from original paper, see the full code in PySlowFast that supports

  • ImageNet
  • Kinetics-400/600/700
  • RevViT, all sizes with configs
  • RevMViT, all sizes with configs

to state-of-the-art accuracies.

About

A simple minimal implementation of Reversible Vision Transformers

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published