<a href="https://www.kaggle.com/code/aisuko/zero-shot-object-detection?scriptVersionId=164771326" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

Traditionally, models used for [object detection](https://www.kaggle.com/code/aisuko/object-detection) require labeled image datasets for training, and are limited to detecting the set of classes from the training data. Zero-shot object detection is supported by the OWL-ViT model which uses a different approach. OWL-ViT is an open-vocabulary object detector. It means that it can detect objects in images based on free-text queries without the need to fine-tune the model on labeled datasets. OWL-ViT leverages multi-modal representations to perform open-vocabulary detection. It combines [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) with lightweight object classification and localization heads. **Open-vocabulary detection is achieved by embedding free-text queries with the text encoder of CLIP and using them as input to the object classification and localization heads;associate images and their corresponding textual descriptions, and ViT processes image patches as inputs.** The authors of OWL-ViT first trained CLIP from scratch and then fine-tuned OWL-ViT end to end on standard object detection datasets using a bipartite matching loss. With this apporach, the model can detect objects based on textual descriptions without prior training on labeled datasets.

Here we will use OWL-ViT:
- Detection objects based on text prompts
- For batch object detection
- For image-guided object detection

In [None]:
%%capture
!pip install transformers==4.35.2

# Zero-Shot Object Detection Pipeline

In [None]:
from transformers import pipeline

model_checkpoint="google/owlvit-base-patch32"
detector=pipeline(model=model_checkpoint, task="zero-shot-object-detection")
detector.enable_cpu_offloading()
detector.to('cuda')

# Loading Image

In [None]:
import skimage
import numpy as np
from PIL import Image

image=skimage.data.astronaut()
image=Image.fromarray(np.uint8(image)).convert("RGB")
image

Pass the image and the candidate object labels to look for the pipeline. Here we pass the image directly; other suitable options include a local path to an image or an image url. We also pass text descriptions for all items we want to query the image for.

In [None]:
predictions=detector(image, candidate_labels=["human face", "rocket", "nasa badge", "star-spangled banner"],)
predictions

In [None]:
from PIL import ImageDraw

draw=ImageDraw.Draw(image)

for prediction in predictions:
    box=prediction["box"]
    label=prediction["label"]
    score=prediction["score"]
    
    xmin,ymin, xmax,ymax=box.values()
    draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
    draw.text((xmin,ymin), f"{label}: {round(score,2)}", fill="white")

image

# Text-prompted zero-shot object detection by hand

In [None]:
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

model=AutoModelForZeroShotObjectDetection.from_pretrained(model_checkpoint)
processor=AutoProcessor.from_pretrained(model_checkpoint)

In [None]:
import requests

url="https://unsplash.com/photos/oj0zeY2Ltk4/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MTR8fHBpY25pY3xlbnwwfHx8fDE2Nzc0OTE1NDk&force=true&w=640"
im=Image.open(requests.get(url, stream=True).raw)
im

In [None]:
text_queries=["hat", "book", "sunglasses", "camera"]
inputs=processor(text=text_queries, images=im, return_tensors="pt")

We will need to resize the images before feeding them to the model by using the `post_process_object_detection()` to make sure the predicted bounding boxes have the correct coordinates relative to the orignal image:

In [None]:
import torch

with torch.no_grad():
    outputs=model(**inputs)
    target_sizes=torch.tensor([im.size[::-1]])
    results=processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]

draw=ImageDraw.Draw(im)

scores=results["scores"].tolist()
labels=results["labels"].tolist()
boxes=results["boxes"].tolist()

for box, score, label in zip(boxes, scores, labels):
    xmin, ymin, xmax,ymax=box
    draw.rectangle((xmin, ymin,xmax, ymax), outline="red", width=1)
    draw.text((xmin,ymin), f"{text_queries[label]}:{round(score,2)}", fill="white")

im

# Batch processing

We can pass multiple sets of images and text queries to search for different(or same) objects in several images. Let's use astronaut image and the beach image together. For batch processing, we should pass text queries as a nested list to the processor and images as list of PIL images, PyTorch tensors, or NumPy arrays.

In [None]:
im_batch=Image.open(requests.get(url, stream=True).raw)

images=[image, im_batch]

text_queries=[
    ["human face", "rocket", "nasa badge","start-spangled banner"],
    ["hat","book","sunglasses", "camera"],
]

inputs=processor(text=text_queries, images=images, return_tensors="pt")

With several images, we can pass a tuples. Here we are going to create predictions for the two examples, and visualize the second one(image_idx=1).

In [None]:
with torch.no_grad():
    outputs=model(**inputs)
    target_sizes=[x.size[::-1] for x in images]
    results=processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)
    
image_idx=1
draw=ImageDraw.Draw(images[image_idx])

scores=results[image_idx]["scores"].tolist()
labels=results[image_idx]["labels"].tolist()
boxes=results[image_idx]["boxes"].tolist()

for box, score, label in zip(boxes, scores, labels):
    xmin, ymin, xmax, ymax=box
    draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
    draw.text((xmin, ymin), f"{text_queries[image_idx][label]}:{round(score,2)}", fill="white")

images[image_idx]

# Image-guided object detection

It means we can use an image query to find similar objects in the target image, and an image of a single cat as a query:

In [None]:
url="http://images.cocodataset.org/val2017/000000039769.jpg"
image_target=Image.open(requests.get(url, stream=True).raw)

query_url="http://images.cocodataset.org/val2017/000000524280.jpg"
query_image=Image.open(requests.get(query_url, stream=True).raw)

In [None]:
import matplotlib.pyplot as plt

# Take a quick look at the images
fig, ax=plt.subplots(1,2)
ax[0].imshow(image_target)
ax[1].imshow(query_image)

We can use the query_iamges

In [None]:
inputs=processor(images=image_target, query_images=query_image, return_tensors="pt")

In [None]:
with torch.no_grad():
    outputs=model.image_guided_detection(**inputs)
    target_sizes=torch.tensor([image_target.size[::-1]])
    results=processor.post_process_image_guided_detection(outputs=outputs, target_sizes=target_sizes)[0]

draw=ImageDraw.Draw(image_target)

scores=results["scores"].tolist()
boxes=results["boxes"].tolist()

for box, score, label in zip(boxes, scores, labels):
    xmin, ymin, xmax, ymax=box
    draw.rectangle((xmin, ymin, xmax, ymax), outline="white", width=4)

image_target