<a href="https://colab.research.google.com/github/PsorTheDoctor/deep-neural-nets/blob/main/clip_zero_shot_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Fine-tuning CLIP for zero-shot object detection

In [28]:
!pip install -q transformers einops ftfy captum
!pip install -q git+https://github.com/openai/CLIP.git

import numpy as np
import json
from PIL import Image
from tqdm import tqdm
from pathlib import Path
import glob

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms

import clip
from transformers import CLIPProcessor, CLIPModel

  Preparing metadata (setup.py) ... [?25l[?25hdone


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [32]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [5]:
class CeresDataset(Dataset):
  def __init__(self, img_folder):
    self.img_folder = Path(img_folder)
    self.img_paths = list(self.img_folder.glob('*.jpg'))
    self.label = clip.tokenize(['ceres'])[0]

  def __len__(self):
    return len(self.img_paths)

  def __getitem__(self, idx):
    img = preprocess(Image.open(self.img_paths[idx]))
    return img, self.label

In [6]:
dataset = CeresDataset('/content/drive/MyDrive/ceres-logo-images/')
# batch size must be larger than 1
train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [7]:
def convert_model_to_fp32(model):
  for p in model.parameters():
    p.data = p.data.float()
    p.grad.data = p.grad.data.float()

opt = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.2)
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

n_epochs = 5
losses = []
for epoch in range(n_epochs):
  pbar = tqdm(train_dataloader, total=len(train_dataloader))
  for batch in pbar:
    opt.zero_grad()

    images, texts = batch
    images = images.to(device)
    texts = texts.to(device)

    logits_per_img, logits_per_txt = model(images, texts)
    ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
    total_loss = (loss_img(logits_per_img, ground_truth) + loss_txt(logits_per_txt, ground_truth)) / 2
    losses.append(float(total_loss))

    total_loss.backward()
    if device == 'cpu':
      opt.step()
    else:
      convert_model_to_fp32(model)
      opt.step()
      clip.model.convert_weights(model)

    pbar.set_description(f'Epoch {epoch}/{n_epochs}, Loss: {total_loss.item():.4f}')

Epoch 0/5, Loss: 0.6978: 100%|██████████| 7/7 [01:29<00:00, 12.76s/it]
Epoch 1/5, Loss: 0.6943: 100%|██████████| 7/7 [00:01<00:00,  4.80it/s]
Epoch 2/5, Loss: 0.6968: 100%|██████████| 7/7 [00:01<00:00,  4.78it/s]
Epoch 3/5, Loss: 0.6934: 100%|██████████| 7/7 [00:01<00:00,  4.84it/s]
Epoch 4/5, Loss: 0.6934: 100%|██████████| 7/7 [00:01<00:00,  4.83it/s]
