In [5]:
!git clone https://github.com/Godofnothing/CLIP_experimental
!git clone https://github.com/openai/CLIP

!pip install -q pytorch-lightning
!pip install -q ftfy regex
!wget https://openaipublic.azureedge.net/clip/bpe_simple_vocab_16e6.txt.gz -O bpe_simple_vocab_16e6.txt.gz

Cloning into 'CLIP_experimental'...
remote: Enumerating objects: 177, done.[K
remote: Counting objects: 100% (177/177), done.[K
remote: Compressing objects: 100% (106/106), done.[K
remote: Total 177 (delta 91), reused 136 (delta 52), pack-reused 0[K
Receiving objects: 100% (177/177), 36.69 KiB | 7.34 MiB/s, done.
Resolving deltas: 100% (91/91), done.
Cloning into 'CLIP'...
remote: Enumerating objects: 90, done.[K
remote: Total 90 (delta 0), reused 0 (delta 0), pack-reused 90[K
Unpacking objects: 100% (90/90), done.
[K     |████████████████████████████████| 808kB 9.6MB/s 
[K     |████████████████████████████████| 645kB 23.2MB/s 
[K     |████████████████████████████████| 112kB 38.5MB/s 
[K     |████████████████████████████████| 829kB 36.8MB/s 
[K     |████████████████████████████████| 276kB 35.4MB/s 
[K     |████████████████████████████████| 1.3MB 38.1MB/s 
[K     |████████████████████████████████| 296kB 34.7MB/s 
[K     |████████████████████████████████| 143kB 39.7MB/s 
[

In [6]:
!pip install -q kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!ls ~/.kaggle
!chmod 600 /root/.kaggle/kaggle.json

kaggle.json


In [7]:
!kaggle datasets download moltean/fruits
!unzip -q fruits.zip

Downloading fruits.zip to /content
 99% 753M/760M [00:30<00:00, 34.8MB/s]
100% 760M/760M [00:30<00:00, 26.4MB/s]


In [8]:
dataset_root = 'fruits-360'

In [73]:
import sys

import torch
import numpy as np
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

import torchvision.transforms as T
from torch.utils.data import DataLoader
from PIL import Image

from CLIP_experimental.src import CLIP_Lite, CLIP_Pro
from CLIP_experimental.src import TextTransformer
from CLIP_experimental.src import SimpleTokenizer
from CLIP_experimental.src import CLIPDataset
from CLIP_experimental.src import ClassificationVisualizer

from CLIP import clip

In [76]:
def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

In [77]:
templates = [
    'a low resolution photo of the {}, a type of fruit.',
    'a photo of {}, a type of fruit.',
    'a cropped photo of the {}, a type of fruit.',
    'a photo of a clean {}, a type of fruit.',
    'a close-up photo of a {}, a type of fruit.',
    'a photo of the nice {}, a type of fruit.',
    'a photo of the small {}, a type of fruit.',
    'a photo of the large {}, a type of fruit.',
    'itap of a {}, a type of fruit.'
    ]

In [78]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, image_transform = clip.load("ViT-B/32", jit=False)
model = model.to(dtype=torch.float32)

input_resolution = 224
context_length = model.context_length
vocab_size = model.vocab_size

tokenizer = SimpleTokenizer()
text_transformers = [TextTransformer(tokenizer, [template], context_length) for template in templates]

100%|███████████████████████████████████████| 354M/354M [00:03<00:00, 99.3MiB/s]
  "Argument interpolation should be of type InterpolationMode instead of int. "


In [79]:
MEAN = (0.48145466, 0.4578275, 0.40821073)
STD = (0.26862954, 0.26130258, 0.27577711)

img_transform = T.Compose([
    T.Resize(input_resolution, interpolation=T.InterpolationMode.BICUBIC),
    T.CenterCrop(input_resolution),
    T.ToTensor(),
    T.Normalize(MEAN, STD)                              
])

dataloaders = [DataLoader(CLIPDataset(
    f'{dataset_root}/Training', 
    image_transform=img_transform, 
    prompt_transform=text_transformer,
    return_indices=False), batch_size=32, shuffle=True) for text_transformer in text_transformers]

In [None]:
with torch.no_grad():
    tops1 = []
    tops5 = []
    for template_id, loader in enumerate(dataloaders):
        top1, top5, n = 0., 0., 0.
        for i, (images, text) in enumerate(tqdm(loader)):
            images = images.to(device)
            text = text[:,0,:]
            text = text.to(device)
            target = torch.arange(len(images), device=device)

            image_features = model.visual(images)
            text_features = model.encode_text(text)

            # normalize features
            image_features = image_features / (image_features.norm(dim=-1, keepdim=True) + 1e-6)
            text_features = text_features / (text_features.norm(dim=-1, keepdim=True) + 1e-6)

            logits = 100. * image_features @ text_features.T

            # measure accuracy
            acc1, acc5 = accuracy(logits, target, topk=(1, 5))
            top1 += acc1
            top5 += acc5
            n += images.size(0)

        tops1.append((top1 / n) * 100)
        tops5.append((top5 / n) * 100)

    print(f'Template: {template_id}')
    print(f"Top-1 accuracy: {tops1[-1]:.2f}")
    print(f"Top-5 accuracy: {tops5[-1]:.2f}")

HBox(children=(FloatProgress(value=0.0, max=2116.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2116.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2116.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2116.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2116.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2116.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2116.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2116.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2116.0), HTML(value='')))


Template: 8
Top-1 accuracy: 37.44
Top-5 accuracy: 76.75


In [None]:
print(tops1)
print(tops5)

[38.28517402351829, 37.95869526679667, 39.40790639957454, 38.539266087573125, 40.01359097086805, 39.33404242746558, 40.21302369556225, 39.01495006795486, 37.4431247414761]
[77.52319328724222, 78.10080954913431, 78.33865153932518, 79.820362819831, 79.80411274596703, 78.3504697748626, 79.95774980795368, 78.22785558116173, 76.75057613898245]


In [None]:
np.arange(1, len(sorted_templates)+1)

array([1, 2, 3, 4, 5, 6, 7, 8, 9])

In [None]:
import pandas as pd
sorted_templates = [(templates[i], tops1[i], tops5[i]) for i in np.argsort(tops1)[::-1]]
df_templates = pd.DataFrame(sorted_templates, 
                            index=np.arange(1, len(sorted_templates)+1),
                            columns=['Template', 'Top-1 Accuracy', 'Top-5 Accuracy'])
df_templates

Unnamed: 0,Template,Top-1 Accuracy,Top-5 Accuracy
1,"a photo of the small {}, a type of fruit.",40.213024,79.95775
2,"a close-up photo of a {}, a type of fruit.",40.013591,79.804113
3,"a cropped photo of the {}, a type of fruit.",39.407906,78.338652
4,"a photo of the nice {}, a type of fruit.",39.334042,78.35047
5,"a photo of the large {}, a type of fruit.",39.01495,78.227856
6,"a photo of a clean {}, a type of fruit.",38.539266,79.820363
7,"a low resolution photo of the {}, a type of fr...",38.285174,77.523193
8,"a photo of {}, a type of fruit.",37.958695,78.10081
9,"itap of a {}, a type of fruit.",37.443125,76.750576
