Code for DeepCompress-ViT: Rethinking Model Compression to Enhance Efficiency of Vision Transformers at the Edge
- Python 3.8.19
- PyTorch 2.3.1
- torchvision 0.18.1
- timm 1.0.7
Install dependencies:
pip install -r requirements.txt# Train a compressed DeiT-Small backbone (rank 277)
python main.py \
--model_name deit_small_patch16_224 \
--batch_size 256 \
--epochs 275 \
--eval_interval 20000 \
--mixed_precision \
--device cuda:0 \
--base_dir small_rank_277 \
--distillation_weight 3000 \
--initial_iters 1000 \
--finetune_other_params \
--rank 277 \
--distilled_model
# Train a compressed DeiT-Base backbone (rank 502)
python main.py \
--model_name deit_base_patch16_224 \
--batch_size 256 \
--epochs 275 \
--eval_interval 20000 \
--mixed_precision \
--device cuda:0 \
--base_dir base_rank_502 \
--distillation_weight 3000 \
--initial_iters 1000 \
--finetune_other_params \
--rank 502 \
--distilled_modelDownload our compressed checkpoints from Google Drive:
The files should be placed in the following directory structure:
saved_models/
├── small_rank_277/
│ └── deit_small_patch16_224.pth
└── base_rank_502/
└── deit_base_patch16_224.pth
# Inference with small model
python inference.py \
--model_name deit_small_patch16_224 \
--batch_size 256 \
--device cuda:0 \
--rank 277 \
--mixed_precision \
--state_path saved_models/small_rank_277/deit_small_patch16_224.pth
# Inference with base model
python inference.py \
--model_name deit_base_patch16_224 \
--batch_size 256 \
--device cuda:0 \
--rank 502 \
--mixed_precision \
--state_path saved_models/base_rank_502/deit_base_patch16_224.pthEvaluate our models on the CIFAR-10 dataset (resized to 224x224) using our pre-trained checkpoints.
Download Pre-trained Models:
- DeiT-Small (uncompressed, 97.29% Top-1 Acc)
- DeepCompress-ViT-DeiT-S (compressed rank 277, 96.73% Top-1 Acc)
Directory Structure:
saved_models/
├── small_rank_277_cifar10/
│ └── deit_small_patch16_224.pth
└── deit_small_cifar10_best.pth
Run Evaluation:
# Evaluate compressed DeiT-Small on CIFAR-10
python inference.py \
--model_name deit_small_patch16_224 \
--batch_size 256 \
--device cuda:0 \
--rank 277 \
--mixed_precision \
--state_path saved_models/small_rank_277_cifar10/deit_small_patch16_224.pth \
--dataset cifar10