Skip to content
/ STL Public

Official Pytorch Implementation of Self-emerging Token Labeling

License

Notifications You must be signed in to change notification settings

NVlabs/STL

Repository files navigation

Self-emerging Token Labeling (STL)

Fully Attentional Networks with Self-emerging Token Labeling
Bingyin Zhao, Zhiding Yu, Shiyi Lan, Yutao Cheng, Anima Anandkumar, Yingjie Lao and Jose M. Alvarez.
International Conference on Computer Vision, 2023.

Contents

Install

  1. Clone this repository and navigate to STL folder
git clone https://github.com/NVlabs/STL
cd STL
  1. Install Package
conda create -n STL python=3.10 -y
conda activate STL
pip install --upgrade pip
pip install -r requirements.txt

Dataset

Download ImageNet clean dataset and ImageNet-C dataset and structure the datasets as follows:

/path/to/imagenet-C/
  clean/
    class1/
      img3.jpeg
    class2/
      img4.jpeg
  corruption1/
    severity1/
      class1/
        img3.jpeg
      class2/
        img4.jpeg
    severity2/
      class1/
        img3.jpeg
      class2/
        img4.jpeg

For other out-of-distribution shift benchmarks, we use ImageNet-A or ImageNet-R for evaluation.

Pre-trained Models

FAN token labelers

Model #Params Download
FAN-T-Hybrid-Teacher 8.0M model
FAN-S-Hybrid-Teacher 26.5M model
FAN-B-Hybrid-Teacher 50.9M model
FAN-L-Hybrid-Teacher 77.3M model

Student models trained with STL (Image resolution: 224x224)

Model Model Name IN-1K IN-C IN-A IN-R #Params Download
FAN-T-Hybrid-Student fan_tiny_8_p4_hybrid_token 79.9 58.2 23.7 42.5 8.0M model
FAN-S-Hybrid-Student fan_small_12_p4_hybrid_token 83.4 65.5 38.2 51.8 26.5M model
FAN-B-Hybrid-Student fan_base_16_p4_hybrid_token 84.5 68.2 42.6 55.3 50.9M model
FAN-L-Hybrid-Student fan_large_16_p4_hybrid_token 84.7 68.8 46.1 56.6 77.3M model

Training

STL has two phases. In the first phase, we train a FAN token labeler to produce semantically meaningful token labels. In the second phase, we train a student model by jointly using self-emerging token labels generated by FAN token labelers and original image-level labels. We provide pre-trained token labelers in Pre-trained Models. To save time, you can skip the first phase and train student models directly using the pre-trained token labelers.

Train FAN token labelers

Train FAN token labelers with a single GPU node:

python3 -m torch.distributed.launch --nproc_per_node=8 \
train_token_labeler.py /PATH/TO/IMAGENET/ --model fan_small_12_p4_hybrid_token -b 128 --sched cosine --epochs 350 \
--opt adamw -j 16 --warmup-lr 1e-6 --warmup-epochs 5 \
--model-ema-decay 0.99992 --aa rand-m9-mstd0.5-inc1 --remode pixel \
--reprob 0.3 --lr 20e-4 --min-lr 1e-6 --weight-decay .05 --drop 0.0 \
--drop-path .25 --img-size 224 --mixup 0.8 --cutmix 1.0 \
--smoothing 0.1 \
--output /PATH/TO/SAVE/CKPT/ \
--amp --model-ema \
--token-label --cls-with-single-token-label \

Train FAN token labelers with multiple GPU nodes:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=$rank_num \
	--node_rank=$rank_index --master_addr="$ip_address" --master_port=3349 \
	train_token_labeler.py  /PATH/TO/IMAGENET/ --model fan_base_16_p4_hybrid_token -b 64 --sched cosine --epochs 350 \
	--opt adamw -j 16 --warmup-lr 1e-6 --warmup-epochs 10  \
	--model-ema-decay 0.99992 --aa rand-m9-mstd0.5-inc1 --remode pixel \
	--reprob 0.3 --lr 40e-4 --min-lr 1e-6 --weight-decay .05 --drop 0.0 \
	--drop-path .35 --img-size 224 --mixup 0.8 --cutmix 1.0 \
	--smoothing 0.1 \
	--output /PATH/TO/SAVE/CKPT/ \
	--amp --model-ema \
	--token-label --cls-with-single-token-label \

Or you can run the scripts directly:

cd scripts/fan_token_labeler/
bash fan_small_tl.sh

More details and scripts can be found in the folder scripts/fan_token_labeler/.

Train student models with STL

Train student models with STL using a single GPU node:

python3 -m torch.distributed.launch --nproc_per_node=8 \
main.py  /PATH/TO/IMAGENET/ --model fan_small_12_p4_hybrid_token -b 128 --sched cosine --epochs 350 \
--opt adamw -j 16 --warmup-lr 1e-6 --warmup-epochs 5  \
--model-ema-decay 0.99992 --aa rand-m9-mstd0.5-inc1 --remode pixel \
--reprob 0.3 --lr 20e-4 --min-lr 1e-6 --weight-decay .05 --drop 0.0 \
--drop-path .25 --img-size 224 --mixup 0.8 --cutmix 1.0 \
--smoothing 0.1 \
--output  /PATH/TO/SAVE/CKPT/ \
--amp --model-ema \
--token-label --cls-weight 1.0 --dense-weight 1.0 \
--offline-model  /PATH/TO/LOAD/CKPT/fan_token_labeler.pth.tar \

Train student models with STL using multiple GPU nodes:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=$rank_num \
	--node_rank=$rank_index --master_addr="$ip_address" --master_port=3349 \
	main.py  /PATH/TO/IMAGENET/ --model fan_base_16_p4_hybrid_token -b 64 --sched cosine --epochs 350 \
	--opt adamw -j 16 --warmup-lr 1e-6 --warmup-epochs 10  \
	--model-ema-decay 0.99992 --aa rand-m9-mstd0.5-inc1 --remode pixel \
	--reprob 0.3 --lr 40e-4 --min-lr 1e-6 --weight-decay .05 --drop 0.0 \
	--drop-path .35 --img-size 224 --mixup 0.8 --cutmix 1.0 \
	--smoothing 0.1 \
	--output /PATH/TO/SAVE/CKPT/ \
	--amp --model-ema \
	--token-label  --cls-weight 1.0 --dense-weight 1.0 \
	--offline-model /PATH/TO/LOAD/CKPT/fan_token_labeler.pth.tar \

Or you can run the scripts directly:

cd scripts/fan_stl_student/
bash fan_small_stl.sh

More details and scripts can be found in the folder scripts/fan_stl_student/.

Evaluation

Evaluation on ImageNet-1K and ImageNet-C

bash scripts/imagenet_c_val.sh $model_name $ckpt

Evaluation on ImageNet-A

bash scripts/imagenet_a_val.sh $model_name $ckpt

Evaluation on ImageNet-R

bash scripts/imagenet_r_val.sh $model_name $ckpt

Fully Attentional Networks

STL is built upon Fully Attentional Networks (FAN). FAN is a family of general-purpose Vision Transformer backbones that are highly robust to unseen natural corruptions in various visual recognition tasks. If you are interested in the original FAN design, please refer to the official implementation of FAN.

License

Copyright © 2023, NVIDIA Corporation. All rights reserved.

This work is made available under the Nvidia Source Code License-NC. Click here to view a copy of this license.

The pre-trained models are shared under CC-BY-NC-SA-4.0. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.

For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing.

Citation

If you find STL helpful for your research and applications, please cite using the BibTeX:

@inproceedings{zhao2023fully,
  title={Fully Attentional Networks with Self-emerging Token Labeling},
  author={Zhao, Bingyin and Yu, Zhiding and Lan, Shiyi and Cheng, Yutao and Anandkumar, Anima and Lao, Yingjie and Alvarez, Jose M},
  booktitle={IEEE/CVF International Conference on Computer Vision (ICCV)},
  year={2023}
}

Acknowledgement

This repository is built using the timm library, FAN, DeiT, PVT and SegFormer repositories.

About

Official Pytorch Implementation of Self-emerging Token Labeling

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published