## Cat Classification with CLIP Example

In [1]:
import torch
from PIL import Image
import requests
from transformers import AutoProcessor, CLIPModel
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [4]:
def get_common_objects_and_concepts():
    """Curated list of common objects, animals, concepts, emotions, etc."""
    categories = {
        'objects': [
            'chair', 'table', 'car', 'bicycle', 'bottle', 'cup', 'fork', 'knife', 'spoon',
            'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'pizza',
            'donut', 'cake', 'bed', 'toilet', 'laptop', 'mouse', 'remote', 'keyboard',
            'cell phone', 'book', 'clock', 'scissors', 'teddy bear', 'hair dryer',
            'toothbrush', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
            'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
            'skateboard', 'surfboard', 'tennis racket', 'man', 'woman'
        ],
        'animals': [
            'person', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
            'giraffe', 'bird', 'chicken', 'duck', 'eagle', 'owl', 'fish', 'shark', 'whale',
            'dolphin', 'turtle', 'frog', 'snake', 'spider', 'bee', 'butterfly', 'lion',
            'tiger', 'fox', 'wolf', 'rabbit', 'hamster', 'mouse', 'rat'
        ],
        'clothing': [
            'hat', 'cap', 'helmet', 'glasses', 'sunglasses', 'shirt', 't-shirt', 'sweater',
            'jacket', 'coat', 'dress', 'skirt', 'pants', 'jeans', 'shorts', 'shoes',
            'sneakers', 'boots', 'sandals', 'socks', 'tie', 'scarf', 'gloves', 'belt',
            'watch', 'ring', 'necklace', 'earrings', 'bracelet'
        ]
    }
    
    # Flatten all categories
    all_concepts = []
    for category, items in categories.items():
        if category == "objects":
            items = ['A photo of a ' + w for w in items]
        all_concepts.extend(items)
    
    return all_concepts, categories

In [5]:
words, categories=get_common_objects_and_concepts()

In [6]:
image = Image.open('n02123045_1955.jpg')

inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
    image_features = model.get_image_features(**inputs)
    # Normalize the embeddings
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)

In [7]:
all_text_features=[]
for i, word in enumerate(words):
    inputs=processor(text='A photo of a '+ word, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        text_features = model.get_text_features(**inputs)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        all_text_features.append(text_features)
all_text_features=torch.concat(all_text_features)
text_features

tensor([[-1.7887e-02, -6.2129e-03,  1.6517e-02, -2.9591e-02, -1.6887e-02,
          1.6373e-02,  2.9903e-02, -6.8541e-02, -9.1934e-03,  6.7381e-02,
         -1.8423e-02,  2.4983e-03,  2.1080e-02, -3.8571e-02, -3.2232e-02,
         -2.2684e-02,  1.6661e-02, -4.1826e-04, -4.1781e-02, -2.3613e-02,
          5.0989e-02, -3.5312e-02, -1.6488e-03, -1.7573e-02, -5.7819e-03,
          3.2350e-02, -3.3370e-02,  5.5010e-02, -2.9318e-02,  3.0652e-02,
         -8.2050e-04,  5.5223e-02, -6.2733e-03, -9.8345e-03, -6.0178e-02,
          1.2874e-02, -3.5781e-02,  2.0373e-02, -1.7494e-02, -9.8844e-03,
          4.1935e-03, -6.2362e-02,  5.0253e-02, -1.1455e-02,  8.2908e-03,
          2.1451e-02,  2.0287e-02,  3.2606e-02, -3.5070e-02,  2.7776e-02,
         -5.2976e-03,  2.7021e-02,  6.6977e-02, -4.6803e-02,  1.2262e-02,
         -2.7150e-04, -7.8685e-03,  1.4277e-02,  1.1502e-02, -5.3555e-03,
          5.2074e-02, -4.8163e-02,  4.6797e-04, -2.8965e-02,  1.6307e-02,
         -4.8863e-02, -2.6552e-02,  2.

In [8]:
all_text_features.shape

torch.Size([110, 512])

In [9]:
similarities = torch.cosine_similarity(image_features, all_text_features, dim=1)

In [10]:
similarities.shape

torch.Size([110])

In [11]:
top_k = 200
top_k_indices = similarities.argsort(descending=True)[:top_k]

results = []
for idx in top_k_indices:
    emb = all_text_features[idx].cpu().tolist()
    emb_short = [round(emb[0], 2), round(emb[1], 2), round(emb[-1], 2)]
    results.append({
        'text': words[idx],
        'similarity': similarities[idx].item(),
        'embedding': emb_short
    })

import pandas as pd

df_results = pd.DataFrame(results)
df_results.to_csv('cat_similarity.csv', index=False)


In [12]:
df_results

Unnamed: 0,text,similarity,embedding
0,cat,0.278386,"[0.01, 0.01, 0.0]"
1,tiger,0.239558,"[-0.0, 0.03, 0.01]"
2,socks,0.235312,"[-0.01, 0.01, -0.02]"
3,mouse,0.235208,"[-0.04, 0.02, -0.0]"
4,boots,0.234423,"[-0.0, 0.02, -0.02]"
...,...,...,...
105,A photo of a car,0.161206,"[0.03, 0.04, -0.01]"
106,A photo of a pizza,0.159154,"[0.03, 0.02, -0.05]"
107,A photo of a broccoli,0.156919,"[0.05, 0.04, -0.03]"
108,A photo of a snowboard,0.156050,"[0.05, 0.02, -0.04]"
