Skip to content

ML-Security-Research-LAB/DeepCompress-ViT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DeepCompress-ViT (CVPR 2025)

Code for DeepCompress-ViT: Rethinking Model Compression to Enhance Efficiency of Vision Transformers at the Edge

Requirements

  • Python 3.8.19
  • PyTorch 2.3.1
  • torchvision 0.18.1
  • timm 1.0.7

Install dependencies:

pip install -r requirements.txt

1. Training

# Train a compressed DeiT-Small backbone (rank 277)
python main.py \
  --model_name deit_small_patch16_224 \
  --batch_size 256 \
  --epochs 275 \
  --eval_interval 20000 \
  --mixed_precision \
  --device cuda:0 \
  --base_dir small_rank_277 \
  --distillation_weight 3000 \
  --initial_iters 1000 \
  --finetune_other_params \
  --rank 277 \
  --distilled_model

# Train a compressed DeiT-Base backbone (rank 502)
python main.py \
  --model_name deit_base_patch16_224 \
  --batch_size 256 \
  --epochs 275 \
  --eval_interval 20000 \
  --mixed_precision \
  --device cuda:0 \
  --base_dir base_rank_502 \
  --distillation_weight 3000 \
  --initial_iters 1000 \
  --finetune_other_params \
  --rank 502 \
  --distilled_model

2. Download Pre-trained Weights

Download our compressed checkpoints from Google Drive:

The files should be placed in the following directory structure:

saved_models/
├── small_rank_277/
│   └── deit_small_patch16_224.pth
└── base_rank_502/
    └── deit_base_patch16_224.pth

3. Inference

# Inference with small model
python inference.py \
  --model_name deit_small_patch16_224 \
  --batch_size 256 \
  --device cuda:0 \
  --rank 277 \
  --mixed_precision \
  --state_path saved_models/small_rank_277/deit_small_patch16_224.pth

# Inference with base model
python inference.py \
  --model_name deit_base_patch16_224 \
  --batch_size 256 \
  --device cuda:0 \
  --rank 502 \
  --mixed_precision \
  --state_path saved_models/base_rank_502/deit_base_patch16_224.pth

4. Optional CIFAR-10 Evaluation

Evaluate our models on the CIFAR-10 dataset (resized to 224x224) using our pre-trained checkpoints.

Download Pre-trained Models:

Directory Structure:

saved_models/
├── small_rank_277_cifar10/
│   └── deit_small_patch16_224.pth
└── deit_small_cifar10_best.pth

Run Evaluation:

# Evaluate compressed DeiT-Small on CIFAR-10
python inference.py \
  --model_name deit_small_patch16_224 \
  --batch_size 256 \
  --device cuda:0 \
  --rank 277 \
  --mixed_precision \
  --state_path saved_models/small_rank_277_cifar10/deit_small_patch16_224.pth \
  --dataset cifar10

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages