## Cat Classification with CLIP Example

In [9]:
import torch
from PIL import Image
import requests
from transformers import AutoProcessor, CLIPModel
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [10]:
device='cuda'

In [11]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [12]:
def get_common_objects_and_concepts():
    """Curated list of common objects, animals, concepts, emotions, etc."""
    categories = {
        'objects': [
            'chair', 'table', 'car', 'bicycle', 'bottle', 'cup', 'fork', 'knife', 'spoon',
            'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'pizza',
            'donut', 'cake', 'bed', 'toilet', 'laptop', 'mouse', 'remote', 'keyboard',
            'cell phone', 'book', 'clock', 'scissors', 'teddy bear', 'hair dryer',
            'toothbrush', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
            'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
            'skateboard', 'surfboard', 'tennis racket', 'man', 'woman'
        ],
        'animals': [
            'person', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
            'giraffe', 'bird', 'chicken', 'duck', 'eagle', 'owl', 'fish', 'shark', 'whale',
            'dolphin', 'turtle', 'frog', 'snake', 'spider', 'bee', 'butterfly', 'lion',
            'tiger', 'fox', 'wolf', 'rabbit', 'hamster', 'mouse', 'rat'
        ],
        'clothing': [
            'hat', 'cap', 'helmet', 'glasses', 'sunglasses', 'shirt', 't-shirt', 'sweater',
            'jacket', 'coat', 'dress', 'skirt', 'pants', 'jeans', 'shorts', 'shoes',
            'sneakers', 'boots', 'sandals', 'socks', 'tie', 'scarf', 'gloves', 'belt',
            'watch', 'ring', 'necklace', 'earrings', 'bracelet'
        ],
        'emotions': [
            'happy', 'sad', 'angry', 'surprised', 'excited', 'calm', 'peaceful', 'joyful',
            'melancholy', 'nostalgic', 'anxious', 'confident', 'mysterious', 'dramatic',
            'romantic', 'energetic', 'serene', 'tense', 'playful', 'serious'
        ],
        'abstract_concepts': [
            'freedom', 'justice', 'peace', 'war', 'love', 'hate', 'beauty', 'ugliness',
            'truth', 'lie', 'innovation', 'tradition', 'progress', 'chaos', 'order',
            'simplicity', 'complexity', 'elegance', 'roughness', 'sophistication'
        ],
        'colors': [
            'red', 'blue', 'green', 'yellow', 'orange', 'purple', 'pink', 'brown',
            'black', 'white', 'gray', 'silver', 'gold', 'cyan', 'magenta', 'lime',
            'navy', 'maroon', 'olive', 'aqua'
        ],
        'styles': [
            'modern', 'vintage', 'classic', 'contemporary', 'abstract', 'realistic',
            'minimalist', 'ornate', 'rustic', 'elegant', 'casual', 'formal', 'artistic',
            'professional', 'creative', 'traditional', 'futuristic', 'retro'
        ],
        'activities': [
            'running', 'walking', 'jumping', 'dancing', 'singing', 'reading', 'writing',
            'cooking', 'eating', 'sleeping', 'working', 'playing', 'studying', 'exercising',
            'swimming', 'flying', 'driving', 'riding', 'climbing', 'surfing'
        ]
    }
    
    # Flatten all categories
    all_concepts = []
    for category, items in categories.items():
        if category == "objects":
            items = ['A photo of a ' + w for w in items]
        all_concepts.extend(items)
    
    return all_concepts, categories

In [13]:
words, categories=get_common_objects_and_concepts()

In [19]:
image = Image.open('n02123045_1955.jpg')

inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
    image_features = model.get_image_features(**inputs)
    # Normalize the embeddings
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)

In [20]:
all_text_features=[]
for i, word in enumerate(words):
    inputs=processor(text='A photo of a '+ word, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        text_features = model.get_text_features(**inputs)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        all_text_features.append(text_features)
all_text_features=torch.concat(all_text_features)

In [21]:
all_text_features.shape

torch.Size([208, 512])

In [22]:
similarities = torch.cosine_similarity(image_features, all_text_features, dim=1)

In [23]:
similarities.shape

torch.Size([208])

In [24]:
top_k=100
top_k_indices = similarities.argsort(descending=True)[:top_k]

results = []
for idx in top_k_indices:
    results.append({
        'text': words[idx],
        'similarity': similarities[idx].item()
    })
results

[{'text': 'cat', 'similarity': 0.2783862054347992},
 {'text': 'olive', 'similarity': 0.2613104581832886},
 {'text': 'gray', 'similarity': 0.2560223340988159},
 {'text': 'surprised', 'similarity': 0.2504279315471649},
 {'text': 'black', 'similarity': 0.2491314709186554},
 {'text': 'serious', 'similarity': 0.24542485177516937},
 {'text': 'cyan', 'similarity': 0.24468094110488892},
 {'text': 'angry', 'similarity': 0.24427530169487},
 {'text': 'elegant', 'similarity': 0.2428973913192749},
 {'text': 'white', 'similarity': 0.24277672171592712},
 {'text': 'tiger', 'similarity': 0.23955842852592468},
 {'text': 'green', 'similarity': 0.23955202102661133},
 {'text': 'maroon', 'similarity': 0.23935534060001373},
 {'text': 'beauty', 'similarity': 0.23908817768096924},
 {'text': 'magenta', 'similarity': 0.23754103481769562},
 {'text': 'brown', 'similarity': 0.23747514188289642},
 {'text': 'socks', 'similarity': 0.23531195521354675},
 {'text': 'mouse', 'similarity': 0.23520846664905548},
 {'text': '