This repo contains official JAX implementation of CRATE-alpha in our paper: Scaling White-Box Transformers for Vision
We propose CRATE-α, featuring strategic yet minimal modifications to the sparse coding block in the CRATE architecture design, and a light training recipe designed to improve the scalability of CRATE.
One layer of the CRATE-α model architecture. MSSA (Multi-head Subspace Self-Attention) represents the compression block, and ODL (Overcomplete Dictionary Learning) represents the sparse coding block.
Left: We demonstrate how modifications to the components enhance the performance of the CRATE model on ImageNet-1K. Right: We compare the FLOPs and accuracy on ImageNet-1K of our methods with ViT Dosovitskiy et al., 2020 and CRATE Yu et al., 2023. CRATE is trained only on ImageNet-1K, while ours and ViT are pre-trained on ImageNet-21K.
Visualization of segmentation on COCO val2017 Lin et al., 2014 with MaskCut Wang et al., 2023. Top row: Supervised ours effectively identifies the main objects in the image. Compared with CRATE (Middle row), ours achieves better segmentation performance in terms of boundary. Bottom row: Supervised ViT fails to identify the main objects in most images. We warp the failed image in a red box.
Models (Base) | ImageNet-1K(%) | Models (Large) | ImageNet-1K(%) | |
---|---|---|---|---|
CRATE-α-B/32 | 76.5 | CRATE-α-L/32 | 80.2 | |
CRATE-α-B/16 | 81.2 | CRATE-α-L/14 | 83.9 | |
CRATE-α-B/8 | 83.2 | CRATE-α-L/8 | 85.1 |
You can download model weights from the following link: Model Weights
Our experiments are conducted on TPUs. How can we gain access to and set up TPU machines? Check this brief doc in CLIPA.
To set up the environment, run the following script:
bash scripts/env/setup_env.sh
We provide scripts for pre-training on ImageNet-21K and fine-tuning on ImageNet-1K.
To start pre-training on ImageNet-21K, run:
bash scripts/in1k/pre_training_in21k.sh
To start fine-tuning on ImageNet-1K, run:
bash scripts/in1k/fine_tuning_in1k.sh
We provide scripts for pre-training and fine-tuning on Datacomp1B.
To start pre-training on Datacomp1B, run:
bash scripts/clipa/pre_train.sh
To start fine-tuning on Datacomp1B, run:
bash scripts/clipa/fine_tune.sh
To increase accessibility, we have converted the weights from JAX to PyTorch. We provide models in configurations B/16, L/14, CRATE-α-CLIPA-L/14, and CRATE-α-CLIPA-H/14. You can use the PyTorch code to reproduce the results from our paper.
You can download the ImageNet-1K validation set using the following commands:
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz
For the PyTorch environment, the recommended dependencies are as follows:
pip install torch==2.0.0
pip install torchvision==0.15.0
pip install transformers==4.40.2
pip install open-clip-torch==2.24.0
Model | PyTorch Accuracy | JAX (Paper) Accuracy | PyTorch Weights |
---|---|---|---|
CRATE-α-B/16 | 81.2 | 81.2 | Download |
CRATE-α-L/14 | 83.9 | 83.9 | Download |
CRATE-α-CLIPA-L/14 | 69.8 | 69.8 | Download |
CRATE-α-CLIPA-H/14 | 72.3 | 72.3 | Download |
Weights for the PyTorch models are available for download. Use the links provided in the table above.
To run the evaluation code, specify the path to the checkpoints and the ImageNet validation set in the eval_in1k_cls.py
file.
python torch_inference/eval_in1k_cls.py
For the CLIPA PyTorch version, we refer to CLIP.
To run the evaluation code, specify the path to the checkpoints and the ImageNet validation set in the eval_in1k.py
and clipa_model.py
files. The default model is CRATE-α-CLIPA-L/14.
python torch_inference/eval_in1k.py
The repo is built on big vision and CLIPA. Many thanks to the awesome works from the open-source community!
We are also very grateful that this work is supported by a gift from Open Philanthropy, TPU Research Cloud (TRC) program, and Google Cloud Research Credits program.
@article{yang2024cratealpha,
title = {Scaling White-Box Transformers for Vision},
author = {Yang, Jinrui and Li, Xianhang and Pai, Druv and Zhou, Yuyin and Ma, Yi and Yu, Yaodong and Xie, Cihang},
journal = {arXiv preprint arXiv:2405.20299},
year = {2024}
}
If you have any questions, please feel free to raise an issue or contact us directly: jyang347@ucsc.edu.