Skip to content

StomachCold/HCTransformers

Repository files navigation

HCTransformers

PWC

PyTorch implementation for "Attribute Surrogates Learning and Spectral Tokens Pooling in Transformers for Few-shot Learning".
[arxiv]

Code will be continuously updated.

HCT Network Architecture

Updates

06/21/2024

Update share links for pretrained weights and extracted features.

07/07/2022

  1. Datasets description and guideline are updated.
  2. Features extracted by the pretrained models on our 𝒎𝒊𝒏𝒊ImageNet is also provided here.

07/01/2022

Provided pretrained weights download links and evaluation command line.

Prerequisites

This codebase has been developed with Python version 3.8, PyTorch version 1.9.0, CUDA 11.1 and torchvision 0.10.0. It has been tested on Ubuntu 20.04.

Pretrained weights

Pretrained weights on 𝒎𝒊𝒏𝒊ImageNet, 𝒕𝒊𝒆𝒓𝒆𝒅ImageNet, CIFAR-FS and FC100 are available now. Note that for 𝒕𝒊𝒆𝒓𝒆𝒅ImageNet and FC100 there are only checkpoints for the first stage (without cascaded training). Accuracy of 5-way 1-shot and 5-way 5-shot shown in the table is evaluated on the test split and for reference only.

dataset 1-shot 5-shot download
𝒎𝒊𝒏𝒊ImageNet 71.16% 84.60% checkpoints_first features_mini
𝒕𝒊𝒆𝒓𝒆𝒅ImageNet 79.67% 91.72% -
FC100 48.27% 66.42% -
CIFAR-FS 73.13% 86.36% -

Pretrained weights for the cascaded-trained models on 𝒎𝒊𝒏𝒊ImageNet and CIFAR-FS are provided as follows. Note that the path to pretrained weight in the first stage must be specified when evaluating (see Evaluation).

dataset 1-shot 5-shot download
𝒎𝒊𝒏𝒊ImageNet 74.74% 89.19% checkpoints_pooling features_mini
CIFAR-FS 78.89% 90.50% -

Datasets

𝒎𝒊𝒏𝒊ImageNet

The 𝑚𝑖𝑛𝑖ImageNet dataset was proposed by Vinyals et al. for few-shot learning evaluation. Its complexity is high due to the use of ImageNet images but requires fewer resources and infrastructure than running on the full ImageNet dataset. In total, there are 100 classes with 600 samples of color images per class. These 100 classes are divided into 64, 16, and 20 classes respectively for sampling tasks for meta-training, meta-validation, and meta-test. To generate this dataset from ImageNet, you may use the repository 𝑚𝑖𝑛𝑖ImageNet tools.

Note that in our implemenation images are resized to 480 × 480 because the data augmentation we used require the image resolution to be greater than 224 to avoid distortions. Therefore, when generating 𝒎𝒊𝒏𝒊ImageNet, you should set --image_resize 0 to keep the original size or --image_resize 480 as what we did.

𝒕𝒊𝒆𝒓𝒆𝒅ImageNet

The 𝑡𝑖𝑒𝑟𝑒𝑑ImageNet dataset is a larger subset of ILSVRC-12 with 608 classes (779,165 images) grouped into 34 higher-level nodes in the ImageNet human-curated hierarchy. To generate this dataset from ImageNet, you may use the repository 𝑡𝑖𝑒𝑟𝑒𝑑ImageNet dataset: 𝑡𝑖𝑒𝑟𝑒𝑑ImageNet tools.

Similar to 𝒎𝒊𝒏𝒊ImageNet, you should set --image_resize 0 to keep the original size or --image_resize 480 as what we did when generating 𝒕𝒊𝒆𝒓𝒆𝒅ImageNet.

Training

We provide the training code for 𝒎𝒊𝒏𝒊ImageNet, 𝒕𝒊𝒆𝒓𝒆𝒅ImageNet and CIFAR-FS, extending the DINO repo (link).

1 Pre-train the First Transformer

To pre-train the first Transformer with attribute surrogates learning on 𝒎𝒊𝒏𝒊ImageNet from scratch with multiple GPU, run:

python -m torch.distributed.launch --nproc_per_node=8 main_hct_first.py --arch vit_small --data_path /path/to/mini_imagenet/train --output_dir /path/to/saving_dir

2 Train the Hierarchically Cascaded Transformers

To train the Hierarchically Cascaded Transformers with sprectral token pooling on 𝒎𝒊𝒏𝒊ImageNet, run:

python -m torch.distributed.launch --nproc_per_node=8 main_hct_pooling.py --arch vit_small --data_path /path/to/mini_imagenet/train --output_dir /path/to/saving_dir --pretrained_weights /path/to/pretrained_weights

Evaluation

To evaluate the performance of the first Transformer on 𝒎𝒊𝒏𝒊ImageNet 5-way 1-shot task, run:

python eval_hct_first.py --arch vit_small --server mini --partition test --checkpoint_key student --ckp_path /path/to/checkpoint_mini/ --num_shots 1

To evaluate the performance of the Hierarchically Cascaded Transformers on 𝒎𝒊𝒏𝒊ImageNet 5-way 5-shot task, run:

python eval_hct_pooling.py --arch vit_small --server mini_pooling --partition val --checkpoint_key student --ckp_path /path/to/checkpoint_mini_pooling/  --pretrained_weights /path/to/pretrained_weights_of_first_satge --num_shots 5

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Citation

If you find our code or paper useful to your research work, please consider citing our work using the following bibtex:

@inproceedings{he2022attribute,
  title={Attribute surrogates learning and spectral tokens pooling in transformers for few-shot learning},
  author={He, Yangji and Liang, Weihan and Zhao, Dongyang and Zhou, Hong-Yu and Ge, Weifeng and Yu, Yizhou and Zhang, Wenqiang},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={9119--9129},
  year={2022}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages