# Finetuner X CLIP Benchmark

@bo_wangbo
@fissoreg

In this Colab notebook, we'll try to use [Finetuner](https://github.com/jina-ai/finetuner) to fine-tune the CLIP model on `Flickr8k`, and compare the retrieval metrics produced by the fine-tuned model against pre-trained zero-shot results produced from CLIP Benchmark.

*NOTE: Finetuner is a cloud-based training platform, which requires you to login and Finetuner will allocate computational resources automatically for free.*

**Please Consider [Switching to a GPU Runtime](https://medium.com/@oribarel/getting-the-most-out-of-your-google-colab-2b0585f82403) for faster evaluation!**


In [1]:
!pip install "finetuner[full]"
# our fork of CLIP benchmark, resolved some minor issues in data builder and adjust the evaluator code to allow evaluator receive 2 models
# when fine-tuning CLIP, Finetuner will un-wrap the CLIP model into 2 models and save them individually
!pip install kaggle
!pip install git+https://github.com/bwanglzu/CLIP_benchmark.git

Collecting transformers==4.20.1
  Using cached transformers-4.20.1-py3-none-any.whl (4.4 MB)
Installing collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.31.0.dev0
    Uninstalling transformers-4.31.0.dev0:
      Successfully uninstalled transformers-4.31.0.dev0
Successfully installed transformers-4.20.1
Collecting git+https://github.com/bwanglzu/CLIP_benchmark.git
  Cloning https://github.com/bwanglzu/CLIP_benchmark.git to /tmp/pip-req-build-50ezrzc1
  Running command git clone -q https://github.com/bwanglzu/CLIP_benchmark.git /tmp/pip-req-build-50ezrzc1
Building wheels for collected packages: clip-benchmark
  Building wheel for clip-benchmark (setup.py) ... [?25ldone
[?25h  Created wheel for clip-benchmark: filename=clip_benchmark-0.1.0-py2.py3-none-any.whl size=48107 sha256=3f2c7ad4c0cb5f534efda9c5df86933239a84519f2bd3b2ba25477f499a770f6
  Stored in directory: /tmp/pip-ephem-wheel-cache-6_qkzxnr/wheels/39/c8/6a/f2

## Preparing the training data

CLIP Benchmark comes with a dataset `builder` that does much of the work of assembling training data. However, for Finetuner, we need to convert it into Jina DocArray format.

We will use:

1. CLIP Benchmark contains a file named `captions.txt` which includes all Flickr8k image urls with captions.
2. CLIP Benchmark reused the Karpathy split which split the `Flickr8k` into test sets and training sets. The test set includs 5000 images with annotations.

We will build our training set by loading all images, and then then excluding the test set images.



In [2]:
import os
from clip_benchmark.datasets.builder import build_dataset

# please fill in your kaggle token here, you should be able to get your kaggle
# user name and key in kaggle personal settings.
# CLIP Benchmark uses kaggle to download flickr8k dataset
os.environ['KAGGLE_USERNAME'] = 'vincentzho'
os.environ['KAGGLE_KEY'] = '162c937e7a869e55276c795e1c293876'

build_dataset(dataset_name='flickr8k', annotation_file=None, download=True)



Dataset Flickr
    Number of datapoints: 1000
    Root location: root

In [8]:
##### I have changed the root dir to my path ########
root_dir = '/home/xz306/root/'
full_annotation = root_dir + 'captions.txt'
test_annotation = root_dir + 'flickr8k_test_karpathy.txt'

all_imgs = []
test_imgs = []
with open(full_annotation, 'r') as f:
    next(f) # exclude the header line
    for idx, item in enumerate(f.readlines()):
        all_imgs.append(item.split(',', 1)[0])

with open(test_annotation, 'r') as f:
    next(f) # exclude the header line
    for idx, item in enumerate(f.readlines()):
        test_imgs.append(item.split(',', 1)[0])

print(f'Size of the full image set is {len(all_imgs)}')
print(f'Size of the test image set is {len(test_imgs)}')

Size of the full image set is 40455
Size of the test image set is 5000


Now we will convert the downloaded images into `DocumentArray` format like this:

```python
from docarray import Document, DocumentArray

pairs = DocumentArray()
pair_1 = Document(chunks=[
    img_chunk = Document(uri='your-image.jpg', modality='image'),
    txt_chunk = Document(content='the text descriptor', modality='text'),
]}
pair_2 = ...
pairs.extend([pair_1, pair_2, ...])
```

In [9]:
from tqdm import tqdm
from docarray import Document, DocumentArray

train = DocumentArray()
with open(full_annotation, 'r') as f:
    next(f) # exclude the header line
    for idx, line in tqdm(enumerate(f.readlines())):
        url, txt = line.split(',', 1)
        if url in test_imgs:  # do not include test images into training set
            continue
        img_chunk = Document(uri=root_dir + url, modality='image')
        txt_chunk = Document(content=txt, modality='text')
        img_chunk.load_uri_to_image_tensor(224, 224)
        img_chunk.pop('uri')
        pair = Document(chunks=[img_chunk, txt_chunk])
        train.append(pair)
        if idx == 5000: # we only use a subset to train
            break

print(f'The size of the training data is {len(train)}')

5000it [02:59, 27.88it/s] 

The size of the training data is 4376





The Flickr8k dataset contains 8,000 images, each with 5 descriptive texts, or 40,000 image-text pairs in total.

+ The training set has ~35000 image-text pairs.
+ The test set has ~5000 image-text pairs.

## Start Fine-tuning

Now that we have prepared the training and test data, the next step is to start the fine-tuning job using Finetuner. Finetuner takes a pre-trained model from a 3rd party library, such as `open_clip`, then jointly optimize the `CLIPLoss` function for the image encoder and text encoder.

Finetuner will also reserve a cloud GPU for you for free.

In [None]:
UC merced land use dateset

In [10]:
import finetuner

finetuner.login()

VBox(children=(HTML(value="\n<div class='custom-container'>\n    <style>\n        .custom-container {\n       …

In [None]:
# Note, we have push the training set below to the cloud, and set the dataset as public, so you don't have to push again.
# train.push('finetuner-flickr8k-demo', public=True, show_progress=True)
# finetuner.delete_run('clip-run')

In [12]:
run = finetuner.fit(
    ####### I change the model name from ViT-B-32::openai to openai/clip-vit-base-patch32 because it shows there is no model called ViT-B-32 #####
    model='openai/clip-vit-base-patch32', # we take ViT-B-32 trained from Open AI, model provided by OpenCLIP
    train_data='finetuner-flickr8k-demo', # the dataset we prepared has been pushed to the cloud in the prev section
    run_name='clip-run',
    loss='CLIPLoss', # use CLIPLoss for fine-tuning CLIP model
    epochs=5,
    batch_size=64,
    learning_rate= 1e-6,
    device='cuda',
)

In [13]:
# takes around ~10 minutes to finish
for log_entry in run.stream_logs():
    print(log_entry)

Output()

[23:46:01] INFO     Starting finetuner training run ...                                                  __main__.py:350
DEBUG    Found Jina AI Cloud authentication token                                             __main__.py:362
DEBUG    Running in online mode                                                               __main__.py:363
INFO     Reading config ...                                                                   __main__.py:370
DEBUG    Reading config from stream                                                           __main__.py:382
INFO     Parsing config ...                                                                   __main__.py:385
INFO     Config loaded 📜                                                                     __main__.py:389
INFO     Run name: clip-run                                                                   __main__.py:391
INFO     Experiment name: default                                                             __main__.py:392


## Inference

After fine-tuning is finished, your fine-tuned model is saved in the cloud as an `artifact`. An `artifact` contains the model weights, and some metadata such as evaluation metrics and hyper-parameters.

In order to download your artifact, call the method `run.save_artifact()`.

Since CLIP is actually two models and we are fine-tuning them in parallel, there will be two models downloaded as one artifact: a text encoder and an image encoder. To use these models to do encodings, you will need the `finetuner.get_model()` with a `select_model` -- either `clip-text` or `clip-vision` -- get access to CLIPs constituent models individually.

In [14]:
artifact = run.save_artifact('clip-model')

clip_txt_encoder = finetuner.get_model(artifact=artifact, select_model='clip-text')
clip_img_encoder = finetuner.get_model(artifact=artifact, select_model='clip-vision')

Output()

NVIDIA GeForce RTX 3080 with CUDA capability sm_86 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70.
If you want to use the NVIDIA GeForce RTX 3080 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



With these two models and Finetuner, you can encode your image and text data with:

```python
data = DocumentArray([Document(content='some text to encode')])
finetuner.encode(model=clip_txt_encoder, data=data)
```

In order to use CLIP Benchmark, we must provide a PyTorch file rather than a Finetuner inference runtime. The code below is a hack to overcome this problem.

In [None]:
!unzip clip-model/clip-run.zip # as said, artifact are saved as zip together with weights and some metadata.

In [17]:

import torch
from _finetuner.models.builders import OpenCLIPVisionBuilder, OpenCLIPTextBuilder

clip_vision = OpenCLIPVisionBuilder(descriptor='ViT-B-32::openai').build()
##### I have chaned the dir to my path #####
clip_vision.load_state_dict(torch.load(f'/home/xz306/{run.name}/models/clip-vision/model.pt'))

clip_text = OpenCLIPTextBuilder(descriptor='ViT-B-32::openai').build()
clip_text.load_state_dict(torch.load(f'/home/xz306/{run.name}/models/clip-text/model.pt'))

RuntimeError: Error(s) in loading state_dict for OpenCLIPVisionModel:
	Missing key(s) in state_dict: "_model.positional_embedding", "_model.text_projection", "_model.logit_scale", "_model.visual.class_embedding", "_model.visual.positional_embedding", "_model.visual.proj", "_model.visual.conv1.weight", "_model.visual.ln_pre.weight", "_model.visual.ln_pre.bias", "_model.visual.transformer.resblocks.0.ln_1.weight", "_model.visual.transformer.resblocks.0.ln_1.bias", "_model.visual.transformer.resblocks.0.attn.in_proj_weight", "_model.visual.transformer.resblocks.0.attn.in_proj_bias", "_model.visual.transformer.resblocks.0.attn.out_proj.weight", "_model.visual.transformer.resblocks.0.attn.out_proj.bias", "_model.visual.transformer.resblocks.0.ln_2.weight", "_model.visual.transformer.resblocks.0.ln_2.bias", "_model.visual.transformer.resblocks.0.mlp.c_fc.weight", "_model.visual.transformer.resblocks.0.mlp.c_fc.bias", "_model.visual.transformer.resblocks.0.mlp.c_proj.weight", "_model.visual.transformer.resblocks.0.mlp.c_proj.bias", "_model.visual.transformer.resblocks.1.ln_1.weight", "_model.visual.transformer.resblocks.1.ln_1.bias", "_model.visual.transformer.resblocks.1.attn.in_proj_weight", "_model.visual.transformer.resblocks.1.attn.in_proj_bias", "_model.visual.transformer.resblocks.1.attn.out_proj.weight", "_model.visual.transformer.resblocks.1.attn.out_proj.bias", "_model.visual.transformer.resblocks.1.ln_2.weight", "_model.visual.transformer.resblocks.1.ln_2.bias", "_model.visual.transformer.resblocks.1.mlp.c_fc.weight", "_model.visual.transformer.resblocks.1.mlp.c_fc.bias", "_model.visual.transformer.resblocks.1.mlp.c_proj.weight", "_model.visual.transformer.resblocks.1.mlp.c_proj.bias", "_model.visual.transformer.resblocks.2.ln_1.weight", "_model.visual.transformer.resblocks.2.ln_1.bias", "_model.visual.transformer.resblocks.2.attn.in_proj_weight", "_model.visual.transformer.resblocks.2.attn.in_proj_bias", "_model.visual.transformer.resblocks.2.attn.out_proj.weight", "_model.visual.transformer.resblocks.2.attn.out_proj.bias", "_model.visual.transformer.resblocks.2.ln_2.weight", "_model.visual.transformer.resblocks.2.ln_2.bias", "_model.visual.transformer.resblocks.2.mlp.c_fc.weight", "_model.visual.transformer.resblocks.2.mlp.c_fc.bias", "_model.visual.transformer.resblocks.2.mlp.c_proj.weight", "_model.visual.transformer.resblocks.2.mlp.c_proj.bias", "_model.visual.transformer.resblocks.3.ln_1.weight", "_model.visual.transformer.resblocks.3.ln_1.bias", "_model.visual.transformer.resblocks.3.attn.in_proj_weight", "_model.visual.transformer.resblocks.3.attn.in_proj_bias", "_model.visual.transformer.resblocks.3.attn.out_proj.weight", "_model.visual.transformer.resblocks.3.attn.out_proj.bias", "_model.visual.transformer.resblocks.3.ln_2.weight", "_model.visual.transformer.resblocks.3.ln_2.bias", "_model.visual.transformer.resblocks.3.mlp.c_fc.weight", "_model.visual.transformer.resblocks.3.mlp.c_fc.bias", "_model.visual.transformer.resblocks.3.mlp.c_proj.weight", "_model.visual.transformer.resblocks.3.mlp.c_proj.bias", "_model.visual.transformer.resblocks.4.ln_1.weight", "_model.visual.transformer.resblocks.4.ln_1.bias", "_model.visual.transformer.resblocks.4.attn.in_proj_weight", "_model.visual.transformer.resblocks.4.attn.in_proj_bias", "_model.visual.transformer.resblocks.4.attn.out_proj.weight", "_model.visual.transformer.resblocks.4.attn.out_proj.bias", "_model.visual.transformer.resblocks.4.ln_2.weight", "_model.visual.transformer.resblocks.4.ln_2.bias", "_model.visual.transformer.resblocks.4.mlp.c_fc.weight", "_model.visual.transformer.resblocks.4.mlp.c_fc.bias", "_model.visual.transformer.resblocks.4.mlp.c_proj.weight", "_model.visual.transformer.resblocks.4.mlp.c_proj.bias", "_model.visual.transformer.resblocks.5.ln_1.weight", "_model.visual.transformer.resblocks.5.ln_1.bias", "_model.visual.transformer.resblocks.5.attn.in_proj_weight", "_model.visual.transformer.resblocks.5.attn.in_proj_bias", "_model.visual.transformer.resblocks.5.attn.out_proj.weight", "_model.visual.transformer.resblocks.5.attn.out_proj.bias", "_model.visual.transformer.resblocks.5.ln_2.weight", "_model.visual.transformer.resblocks.5.ln_2.bias", "_model.visual.transformer.resblocks.5.mlp.c_fc.weight", "_model.visual.transformer.resblocks.5.mlp.c_fc.bias", "_model.visual.transformer.resblocks.5.mlp.c_proj.weight", "_model.visual.transformer.resblocks.5.mlp.c_proj.bias", "_model.visual.transformer.resblocks.6.ln_1.weight", "_model.visual.transformer.resblocks.6.ln_1.bias", "_model.visual.transformer.resblocks.6.attn.in_proj_weight", "_model.visual.transformer.resblocks.6.attn.in_proj_bias", "_model.visual.transformer.resblocks.6.attn.out_proj.weight", "_model.visual.transformer.resblocks.6.attn.out_proj.bias", "_model.visual.transformer.resblocks.6.ln_2.weight", "_model.visual.transformer.resblocks.6.ln_2.bias", "_model.visual.transformer.resblocks.6.mlp.c_fc.weight", "_model.visual.transformer.resblocks.6.mlp.c_fc.bias", "_model.visual.transformer.resblocks.6.mlp.c_proj.weight", "_model.visual.transformer.resblocks.6.mlp.c_proj.bias", "_model.visual.transformer.resblocks.7.ln_1.weight", "_model.visual.transformer.resblocks.7.ln_1.bias", "_model.visual.transformer.resblocks.7.attn.in_proj_weight", "_model.visual.transformer.resblocks.7.attn.in_proj_bias", "_model.visual.transformer.resblocks.7.attn.out_proj.weight", "_model.visual.transformer.resblocks.7.attn.out_proj.bias", "_model.visual.transformer.resblocks.7.ln_2.weight", "_model.visual.transformer.resblocks.7.ln_2.bias", "_model.visual.transformer.resblocks.7.mlp.c_fc.weight", "_model.visual.transformer.resblocks.7.mlp.c_fc.bias", "_model.visual.transformer.resblocks.7.mlp.c_proj.weight", "_model.visual.transformer.resblocks.7.mlp.c_proj.bias", "_model.visual.transformer.resblocks.8.ln_1.weight", "_model.visual.transformer.resblocks.8.ln_1.bias", "_model.visual.transformer.resblocks.8.attn.in_proj_weight", "_model.visual.transformer.resblocks.8.attn.in_proj_bias", "_model.visual.transformer.resblocks.8.attn.out_proj.weight", "_model.visual.transformer.resblocks.8.attn.out_proj.bias", "_model.visual.transformer.resblocks.8.ln_2.weight", "_model.visual.transformer.resblocks.8.ln_2.bias", "_model.visual.transformer.resblocks.8.mlp.c_fc.weight", "_model.visual.transformer.resblocks.8.mlp.c_fc.bias", "_model.visual.transformer.resblocks.8.mlp.c_proj.weight", "_model.visual.transformer.resblocks.8.mlp.c_proj.bias", "_model.visual.transformer.resblocks.9.ln_1.weight", "_model.visual.transformer.resblocks.9.ln_1.bias", "_model.visual.transformer.resblocks.9.attn.in_proj_weight", "_model.visual.transformer.resblocks.9.attn.in_proj_bias", "_model.visual.transformer.resblocks.9.attn.out_proj.weight", "_model.visual.transformer.resblocks.9.attn.out_proj.bias", "_model.visual.transformer.resblocks.9.ln_2.weight", "_model.visual.transformer.resblocks.9.ln_2.bias", "_model.visual.transformer.resblocks.9.mlp.c_fc.weight", "_model.visual.transformer.resblocks.9.mlp.c_fc.bias", "_model.visual.transformer.resblocks.9.mlp.c_proj.weight", "_model.visual.transformer.resblocks.9.mlp.c_proj.bias", "_model.visual.transformer.resblocks.10.ln_1.weight", "_model.visual.transformer.resblocks.10.ln_1.bias", "_model.visual.transformer.resblocks.10.attn.in_proj_weight", "_model.visual.transformer.resblocks.10.attn.in_proj_bias", "_model.visual.transformer.resblocks.10.attn.out_proj.weight", "_model.visual.transformer.resblocks.10.attn.out_proj.bias", "_model.visual.transformer.resblocks.10.ln_2.weight", "_model.visual.transformer.resblocks.10.ln_2.bias", "_model.visual.transformer.resblocks.10.mlp.c_fc.weight", "_model.visual.transformer.resblocks.10.mlp.c_fc.bias", "_model.visual.transformer.resblocks.10.mlp.c_proj.weight", "_model.visual.transformer.resblocks.10.mlp.c_proj.bias", "_model.visual.transformer.resblocks.11.ln_1.weight", "_model.visual.transformer.resblocks.11.ln_1.bias", "_model.visual.transformer.resblocks.11.attn.in_proj_weight", "_model.visual.transformer.resblocks.11.attn.in_proj_bias", "_model.visual.transformer.resblocks.11.attn.out_proj.weight", "_model.visual.transformer.resblocks.11.attn.out_proj.bias", "_model.visual.transformer.resblocks.11.ln_2.weight", "_model.visual.transformer.resblocks.11.ln_2.bias", "_model.visual.transformer.resblocks.11.mlp.c_fc.weight", "_model.visual.transformer.resblocks.11.mlp.c_fc.bias", "_model.visual.transformer.resblocks.11.mlp.c_proj.weight", "_model.visual.transformer.resblocks.11.mlp.c_proj.bias", "_model.visual.ln_post.weight", "_model.visual.ln_post.bias", "_model.transformer.resblocks.0.ln_1.weight", "_model.transformer.resblocks.0.ln_1.bias", "_model.transformer.resblocks.0.attn.in_proj_weight", "_model.transformer.resblocks.0.attn.in_proj_bias", "_model.transformer.resblocks.0.attn.out_proj.weight", "_model.transformer.resblocks.0.attn.out_proj.bias", "_model.transformer.resblocks.0.ln_2.weight", "_model.transformer.resblocks.0.ln_2.bias", "_model.transformer.resblocks.0.mlp.c_fc.weight", "_model.transformer.resblocks.0.mlp.c_fc.bias", "_model.transformer.resblocks.0.mlp.c_proj.weight", "_model.transformer.resblocks.0.mlp.c_proj.bias", "_model.transformer.resblocks.1.ln_1.weight", "_model.transformer.resblocks.1.ln_1.bias", "_model.transformer.resblocks.1.attn.in_proj_weight", "_model.transformer.resblocks.1.attn.in_proj_bias", "_model.transformer.resblocks.1.attn.out_proj.weight", "_model.transformer.resblocks.1.attn.out_proj.bias", "_model.transformer.resblocks.1.ln_2.weight", "_model.transformer.resblocks.1.ln_2.bias", "_model.transformer.resblocks.1.mlp.c_fc.weight", "_model.transformer.resblocks.1.mlp.c_fc.bias", "_model.transformer.resblocks.1.mlp.c_proj.weight", "_model.transformer.resblocks.1.mlp.c_proj.bias", "_model.transformer.resblocks.2.ln_1.weight", "_model.transformer.resblocks.2.ln_1.bias", "_model.transformer.resblocks.2.attn.in_proj_weight", "_model.transformer.resblocks.2.attn.in_proj_bias", "_model.transformer.resblocks.2.attn.out_proj.weight", "_model.transformer.resblocks.2.attn.out_proj.bias", "_model.transformer.resblocks.2.ln_2.weight", "_model.transformer.resblocks.2.ln_2.bias", "_model.transformer.resblocks.2.mlp.c_fc.weight", "_model.transformer.resblocks.2.mlp.c_fc.bias", "_model.transformer.resblocks.2.mlp.c_proj.weight", "_model.transformer.resblocks.2.mlp.c_proj.bias", "_model.transformer.resblocks.3.ln_1.weight", "_model.transformer.resblocks.3.ln_1.bias", "_model.transformer.resblocks.3.attn.in_proj_weight", "_model.transformer.resblocks.3.attn.in_proj_bias", "_model.transformer.resblocks.3.attn.out_proj.weight", "_model.transformer.resblocks.3.attn.out_proj.bias", "_model.transformer.resblocks.3.ln_2.weight", "_model.transformer.resblocks.3.ln_2.bias", "_model.transformer.resblocks.3.mlp.c_fc.weight", "_model.transformer.resblocks.3.mlp.c_fc.bias", "_model.transformer.resblocks.3.mlp.c_proj.weight", "_model.transformer.resblocks.3.mlp.c_proj.bias", "_model.transformer.resblocks.4.ln_1.weight", "_model.transformer.resblocks.4.ln_1.bias", "_model.transformer.resblocks.4.attn.in_proj_weight", "_model.transformer.resblocks.4.attn.in_proj_bias", "_model.transformer.resblocks.4.attn.out_proj.weight", "_model.transformer.resblocks.4.attn.out_proj.bias", "_model.transformer.resblocks.4.ln_2.weight", "_model.transformer.resblocks.4.ln_2.bias", "_model.transformer.resblocks.4.mlp.c_fc.weight", "_model.transformer.resblocks.4.mlp.c_fc.bias", "_model.transformer.resblocks.4.mlp.c_proj.weight", "_model.transformer.resblocks.4.mlp.c_proj.bias", "_model.transformer.resblocks.5.ln_1.weight", "_model.transformer.resblocks.5.ln_1.bias", "_model.transformer.resblocks.5.attn.in_proj_weight", "_model.transformer.resblocks.5.attn.in_proj_bias", "_model.transformer.resblocks.5.attn.out_proj.weight", "_model.transformer.resblocks.5.attn.out_proj.bias", "_model.transformer.resblocks.5.ln_2.weight", "_model.transformer.resblocks.5.ln_2.bias", "_model.transformer.resblocks.5.mlp.c_fc.weight", "_model.transformer.resblocks.5.mlp.c_fc.bias", "_model.transformer.resblocks.5.mlp.c_proj.weight", "_model.transformer.resblocks.5.mlp.c_proj.bias", "_model.transformer.resblocks.6.ln_1.weight", "_model.transformer.resblocks.6.ln_1.bias", "_model.transformer.resblocks.6.attn.in_proj_weight", "_model.transformer.resblocks.6.attn.in_proj_bias", "_model.transformer.resblocks.6.attn.out_proj.weight", "_model.transformer.resblocks.6.attn.out_proj.bias", "_model.transformer.resblocks.6.ln_2.weight", "_model.transformer.resblocks.6.ln_2.bias", "_model.transformer.resblocks.6.mlp.c_fc.weight", "_model.transformer.resblocks.6.mlp.c_fc.bias", "_model.transformer.resblocks.6.mlp.c_proj.weight", "_model.transformer.resblocks.6.mlp.c_proj.bias", "_model.transformer.resblocks.7.ln_1.weight", "_model.transformer.resblocks.7.ln_1.bias", "_model.transformer.resblocks.7.attn.in_proj_weight", "_model.transformer.resblocks.7.attn.in_proj_bias", "_model.transformer.resblocks.7.attn.out_proj.weight", "_model.transformer.resblocks.7.attn.out_proj.bias", "_model.transformer.resblocks.7.ln_2.weight", "_model.transformer.resblocks.7.ln_2.bias", "_model.transformer.resblocks.7.mlp.c_fc.weight", "_model.transformer.resblocks.7.mlp.c_fc.bias", "_model.transformer.resblocks.7.mlp.c_proj.weight", "_model.transformer.resblocks.7.mlp.c_proj.bias", "_model.transformer.resblocks.8.ln_1.weight", "_model.transformer.resblocks.8.ln_1.bias", "_model.transformer.resblocks.8.attn.in_proj_weight", "_model.transformer.resblocks.8.attn.in_proj_bias", "_model.transformer.resblocks.8.attn.out_proj.weight", "_model.transformer.resblocks.8.attn.out_proj.bias", "_model.transformer.resblocks.8.ln_2.weight", "_model.transformer.resblocks.8.ln_2.bias", "_model.transformer.resblocks.8.mlp.c_fc.weight", "_model.transformer.resblocks.8.mlp.c_fc.bias", "_model.transformer.resblocks.8.mlp.c_proj.weight", "_model.transformer.resblocks.8.mlp.c_proj.bias", "_model.transformer.resblocks.9.ln_1.weight", "_model.transformer.resblocks.9.ln_1.bias", "_model.transformer.resblocks.9.attn.in_proj_weight", "_model.transformer.resblocks.9.attn.in_proj_bias", "_model.transformer.resblocks.9.attn.out_proj.weight", "_model.transformer.resblocks.9.attn.out_proj.bias", "_model.transformer.resblocks.9.ln_2.weight", "_model.transformer.resblocks.9.ln_2.bias", "_model.transformer.resblocks.9.mlp.c_fc.weight", "_model.transformer.resblocks.9.mlp.c_fc.bias", "_model.transformer.resblocks.9.mlp.c_proj.weight", "_model.transformer.resblocks.9.mlp.c_proj.bias", "_model.transformer.resblocks.10.ln_1.weight", "_model.transformer.resblocks.10.ln_1.bias", "_model.transformer.resblocks.10.attn.in_proj_weight", "_model.transformer.resblocks.10.attn.in_proj_bias", "_model.transformer.resblocks.10.attn.out_proj.weight", "_model.transformer.resblocks.10.attn.out_proj.bias", "_model.transformer.resblocks.10.ln_2.weight", "_model.transformer.resblocks.10.ln_2.bias", "_model.transformer.resblocks.10.mlp.c_fc.weight", "_model.transformer.resblocks.10.mlp.c_fc.bias", "_model.transformer.resblocks.10.mlp.c_proj.weight", "_model.transformer.resblocks.10.mlp.c_proj.bias", "_model.transformer.resblocks.11.ln_1.weight", "_model.transformer.resblocks.11.ln_1.bias", "_model.transformer.resblocks.11.attn.in_proj_weight", "_model.transformer.resblocks.11.attn.in_proj_bias", "_model.transformer.resblocks.11.attn.out_proj.weight", "_model.transformer.resblocks.11.attn.out_proj.bias", "_model.transformer.resblocks.11.ln_2.weight", "_model.transformer.resblocks.11.ln_2.bias", "_model.transformer.resblocks.11.mlp.c_fc.weight", "_model.transformer.resblocks.11.mlp.c_fc.bias", "_model.transformer.resblocks.11.mlp.c_proj.weight", "_model.transformer.resblocks.11.mlp.c_proj.bias", "_model.token_embedding.weight", "_model.ln_final.weight", "_model.ln_final.bias". 
	Unexpected key(s) in state_dict: "_projection.weight", "_model.embeddings.class_embedding", "_model.embeddings.position_ids", "_model.embeddings.patch_embedding.weight", "_model.embeddings.position_embedding.weight", "_model.pre_layrnorm.weight", "_model.pre_layrnorm.bias", "_model.encoder.layers.0.self_attn.k_proj.weight", "_model.encoder.layers.0.self_attn.k_proj.bias", "_model.encoder.layers.0.self_attn.v_proj.weight", "_model.encoder.layers.0.self_attn.v_proj.bias", "_model.encoder.layers.0.self_attn.q_proj.weight", "_model.encoder.layers.0.self_attn.q_proj.bias", "_model.encoder.layers.0.self_attn.out_proj.weight", "_model.encoder.layers.0.self_attn.out_proj.bias", "_model.encoder.layers.0.layer_norm1.weight", "_model.encoder.layers.0.layer_norm1.bias", "_model.encoder.layers.0.mlp.fc1.weight", "_model.encoder.layers.0.mlp.fc1.bias", "_model.encoder.layers.0.mlp.fc2.weight", "_model.encoder.layers.0.mlp.fc2.bias", "_model.encoder.layers.0.layer_norm2.weight", "_model.encoder.layers.0.layer_norm2.bias", "_model.encoder.layers.1.self_attn.k_proj.weight", "_model.encoder.layers.1.self_attn.k_proj.bias", "_model.encoder.layers.1.self_attn.v_proj.weight", "_model.encoder.layers.1.self_attn.v_proj.bias", "_model.encoder.layers.1.self_attn.q_proj.weight", "_model.encoder.layers.1.self_attn.q_proj.bias", "_model.encoder.layers.1.self_attn.out_proj.weight", "_model.encoder.layers.1.self_attn.out_proj.bias", "_model.encoder.layers.1.layer_norm1.weight", "_model.encoder.layers.1.layer_norm1.bias", "_model.encoder.layers.1.mlp.fc1.weight", "_model.encoder.layers.1.mlp.fc1.bias", "_model.encoder.layers.1.mlp.fc2.weight", "_model.encoder.layers.1.mlp.fc2.bias", "_model.encoder.layers.1.layer_norm2.weight", "_model.encoder.layers.1.layer_norm2.bias", "_model.encoder.layers.2.self_attn.k_proj.weight", "_model.encoder.layers.2.self_attn.k_proj.bias", "_model.encoder.layers.2.self_attn.v_proj.weight", "_model.encoder.layers.2.self_attn.v_proj.bias", "_model.encoder.layers.2.self_attn.q_proj.weight", "_model.encoder.layers.2.self_attn.q_proj.bias", "_model.encoder.layers.2.self_attn.out_proj.weight", "_model.encoder.layers.2.self_attn.out_proj.bias", "_model.encoder.layers.2.layer_norm1.weight", "_model.encoder.layers.2.layer_norm1.bias", "_model.encoder.layers.2.mlp.fc1.weight", "_model.encoder.layers.2.mlp.fc1.bias", "_model.encoder.layers.2.mlp.fc2.weight", "_model.encoder.layers.2.mlp.fc2.bias", "_model.encoder.layers.2.layer_norm2.weight", "_model.encoder.layers.2.layer_norm2.bias", "_model.encoder.layers.3.self_attn.k_proj.weight", "_model.encoder.layers.3.self_attn.k_proj.bias", "_model.encoder.layers.3.self_attn.v_proj.weight", "_model.encoder.layers.3.self_attn.v_proj.bias", "_model.encoder.layers.3.self_attn.q_proj.weight", "_model.encoder.layers.3.self_attn.q_proj.bias", "_model.encoder.layers.3.self_attn.out_proj.weight", "_model.encoder.layers.3.self_attn.out_proj.bias", "_model.encoder.layers.3.layer_norm1.weight", "_model.encoder.layers.3.layer_norm1.bias", "_model.encoder.layers.3.mlp.fc1.weight", "_model.encoder.layers.3.mlp.fc1.bias", "_model.encoder.layers.3.mlp.fc2.weight", "_model.encoder.layers.3.mlp.fc2.bias", "_model.encoder.layers.3.layer_norm2.weight", "_model.encoder.layers.3.layer_norm2.bias", "_model.encoder.layers.4.self_attn.k_proj.weight", "_model.encoder.layers.4.self_attn.k_proj.bias", "_model.encoder.layers.4.self_attn.v_proj.weight", "_model.encoder.layers.4.self_attn.v_proj.bias", "_model.encoder.layers.4.self_attn.q_proj.weight", "_model.encoder.layers.4.self_attn.q_proj.bias", "_model.encoder.layers.4.self_attn.out_proj.weight", "_model.encoder.layers.4.self_attn.out_proj.bias", "_model.encoder.layers.4.layer_norm1.weight", "_model.encoder.layers.4.layer_norm1.bias", "_model.encoder.layers.4.mlp.fc1.weight", "_model.encoder.layers.4.mlp.fc1.bias", "_model.encoder.layers.4.mlp.fc2.weight", "_model.encoder.layers.4.mlp.fc2.bias", "_model.encoder.layers.4.layer_norm2.weight", "_model.encoder.layers.4.layer_norm2.bias", "_model.encoder.layers.5.self_attn.k_proj.weight", "_model.encoder.layers.5.self_attn.k_proj.bias", "_model.encoder.layers.5.self_attn.v_proj.weight", "_model.encoder.layers.5.self_attn.v_proj.bias", "_model.encoder.layers.5.self_attn.q_proj.weight", "_model.encoder.layers.5.self_attn.q_proj.bias", "_model.encoder.layers.5.self_attn.out_proj.weight", "_model.encoder.layers.5.self_attn.out_proj.bias", "_model.encoder.layers.5.layer_norm1.weight", "_model.encoder.layers.5.layer_norm1.bias", "_model.encoder.layers.5.mlp.fc1.weight", "_model.encoder.layers.5.mlp.fc1.bias", "_model.encoder.layers.5.mlp.fc2.weight", "_model.encoder.layers.5.mlp.fc2.bias", "_model.encoder.layers.5.layer_norm2.weight", "_model.encoder.layers.5.layer_norm2.bias", "_model.encoder.layers.6.self_attn.k_proj.weight", "_model.encoder.layers.6.self_attn.k_proj.bias", "_model.encoder.layers.6.self_attn.v_proj.weight", "_model.encoder.layers.6.self_attn.v_proj.bias", "_model.encoder.layers.6.self_attn.q_proj.weight", "_model.encoder.layers.6.self_attn.q_proj.bias", "_model.encoder.layers.6.self_attn.out_proj.weight", "_model.encoder.layers.6.self_attn.out_proj.bias", "_model.encoder.layers.6.layer_norm1.weight", "_model.encoder.layers.6.layer_norm1.bias", "_model.encoder.layers.6.mlp.fc1.weight", "_model.encoder.layers.6.mlp.fc1.bias", "_model.encoder.layers.6.mlp.fc2.weight", "_model.encoder.layers.6.mlp.fc2.bias", "_model.encoder.layers.6.layer_norm2.weight", "_model.encoder.layers.6.layer_norm2.bias", "_model.encoder.layers.7.self_attn.k_proj.weight", "_model.encoder.layers.7.self_attn.k_proj.bias", "_model.encoder.layers.7.self_attn.v_proj.weight", "_model.encoder.layers.7.self_attn.v_proj.bias", "_model.encoder.layers.7.self_attn.q_proj.weight", "_model.encoder.layers.7.self_attn.q_proj.bias", "_model.encoder.layers.7.self_attn.out_proj.weight", "_model.encoder.layers.7.self_attn.out_proj.bias", "_model.encoder.layers.7.layer_norm1.weight", "_model.encoder.layers.7.layer_norm1.bias", "_model.encoder.layers.7.mlp.fc1.weight", "_model.encoder.layers.7.mlp.fc1.bias", "_model.encoder.layers.7.mlp.fc2.weight", "_model.encoder.layers.7.mlp.fc2.bias", "_model.encoder.layers.7.layer_norm2.weight", "_model.encoder.layers.7.layer_norm2.bias", "_model.encoder.layers.8.self_attn.k_proj.weight", "_model.encoder.layers.8.self_attn.k_proj.bias", "_model.encoder.layers.8.self_attn.v_proj.weight", "_model.encoder.layers.8.self_attn.v_proj.bias", "_model.encoder.layers.8.self_attn.q_proj.weight", "_model.encoder.layers.8.self_attn.q_proj.bias", "_model.encoder.layers.8.self_attn.out_proj.weight", "_model.encoder.layers.8.self_attn.out_proj.bias", "_model.encoder.layers.8.layer_norm1.weight", "_model.encoder.layers.8.layer_norm1.bias", "_model.encoder.layers.8.mlp.fc1.weight", "_model.encoder.layers.8.mlp.fc1.bias", "_model.encoder.layers.8.mlp.fc2.weight", "_model.encoder.layers.8.mlp.fc2.bias", "_model.encoder.layers.8.layer_norm2.weight", "_model.encoder.layers.8.layer_norm2.bias", "_model.encoder.layers.9.self_attn.k_proj.weight", "_model.encoder.layers.9.self_attn.k_proj.bias", "_model.encoder.layers.9.self_attn.v_proj.weight", "_model.encoder.layers.9.self_attn.v_proj.bias", "_model.encoder.layers.9.self_attn.q_proj.weight", "_model.encoder.layers.9.self_attn.q_proj.bias", "_model.encoder.layers.9.self_attn.out_proj.weight", "_model.encoder.layers.9.self_attn.out_proj.bias", "_model.encoder.layers.9.layer_norm1.weight", "_model.encoder.layers.9.layer_norm1.bias", "_model.encoder.layers.9.mlp.fc1.weight", "_model.encoder.layers.9.mlp.fc1.bias", "_model.encoder.layers.9.mlp.fc2.weight", "_model.encoder.layers.9.mlp.fc2.bias", "_model.encoder.layers.9.layer_norm2.weight", "_model.encoder.layers.9.layer_norm2.bias", "_model.encoder.layers.10.self_attn.k_proj.weight", "_model.encoder.layers.10.self_attn.k_proj.bias", "_model.encoder.layers.10.self_attn.v_proj.weight", "_model.encoder.layers.10.self_attn.v_proj.bias", "_model.encoder.layers.10.self_attn.q_proj.weight", "_model.encoder.layers.10.self_attn.q_proj.bias", "_model.encoder.layers.10.self_attn.out_proj.weight", "_model.encoder.layers.10.self_attn.out_proj.bias", "_model.encoder.layers.10.layer_norm1.weight", "_model.encoder.layers.10.layer_norm1.bias", "_model.encoder.layers.10.mlp.fc1.weight", "_model.encoder.layers.10.mlp.fc1.bias", "_model.encoder.layers.10.mlp.fc2.weight", "_model.encoder.layers.10.mlp.fc2.bias", "_model.encoder.layers.10.layer_norm2.weight", "_model.encoder.layers.10.layer_norm2.bias", "_model.encoder.layers.11.self_attn.k_proj.weight", "_model.encoder.layers.11.self_attn.k_proj.bias", "_model.encoder.layers.11.self_attn.v_proj.weight", "_model.encoder.layers.11.self_attn.v_proj.bias", "_model.encoder.layers.11.self_attn.q_proj.weight", "_model.encoder.layers.11.self_attn.q_proj.bias", "_model.encoder.layers.11.self_attn.out_proj.weight", "_model.encoder.layers.11.self_attn.out_proj.bias", "_model.encoder.layers.11.layer_norm1.weight", "_model.encoder.layers.11.layer_norm1.bias", "_model.encoder.layers.11.mlp.fc1.weight", "_model.encoder.layers.11.mlp.fc1.bias", "_model.encoder.layers.11.mlp.fc2.weight", "_model.encoder.layers.11.mlp.fc2.bias", "_model.encoder.layers.11.layer_norm2.weight", "_model.encoder.layers.11.layer_norm2.bias", "_model.post_layernorm.weight", "_model.post_layernorm.bias". 

Then we can run CLIP benchmark:

In [None]:
"""Console script for clip_benchmark.
Code copied from CLIP Benchmark with minor adjusts to run in colab.
"""
import sys
import json
import torch
import open_clip
from pprint import pprint

from clip_benchmark.datasets.builder import build_dataset, get_dataset_collate_fn
from clip_benchmark.metrics import  zeroshot_retrieval

from torch.utils.data import default_collate



device = "cuda" if torch.cuda.is_available() else "cpu"
image_encoder = clip_vision.to(device)
text_encoder = clip_text.to(device)
_, _, transform = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
dataset = build_dataset(
    dataset_name='flickr8k',
    root='root',
    transform=transform,
    split='test',
    annotation_file=None,
    download=True,
)
collate_fn = get_dataset_collate_fn('flickr8k')

dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=64,
    shuffle=False, num_workers=4,
    collate_fn=collate_fn
)

metrics = zeroshot_retrieval.evaluate(
    image_encoder,
    text_encoder,
    dataloader,
    open_clip.tokenizer.tokenize,
    recall_k_list=[5],
    device=device,
    amp=True
)
dump = {
    "dataset": 'flickr8k',
    "model": 'ViT-B-32',
    "pretrained": 'openai',
    "task": 'finetuned',
    "metrics": metrics
}
pprint(dump)

  cpuset_checked))
16it [00:16,  1.01s/it]


{'dataset': 'flickr8k',
 'metrics': {'image_retrieval_recall@5': 0.8537999987602234,
             'text_retrieval_recall@5': 0.9100000262260437},
 'model': 'ViT-B-32',
 'pretrained': 'openai',
 'task': 'finetuned'}


## Results: Pre-Trained Zero-Shot vs Fine-Tuned

The people responsible for CLIP Benchmark have published benchmarking results for a wide variety of models and configuarations in [this csv](https://github.com/LAION-AI/CLIP_benchmark/blob/main/benchmark/benchmark.csv).

For simplicity, we show the comparsion below:

+ `image_retrieval_recall@5`: use text queries to find top 5 similar images.
+ `text_retrieval_recall@5`: use image to find top 5 similar text.


| model                            | dataset       | imageRecall@5(zero-shot) | textRecall@5(zero-shot) | imageRecall@5(fine-tuned) | textRecall@5(fine-tuned) |
|----------------------------------|---------------|-------------------|----------------------|---------|-------------|
| ViT-B-32#openai                  | flickr8k      |0.5319737792015076 | 0.6991719007492065   |0.8537999987602234| 0.9100000262260437 |

Apart from that, we have done some extensive experiments on three datasets, these are the results we get:


| model                            | dataset       | imageRecall@5(zero-shot) | textRecall@5(zero-shot) | imageRecall@5(fine-tuned) | textRecall@5(fine-tuned) |
|----------------------------------|---------------|-------------------|----------------------|---------|-------------|
| ViT-B-32#openai                  | flickr8k      |0.5319737792015076 | 0.6991719007492065   |0.8651999831199646| 0.9079999923706055 |
| ViT-B-16-plus-240                | flickr8k      |0.6441478133201599 | 0.7916203141212463   |0.8784000277519226| 0.9200000166893005 |
| ViT-B-32-quickgelu#laion400m_e32 | flickr8k      |0.5787171125411987 | 0.7392163872718811   |0.849399983882904 | 0.9020000100135803 |
| ViT-B-32#openai                  | flickr30k     |0.8338000178337097 | 0.9490000009536743   |0.9016000032424927| 0.9480000138282776 |
| ViT-B-16-plus-240                | flickr30k     |0.8894000053405762 | 0.9710000157356262   |0.9169999957084656| 0.9710000157356262 |
| ViT-B-32-quickgelu#laion400m_e32 | flickr30k     |0.8546000123023987 | 0.9409999847412109   |0.8715999722480774| 0.9290000200271606 |
| ViT-B-32#openai                  | coco captions |0.5584565997123718 | 0.748199999332428    |0.6546581387519836| 0.7454000115394592 |
| ViT-B-16-plus-240                | coco captions |0.6620951890945435 | 0.8101999759674072   |0.7120751738548279| 0.8136000037193298 |
| ViT-B-32-quickgelu#laion400m_e32 | coco captions |0.6084766387939453 | 0.7675999999046326   |0.6713714599609375| 0.7635999917984009 |

Our Finetuner hyper-parameters were: `learning_rate: 1e-6`, `epochs: 5`, `optimizer: Adam`.