<a href="https://colab.research.google.com/github/JSJeong-me/CLIP-Zero-Shot-Classification/blob/main/10class_CLIP_Zero_Shot_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to use CLIP Zero-Shot on your own classificaiton dataset

This notebook provides an example of how to benchmark CLIP's zero shot classification performance on your own classification dataset.

[CLIP](https://openai.com/blog/clip/) is a new zero shot image classifier relased by OpenAI that has been trained on 400 million text/image pairs across the web. CLIP uses these learnings to make predicts based on a flexible span of possible classification categories.

CLIP is zero shot, that means **no training is required**. 

Try it out on your own task here!

Be sure to experiment with various text prompts to unlock the richness of CLIP's pretraining procedure.


![Roboflow Wordmark](https://i.imgur.com/dcLNMhV.png)


# Download and Install CLIP Dependencies

In [1]:
#installing some dependencies, CLIP was release in PyTorch
import subprocess

CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
else:
    torch_version_suffix = "+cu110"

!pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex

import numpy as np
import torch

print("Torch version:", torch.__version__)

CUDA version: 11.1
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.7.1+cu110
  Downloading https://download.pytorch.org/whl/cu110/torch-1.7.1%2Bcu110-cp37-cp37m-linux_x86_64.whl (1156.8 MB)
[K     |███████████████████████         | 834.1 MB 1.2 MB/s eta 0:04:28tcmalloc: large alloc 1147494400 bytes == 0x557f48382000 @  0x7f7d2664e615 0x557f462f74cc 0x557f463d747a 0x557f462fa2ed 0x557f463ebe1d 0x557f4636de99 0x557f463689ee 0x557f462fbbda 0x557f4636dd00 0x557f463689ee 0x557f462fbbda 0x557f4636a737 0x557f463ecc66 0x557f46369daf 0x557f463ecc66 0x557f46369daf 0x557f463ecc66 0x557f46369daf 0x557f462fc039 0x557f4633f409 0x557f462fac52 0x557f4636dc25 0x557f463689ee 0x557f462fbbda 0x557f4636a737 0x557f463689ee 0x557f462fbbda 0x557f46369915 0x557f462fbafa 0x557f46369c0d 0x557f463689ee
[K     |█████████████████████████████▏  | 1055.7 MB 1.1 MB/s eta 0:01:29tcmalloc: large alloc 1434370048 bytes == 0x557f8c9d8000 @  0x7f7d2664e615 0x557f462f74cc 0x557f463

In [2]:
#clone the CLIP repository
!git clone https://github.com/openai/CLIP.git
%cd CLIP

Cloning into 'CLIP'...
remote: Enumerating objects: 168, done.[K
remote: Counting objects: 100% (77/77), done.[K
remote: Compressing objects: 100% (43/43), done.[K
remote: Total 168 (delta 36), reused 53 (delta 29), pack-reused 91[K
Receiving objects: 100% (168/168), 8.92 MiB | 22.60 MiB/s, done.
Resolving deltas: 100% (76/76), done.
/content/CLIP


# Download Classification Data or Object Detection Data

We will download the [public flowers classificaiton dataset](https://public.roboflow.com/classification/flowers_classification) from Roboflow. The data will come out as folders broken into train/valid/test splits and seperate folders for each class label.

You can easily download your own dataset from Roboflow in this format, too.

We made a conversion from object detection to CLIP text prompts in Roboflow, too, if you want to try that out.


To get your data into Roboflow, follow the [Getting Started Guide](https://blog.roboflow.ai/getting-started-with-roboflow/).

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%cd ./CLIP

[Errno 2] No such file or directory: './CLIP'
/content/CLIP


In [7]:
!pwd

/content/CLIP


In [None]:
!unzip ./vit01.v4i.clip.zip

In [6]:
!cp /content/CLIP/train/_tokenization.txt ./_token.org

In [8]:
import os
#our the classes and images we want to test are stored in folders in the test set
class_names = os.listdir('./train/')
class_names.remove('_tokenization.txt')
class_names

['car',
 'cat',
 'dog',
 'bicycle',
 'four',
 'bench',
 'person',
 'fish',
 'apple',
 'cup']

In [9]:
#we auto generate some example tokenizations in Roboflow but you should edit this file to try out your own prompts
#CLIP gets a lot better with the right prompting!
#be sure the tokenizations are in the same order as your class_names above!
%cat ./train/_tokenization.txt

An example picture from the vit01 dataset depicting a apple
An example picture from the vit01 dataset depicting a bench
An example picture from the vit01 dataset depicting a bicycle
An example picture from the vit01 dataset depicting a car
An example picture from the vit01 dataset depicting a cat
An example picture from the vit01 dataset depicting a cup
An example picture from the vit01 dataset depicting a dog
An example picture from the vit01 dataset depicting a fish
An example picture from the vit01 dataset depicting a four
An example picture from the vit01 dataset depicting a person

In [None]:
#edit your prompts as you see fit here
%%writefile ./train/_tokenization.txt
An example picture from the flowers dataset depicting a daisy
An example picture from the flowers dataset depicting a dandelion

In [11]:
candidate_captions = []
with open('./train/_tokenization.txt') as f:
    candidate_captions = f.read().splitlines()

In [12]:
candidate_captions

['An example picture from the vit01 dataset depicting a apple',
 'An example picture from the vit01 dataset depicting a bench',
 'An example picture from the vit01 dataset depicting a bicycle',
 'An example picture from the vit01 dataset depicting a car',
 'An example picture from the vit01 dataset depicting a cat',
 'An example picture from the vit01 dataset depicting a cup',
 'An example picture from the vit01 dataset depicting a dog',
 'An example picture from the vit01 dataset depicting a fish',
 'An example picture from the vit01 dataset depicting a four',
 'An example picture from the vit01 dataset depicting a person']

In [None]:
%cd ./CLIP

/content/CLIP


# Run CLIP inference on your classification dataset

In [None]:
!ls -l

total 12
drwxr-xr-x 8 root root 4096 Oct 18 13:24 CLIP
-rw-r--r-- 1 root root  402 Oct 18 13:27 README.roboflow.txt
drwxr-xr-x 1 root root 4096 Oct  8 13:45 sample_data


In [25]:
import torch
import clip
from PIL import Image
import glob

def argmax(iterable):
    return max(enumerate(iterable), key=lambda x: x[1])[0]

device = "cuda" if torch.cuda.is_available() else "cpu"
model, transform = clip.load("ViT-B/32", device=device)

correct = []

#define our target classificaitons, you can should experiment with these strings of text as you see fit, though, make sure they are in the same order as your class names above
text = clip.tokenize(candidate_captions).to(device)

for cls in class_names:
    class_correct = []
    test_imgs = glob.glob('./train/' + cls + '/*.jpg')
    #print(test_imgs)
    for img in test_imgs:
        print(img)
        image = transform(Image.open(img)).unsqueeze(0).to(device)
        with torch.no_grad():
            image_features = model.encode_image(image)
            text_features = model.encode_text(text)
            
            logits_per_image, logits_per_text = model(image, text)
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()

            pred = class_names[argmax(list(probs)[0])]
            print(pred)
            if pred == cls:
                correct.append(1)
                class_correct.append(1)
            else:
                correct.append(0)
                class_correct.append(0)
    
    print('accuracy on class ' + cls + ' is :' + str(sum(class_correct)/len(class_correct)))
    print('\n\n\n')
    #print('accuracy on class ' + cls + ' is :{}   {}'.format(str(sum(class_correct)), str(len(class_correct))))
print('accuracy on all is : ' + str(sum(correct)/len(correct)))

./train/car/car_d_1_png.rf.a47342e8fabfe2fa75ee0c6c79dcdabc.jpg
bicycle
./train/car/car_p_4_png.rf.0bba57ad0314b8602a5d7660eb15e9fd.jpg
bicycle
./train/car/car_s_2_png.rf.3aec4e485bf07c3d8de94d3b5db9d9bb.jpg
bicycle
./train/car/car_d_1_png.rf.1f28515e9ebf30faf495dd636456d7e4.jpg
bicycle
./train/car/car_p_3_png.rf.5f533d47209d018def821c476ac080bb.jpg
bicycle
./train/car/car_d_3_PNG.rf.aa2a27f8e1fa503bb747da0c59226ca9.jpg
bicycle
./train/car/car_s_4_PNG.rf.58cce95838bf3f73190b4da842eed1a3.jpg
bicycle
./train/car/car_d_4_png.rf.fa3ef0f02517130ee5df7e609e6675f0.jpg
bicycle
./train/car/car_d_1_png.rf.e828243ce86480efb370959a7644c773.jpg
bicycle
./train/car/car_s_2_png.rf.123db7a5bd52f8b56737f6d72a243428.jpg
bicycle
./train/car/car_s_4_PNG.rf.4d1dae36097863d48bf161e0321ce43c.jpg
bicycle
./train/car/car_d_2_png.rf.f3dd14427f76a8320c0f819418775fa4.jpg
bicycle
./train/car/car_d_4_png.rf.2029921d58f20ba93d18a62c9dc261e2.jpg
bicycle
./train/car/car_p_2_PNG.rf.d1fe44a78b166f03a5e61b1283c3685d.jpg
