This repo contains official JAX implementation of CLIPA in our paper: An Inverse Scaling Law for CLIP Training
You can download the LAION-400M dataset using the handy img2dataset tool.
We have only tested this jax implementation on Google Cloud TPU. To run this code, you have to choose tfrecord format, instead of more common webdataset format.
Since LAION-400M is not a TFDS officially supported dataset, for your convenience, we provide some self-implemented scripts (sorry, we know they are crudely written) for post-processing the downloaded tfrecord files here.
Download and extract ImageNet data from http://image-net.org/. This jax implementation only support reading dataset in tfrecord format. Check the official doc for how to prepare the tfds dataset.
First, you need to prepare the dataset and create the TPU VM instance. Refer to TPU_USAGE for details.
The configs folder contains the detailed configuration of our model and training details.
First
cd clipa_jax
To begin with, navigate to the 'scripts/' directory and locate the three scripts set_up_env.sh, pre_training.sh and fine_tuning.sh provided. Before executing the scripts, ensure that you specify the TPU VM instance information and dataset path at the top of each file.
Next, you can upload all necessary files to your TPU VM instance and set up the required environment by running the following command:
bash scripts/set_up_env.sh
Then pre-training can be done by running:
bash scripts/pre_training.sh
After pre-training, you can fine-tune the model by running:
bash scripts/fine_tuning.sh
| Model | image | text | Data | Schedule | Top1 | weights |
|---|---|---|---|---|---|---|
| H-14 | 70 | 8 | LAION-2B | 12.8B | 70.1 | weight |
| H-14 | 84 | 8 | LAION-2B | 12.8B | 72.1 | weight |
| L-14 | 84 | 8 | DataCOMP-1B | 12.8B | 72.7 | weight |
| H-14 | 84 | 8 | DataCOMP-1B | 12.8B | 75.6 | weight |
| G-14 | 84 | 8 | DataCOMP-1B | 12.8B | 78.5 | weight |
| data | Schedule | GPU Hours | Estimated Cost | zero-shot IN-1K | model weight | |
|---|---|---|---|---|---|---|
| H/14 | LAION-2B | 12.8B@84 + 512M@224 | 7776 | $12247 | 78.6 | download |
| H/14 | LAION-2B | 12.8B@84 + 512M@224 + 128M@336 | 8640 | $13616 | 79.1 | download |
| L/14 | DataCOMP-1B | 12.8B@84 + 512M@224 | 4008 | $6318 | 79.7 | download |
| L/14 | DataCOMP-1B | 12.8B@84 + 512M@224 + 128M@336 | 4520 | $7124 | 80.3 | download |
| H/14 | DataCOMP-1B | 12.8B@70 + 512M@224 | 5920 | $9324 | 81.1 | download |
| H/14 | DataCOMP-1B | 12.8B@84 + 512M@224 | 7776 | $12247 | 81.5 | download |
| H/14 | DataCOMP-1B | 12.8B@84 + 512M@224 + 128M@336 | 8640 | $13616 | 81.8 | download |
| G/14 | DataCOMP-1B | 12.8B@84 + 512M@224 | 21,998 | $34,646 | 82.5 | download |
| G/14 | DataCOMP-1B | 12.8B@84 + 512M@224 + 128M@336 | 23742 | $39056 | 83.0 | download |
Our CLIPA-v2’s GPU hour is estimated using an 8-A100 80GB GPU machine on Google Cloud. The corresponding training cost is estimated based on 80GB A100’s cloud pricing.