Skip to content

Latest commit

 

History

History
184 lines (160 loc) · 8.82 KB

README.md

File metadata and controls

184 lines (160 loc) · 8.82 KB

This is a PyTorch/PyTorch-XLA implementation of the paper An Inverse Scaling Law for CLIP Training. This repo is heavily based on OpenCLIP, with a few changes including pytorch-xla compatible implementation, uint8 data transfer, and some minor modifications.

Installation

A simple

pip install -r requirements.txt

is enough.

Note that this repo is compatible with both GPU and TPU. If you want to run the code on Google Cloud TPU, here are some documents you may find helpful: Google Cloud User's Guide and TIMM bits README

Data preparation

LAION-400M

You can download the LAION-400M dataset using the handy img2dataset tool. It supports both webdataset and tfrecord format.

To run this code on Google Cloud, we strongly recommend tfrecord over webdataset. Since LAION-400M is not a TFDS officially supported dataset, for your convenience, we provide some self-implemented scripts (sorry, we know they are crudely written) for post-processing the downloaded tfrecord files here..

ImageNet-1K

Download and extract ImageNet data from http://image-net.org/. The directory structure is the standard layout, and note that we only need the validation set:

/path/to/imagenet/
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

We also support reading ImageNet-1K in the tfrecord format. Check the official doc for how to prepare the tfds dataset.

Usage

cd CLIPA/clipa_torch
pip install -v .
import torch
import torch.nn.functional as F
from PIL import Image
from clipa_torch.open_clip import create_model_and_transforms, get_tokenizer

model, _, preprocess = create_model_and_transforms('ViT-H-14-CL32-GAP-BigVision', 
                                                                pretrained='/path/to/ckpt', 
                                                                force_image_size=336,
                                                                image_mean=[0.485, 0.456, 0.406],
                                                                image_std=[0.229, 0.224, 0.225],
                                                                interpolation='bilinear',
                                                                square_resize_only=True,
                                                                )
tokenizer = get_tokenizer('ViT-H-14-CL32-GAP-BigVision')

image = preprocess(Image.open("CLIP.png")).unsqueeze(0)
text = tokenizer(["a diagram", "a dog", "a cat"])

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features = F.normalize(image_features, dim=-1)
    text_features = F.normalize(text_features, dim=-1)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]

Training Instructions

We provide example scripts to reproduce our CLIPA results on an A100 eight-GPU machine.

For instance, to reproduce the CLIPA-L16(I37,T8) results, first run the pre-training script

bash scripts/exp/gpu/vit_l16/i37_t8_pretrain.sh

and fine-tune the pre-trained checkpoint with

bash scripts/exp/gpu/vit_l16/i37_t8_finetune.sh
  • Remember to change the path to dataset and checkpoint to your own path. You have to use the imagenet validation path /path/to/imagenet/val if you are reading from disk!
  • The training time is ~3 days for pre-training and ~1 day for fine-tuning on an A100 eight-GPU machine.
  • Note that to ensure proper shuffling diversity, each worker maintains a TFDS shuffling buffer and prefetch buffer. This significantly increase cpu memory burden. If you observe cpu out-of-memory issue, try tune down the values of TFDS_PREFETCH_SIZE and TFDS_SHUFFLE_SIZE.

Testing Instructions

This repo was only used for evaluating zero-shot Top-1 accuracy on ImageNet-1k. For evaluation on more datasets, check the amazing clip_benchmark. The CLIPA-v1 checkpoints from this repo should be readily applicable. As for CLIPA-v2 checkpoints, we have provided some example testing scripts for your convenience. If you wish to use them in clip_benchmark, check our example testing scripts carefully and follow their instructions to add other CLIP models.

Model Weights

Here are CLIPA-v1 trained weights on LAION-400M with academic resources. All models are pre-trained for 6 epochs with reduced input token lengths and subsequently fine-tuned for 0.36 epoch with full input token lengths.

Pretrained Model zero-shot IN-1K
CLIPA-B/16(I50,T16) download 63.2
CLIPA-L/16(I17,T16) download 67.8
CLIPA_L/16(I37,T8) download 69.3

Here are CLIPA-v2 trained weights on the LAION-2B or DataComp-1B dataset. These weights are trained by our jax implementation and converted into pytorch format. Slight performance variation is possible due to framework difference. Note that these converted weights are not open_clip compatible. Try our example testing scripts for evaluation.

data Schedule GPU Hours Estimated Cost zero-shot IN-1K model weight
H/14 LAION-2B 12.8B@84 + 512M@224 + 128M@336 8640 $13616 79.1 download
L/14 DataCOMP-1B 12.8B@84 + 512M@224 4008 $6318 79.7 download
L/14 DataCOMP-1B 12.8B@84 + 512M@224 + 128M@336 4520 $7124 80.3 download
H/14 DataCOMP-1B 12.8B@70 + 512M@224 5920 $9324 81.1 download
H/14 DataCOMP-1B 12.8B@84 + 512M@224 7776 $12247 81.5 download
H/14 DataCOMP-1B 12.8B@84 + 512M@224 + 128M@336 8640 $13616 81.8 download
bigG/14 DataCOMP-1B 12.8B@84 + 512M@224 21998 $34646 82.7 download
bigG/14 DataCOMP-1B 12.8B@84 + 512M@224 + 128M@336 23742 $39056 83.0 download

Our CLIPA-v2’s GPU hour is estimated using an 8-A100 80GB GPU machine on Google Cloud. The corresponding training cost is estimated based on 80GB A100’s cloud pricing.