In [None]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

In [None]:
#download classification data
!curl -L "https://public.roboflow.com/ds/pmXBn8TqW8?key=ZZm6kgr3sf" > roboflow.zip; unzip roboflow.zip; rm roboflow.zip

In [21]:
import os
#our the classes and images we want to test are stored in folders in the test set
train_class_names = os.listdir('./test/')
train_class_names.remove('_tokenization.txt')

In [22]:
#we auto generate some example tokenizations in Roboflow but you should edit this file to try out your own prompts
#CLIP gets a lot better with the right prompting!
#be sure the tokenizations are in the same order as your class_names above!
%cat ./test/_tokenization.txt

An example picture from the Flowers_Classification dataset depicting a daisy
An example picture from the Flowers_Classification dataset depicting a dandelion

In [23]:
#edit your prompts as you see fit here
# %%writefile ./test/_tokenization.txt
# An example picture from the flowers dataset depicting a daisy
# An example picture from the flowers dataset depicting a dandelion

In [24]:
train_candidate_captions = []
with open('./test/_tokenization.txt') as f:
    train_candidate_captions = f.read().splitlines()

In [25]:
#https://github.com/openai/CLIP/issues/57
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()

In [26]:
import clip
import torch

device = "cuda:0" if torch.cuda.is_available() else "cpu" # If using GPU then use mixed precision training.
model, transform = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training

In [27]:
from PIL import Image
from torch.utils.data import Dataset, DataLoader

class image_caption_dataset(Dataset):
    def __init__(self, df):

        self.images = df["image"].tolist()
        self.caption = df["caption"].tolist()

    def __len__(self):
        return len(self.caption)

    def __getitem__(self, idx):
        
        images = transform(Image.open(self.images[idx])) #preprocess from clip.load
        caption = self.caption[idx]
        return images,caption

In [28]:
import pandas as pd
import glob
from PIL import Image

imageList = []
captionList = []

for i, cls in enumerate(train_class_names):
    train_imgs = glob.glob('./train/' + cls + '/*.jpg')
    for img in train_imgs:
        imageList.append(img)
        captionList.append(train_candidate_captions[i])

listOfTuples = list(zip(imageList, captionList)) 
  
# Converting lists of tuples into pandas Dataframe. 
df = pd.DataFrame(listOfTuples,
                  columns = ['image', 'caption'])

dataset = image_caption_dataset(df)

In [29]:
BATCH_SIZE = 64
train_dataloader = DataLoader(dataset,batch_size = BATCH_SIZE, shuffle = True) #Define your own dataloader

In [None]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

if device == "cpu":
    model.float()
else :
    clip.model.convert_weights(model) # Actually this line is unnecessary since clip by default already on float16

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) #Params from paper

EPOCHS = 20
for epoch in tqdm(range(EPOCHS)):
    for batch in train_dataloader:
        optimizer.zero_grad()

        list_image,list_txt = batch #list_images is list of image in numpy array(np.uint8), or list of PIL images

        # images= torch.stack([preprocess(Image.fromarray(img)) for img in list_image],dim=0) # omit the Image.fromarray if the images already in PIL format, change this line to images=list_image if using preprocess inside the dataset class
        images = list_image.to(device)
        texts = clip.tokenize(list_txt).to(device)

        logits_per_image, logits_per_text = model(images, texts)
        if device == "cpu":
            ground_truth = torch.arange(list_image.shape[0]).half().to(device) # Did not use BATCH_SIZE as last batch might be incomplete and cause error
        else:
            ground_truth = torch.arange(list_image.shape[0]).long().to(device) # Did not use BATCH_SIZE as last batch might be incomplete and cause error

        total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
        total_loss.backward()
        
        if device == "cpu":
            optimizer.step()
        else :
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)


  0%|          | 0/20 [00:00<?, ?it/s][A

In [None]:
# torch.save(model, "trainedModel.pt")

In [None]:
# from google.colab import files
# files.download('/content/trainedModel.pt') 

In [48]:
# Testing
test_class_names = os.listdir('./test/')
test_class_names.remove('_tokenization.txt')

#edit your prompts as you see fit here
# %%writefile ./test/_tokenization.txt
# An example picture from the flowers dataset depicting a daisy
# An example picture from the flowers dataset depicting a dandelion

test_candidate_captions = []
with open('./test/_tokenization.txt') as f:
    test_candidate_captions = f.read().splitlines()

In [59]:
import torch
import clip
from PIL import Image
import glob

def argmax(iterable):
    return max(enumerate(iterable), key=lambda x: x[1])[0]

correct = []

#define our target classificaitons, you can should experiment with these strings of text as you see fit, though, make sure they are in the same order as your class names above
text = clip.tokenize(test_candidate_captions).to(device)

for cls in test_class_names:
    class_correct = []
    test_imgs = glob.glob('./test/' + cls + '/*.jpg')
    for img in test_imgs:
        #print(img)
        image = transform(Image.open(img)).unsqueeze(0).to(device)
        with torch.no_grad():
            image_features = model.encode_image(image)
            text_features = model.encode_text(text)
            
            logits_per_image, logits_per_text = model(image, text)
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()

            pred = test_class_names[argmax(list(probs)[0])]
            #print(pred)
            if pred == cls:
                correct.append(1)
                class_correct.append(1)
            else:
                correct.append(0)
                class_correct.append(0)
    
    print('accuracy on class ' + cls + ' is :' + str(sum(class_correct)/len(class_correct)))
print('accuracy on all is : ' + str(sum(correct)/len(correct)))

accuracy on class dandelion is :0.6095238095238096
accuracy on class daisy is :0.7792207792207793
accuracy on all is : 0.6813186813186813
