CAS-ViT: Convolutional Additive Self-attention Vision Transformers for Efficient Mobile Applications
📌 Official Implementation of our proposed method CAS-ViT.
Comparison of diverse self-attention mechanisms. (a) is the classical multi-head self-attention in ViT. (b) is the separable self-attention in MobileViTv2, which reduces the feature metric of a matrix to a vector. (c) is the swift self-attention in SwiftFormer, which achieves efficient feature association only with Q and K. (d) is proposed convolutional additive self-attention.
Upper: Illustration of the classification backbone network. Four stages downsample the original image to 1/4, 1/8, 1/16, 1/32 . Lower: Block architecture with N$_i$ blocks stacked in each stage.
You can download the pretrained weights and configs from Model Zoo.
torch==1.8.0
torchvision==0.9.1
timm==0.5.4
mmcv-full==1.5.3
mmdet==2.24
mmsegmentation==0.24
Download ImageNet-1K dataset.
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── ILSVRC2012_val_00000293.JPEG
│ ├── ILSVRC2012_val_00002138.JPEG
│ ├── ......
Load image from ./classification/data/imagenet1k/train.txt
.
Download the pretrained weights from Model Zoo and run the following command for evaluation on ImageNet-1K dataset.
MODEL=rcvit_m # model to evaluate: rcvit_{xs, s, m, t}
python main.py --model ${MODEL} --eval True --resume <path to checkpoint> --input_size 384 --data_path <path to imagenet>
Checkpoint of CAS-ViT-M should give:
* Acc@1 81.430 Acc@5 95.664 loss 0.907
On a single machine with 8 GPUs, run the following command to train:
python -m torch.distributed.launch --nproc_per_node 8 main.py \
--data_path <path to imagenet> \
--output_dir <output dir> \
--model rcvit_m \
--lr 6e-3 --batch_size 128 --drop_path 0.1 \
--model_ema True --model_ema_eval True \
--use_amp True --multi_scale_sampler
On a single machine with 8 GPUs, run the following command to funetune:
python -m torch.distributed.launch --nproc_per_node 8 main.py \
--data_path <path to imagenet> \
--output_dir <output dir> \
--finetune <path to model weights> \
--input_size 384 --epoch 30 --batch_size 64 \
--lr 5e-5 --min_lr 5e-5 --weight_decay 0.05 \
--drop_path 0 --model_ema True \
--model_ema_eval True --use_amp True \
--auto_resume False --multi_scale_sampler
Prepare COCO according to the guidelines in MMDetection.
To evaluate CAS-ViT + RetinaNet on COCO val 2017 on a single machine with 8 GPUs, run the following command:
python -m torch.distributed.launch --nproc_per_node 8 test.py \
<config path> \
<checkpoint file> \
--launcher pytorch
To train CAS-ViT-M + RetinaNet on COCO val 2017 on a single machine with 8 GPUs, run the following command:
python -m torch.distributed.launch --nproc_per_node 8 train.py \
<config path> --launcher pytorch
Prepare ADE20K according to the guidelines in MMSegmentation.
To evaluate CAS-ViT + Semantic FPN on ADE20K on a single machine with 8 GPUs, run the following command:
python -m torch.distributed.launch --nproc_per_node 8 tools/test.py \
<config path> \
<checkpoint file> \
--launcher pytorch
To train CAS-ViT-M + Semantic FPN on ADE20K on a single machine with 8 GPUs, run the following command:
python -m torch.distributed.launch --nproc_per_node 8 tools/train.py \
<config path> --launcher pytorch
@article{zhang2024cas,
title={CAS-ViT: Convolutional Additive Self-attention Vision Transformers for Efficient Mobile Applications},
author={Zhang, Tianfang and Li, Lei and Zhou, Yang and Liu, Wentao and Qian, Chen and Ji, Xiangyang},
journal={arXiv preprint arXiv:2408.03703},
year={2024}
}
Our code was build base on ConvNeXt, EdgeNeXt, PoolFormer, MMDetection and MMsegmentation. Thanks for their public repository and excellent contributions!