Skip to content

Latest commit

 

History

History

clipa_jax

This repo contains official JAX implementation of CLIPA in our paper: An Inverse Scaling Law for CLIP Training

Data preparation

LAION-400M

You can download the LAION-400M dataset using the handy img2dataset tool.

We have only tested this jax implementation on Google Cloud TPU. To run this code, you have to choose tfrecord format, instead of more common webdataset format. Since LAION-400M is not a TFDS officially supported dataset, for your convenience, we provide some self-implemented scripts (sorry, we know they are crudely written) for post-processing the downloaded tfrecord files here.

ImageNet-1K

Download and extract ImageNet data from http://image-net.org/. This jax implementation only support reading dataset in tfrecord format. Check the official doc for how to prepare the tfds dataset.

Usage

First, you need to prepare the dataset and create the TPU VM instance. Refer to TPU_USAGE for details.

The configs folder contains the detailed configuration of our model and training details.

First

cd clipa_jax

To begin with, navigate to the 'scripts/' directory and locate the three scripts set_up_env.sh, pre_training.sh and fine_tuning.sh provided. Before executing the scripts, ensure that you specify the TPU VM instance information and dataset path at the top of each file.

Next, you can upload all necessary files to your TPU VM instance and set up the required environment by running the following command:

bash scripts/set_up_env.sh

Then pre-training can be done by running:

bash scripts/pre_training.sh

After pre-training, you can fine-tune the model by running:

bash scripts/fine_tuning.sh

Pre-trained weights

Model image text Data Schedule Top1 weights
H-14 70 8 LAION-2B 12.8B 70.1 weight
H-14 84 8 LAION-2B 12.8B 72.1 weight
L-14 84 8 DataCOMP-1B 12.8B 72.7 weight
H-14 84 8 DataCOMP-1B 12.8B 75.6 weight
G-14 84 8 DataCOMP-1B 12.8B 78.5 weight

Fine-tuned Results

data Schedule GPU Hours Estimated Cost zero-shot IN-1K model weight
H/14 LAION-2B 12.8B@84 + 512M@224 7776 $12247 78.6 download
H/14 LAION-2B 12.8B@84 + 512M@224 + 128M@336 8640 $13616 79.1 download
L/14 DataCOMP-1B 12.8B@84 + 512M@224 4008 $6318 79.7 download
L/14 DataCOMP-1B 12.8B@84 + 512M@224 + 128M@336 4520 $7124 80.3 download
H/14 DataCOMP-1B 12.8B@70 + 512M@224 5920 $9324 81.1 download
H/14 DataCOMP-1B 12.8B@84 + 512M@224 7776 $12247 81.5 download
H/14 DataCOMP-1B 12.8B@84 + 512M@224 + 128M@336 8640 $13616 81.8 download
G/14 DataCOMP-1B 12.8B@84 + 512M@224 21,998 $34,646 82.5 download
G/14 DataCOMP-1B 12.8B@84 + 512M@224 + 128M@336 23742 $39056 83.0 download

Our CLIPA-v2’s GPU hour is estimated using an 8-A100 80GB GPU machine on Google Cloud. The corresponding training cost is estimated based on 80GB A100’s cloud pricing.