<a href="https://colab.research.google.com/github/JSJeong-me/2021-K-Digital-Training/blob/main/Roboflow_CLIP_Zero_Shot_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to use CLIP Zero-Shot on your own classificaiton dataset

This notebook provides an example of how to benchmark CLIP's zero shot classification performance on your own classification dataset.

[CLIP](https://openai.com/blog/clip/) is a new zero shot image classifier relased by OpenAI that has been trained on 400 million text/image pairs across the web. CLIP uses these learnings to make predicts based on a flexible span of possible classification categories.

CLIP is zero shot, that means **no training is required**. 

Try it out on your own task here!

Be sure to experiment with various text prompts to unlock the richness of CLIP's pretraining procedure.




# Download and Install CLIP Dependencies

In [None]:
#installing some dependencies, CLIP was release in PyTorch
import subprocess

CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
else:
    torch_version_suffix = "+cu110"

!pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex

import numpy as np
import torch
import os

print("Torch version:", torch.__version__)
os.kill(os.getpid(), 9)
#Your notebook process will restart after these installs

CUDA version: 11.0
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.7.1+cu110
[?25l  Downloading https://download.pytorch.org/whl/cu110/torch-1.7.1%2Bcu110-cp37-cp37m-linux_x86_64.whl (1156.8MB)
[K     |███████████████████████         | 834.1MB 1.4MB/s eta 0:03:44tcmalloc: large alloc 1147494400 bytes == 0x55ebe6168000 @  0x7f60f89d4615 0x55ebabe06cdc 0x55ebabee652a 0x55ebabe09afd 0x55ebabefafed 0x55ebabe7d988 0x55ebabe784ae 0x55ebabe0b3ea 0x55ebabe7d7f0 0x55ebabe784ae 0x55ebabe0b3ea 0x55ebabe7a32a 0x55ebabefbe36 0x55ebabe79853 0x55ebabefbe36 0x55ebabe79853 0x55ebabefbe36 0x55ebabe79853 0x55ebabefbe36 0x55ebabf7e3e1 0x55ebabede6a9 0x55ebabe49cc4 0x55ebabe0a559 0x55ebabe7e4f8 0x55ebabe0b30a 0x55ebabe793b5 0x55ebabe787ad 0x55ebabe0b3ea 0x55ebabe793b5 0x55ebabe0b30a 0x55ebabe793b5
[K     |█████████████████████████████▏  | 1055.7MB 1.3MB/s eta 0:01:20tcmalloc: large alloc 1434370048 bytes == 0x55ec2a7be000 @  0x7f60f89d4615 0x55ebabe06cdc 0x55ebab

In [1]:
#clone the CLIP repository
!git clone https://github.com/openai/CLIP.git
%cd CLIP

Cloning into 'CLIP'...
remote: Enumerating objects: 90, done.[K
remote: Total 90 (delta 0), reused 0 (delta 0), pack-reused 90[K
Unpacking objects: 100% (90/90), done.
/content/CLIP


# Download Classification Data or Object Detection Data

We will download the [public flowers classificaiton dataset](https://public.roboflow.com/classification/flowers_classification) from Roboflow. The data will come out as folders broken into train/valid/test splits and seperate folders for each class label.

You can easily download your own dataset from Roboflow in this format, too.

We made a conversion from object detection to CLIP text prompts in Roboflow, too, if you want to try that out.


To get your data into Roboflow, follow the [Getting Started Guide](https://blog.roboflow.ai/getting-started-with-roboflow/).

In [2]:
#download classification data
#replace with your link
!curl -L "https://public.roboflow.com/ds/vPLCmk4Knv?key=tCrKLQNpTi" > roboflow.zip; unzip roboflow.zip; rm roboflow.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   887  100   887    0     0   1639      0 --:--:-- --:--:-- --:--:--  1636
100 60.9M  100 60.9M    0     0  34.0M      0  0:00:01  0:00:01 --:--:--  136M
Archive:  roboflow.zip
 extracting: README.dataset.txt      
 extracting: README.roboflow.txt     
   creating: test/
 extracting: test/_tokenization.txt  
   creating: test/daisy/
 extracting: test/daisy/10466290366_cc72e33532_jpg.rf.6ddc91cd5d4a6a683e567ccb37e5a089.jpg  
 extracting: test/daisy/10466558316_a7198b87e2_jpg.rf.7acf642b94dc98daa49482a12994ac4c.jpg  
 extracting: test/daisy/12193032636_b50ae7db35_n_jpg.rf.e6c4eeb71c56e793a0d85f6d979dbe20.jpg  
 extracting: test/daisy/1342002397_9503c97b49_jpg.rf.8fe6bdd23186b70f089bb0c5b89d314e.jpg  
 extracting: test/daisy/1354396826_2868631432_m_jpg.rf.409eee37613d16dbc71365cb5615327e.jpg  
 extracting: test/daisy/1374193928_a

In [3]:
!pwd

/content/CLIP


In [4]:
import os
#our the classes and images we want to test are stored in folders in the test set
class_names = os.listdir('./test/')
#class_names.remove('_tokenization.txt')
class_names

['dandelion', 'daisy', '_tokenization.txt']

In [5]:
#we auto generate some example tokenizations in Roboflow but you should edit this file to try out your own prompts
#CLIP gets a lot better with the right prompting!
#be sure the tokenizations are in the same order as your class_names above!
%cat ./test/_tokenization.txt

An example picture from the Flowers_Classification dataset depicting a daisy
An example picture from the Flowers_Classification dataset depicting a dandelion

In [6]:
#edit your prompts as you see fit here, be sure the classes are in teh same order as above
%%writefile ./test/_tokenization.txt
An example picture from the flowers dataset depicting a daisy
An example picture from the flowers dataset depicting a dandelion

Overwriting ./test/_tokenization.txt


In [7]:
candidate_captions = []
with open('./test/_tokenization.txt') as f:
    candidate_captions = f.read().splitlines()

# Run CLIP inference on your classification dataset

In [10]:
import torch
import clip
from PIL import Image
import glob

def argmax(iterable):
    return max(enumerate(iterable), key=lambda x: x[1])[0]

device = "cuda" if torch.cuda.is_available() else "cpu"
model, transform = clip.load("ViT-B/32", device=device)

correct = []

#define our target classificaitons, you can should experiment with these strings of text as you see fit, though, make sure they are in the same order as your class names above
text = clip.tokenize(candidate_captions).to(device)

for cls in class_names:
    class_correct = []
    test_imgs = glob.glob('./test/' + cls + '/*.jpg')
    for img in test_imgs:
        #print(img)
        image = transform(Image.open(img)).unsqueeze(0).to(device)
        with torch.no_grad():
            image_features = model.encode_image(image)
            text_features = model.encode_text(text)
            
            logits_per_image, logits_per_text = model(image, text)
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()

            pred = class_names[argmax(list(probs)[0])]
            print(pred)
            if pred == cls:
                correct.append(1)
                class_correct.append(1)
            else:
                correct.append(0)
                class_correct.append(0)
    
    #print('accuracy on class ' + cls + ' is :' + str(sum(class_correct)/len(class_correct)))
#print('accuracy on all is : ' + str(sum(correct)/len(correct)))

daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
dandelion
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
daisy
dandelion
dandelion
daisy
dandelion
dandelion
dandelion
dandelion
daisy
dandelion
dandelion
dandelion
dandelion
daisy
dandelion
dandelion
daisy
dandelion
dandelion
dandelion
dandelion
dandelion
dandelion
dandelion
daisy
daisy
dandelion
dandelion
dandelion
daisy
dandelion
dandelion
dandelion
dandelion
dandelion
dandelion
dandelion
dandelion
dandelion
dandelion
dand

In [None]:
#Hope you enjoyed!
#As always, happy inferencing
#Roboflow