In [None]:
!pip install torch torchvision transformers datasets pillow matplotlib open_clip_torch

In [None]:
import os

# Download COCO dataset (if not already downloaded)
if not os.path.exists("val2017"):
    !wget http://images.cocodataset.org/zips/val2017.zip
    !unzip -q val2017.zip

# Path to images
image_dir = "val2017"
print("COCO dataset downloaded and extracted!")


In [None]:
import torch
import open_clip

# Load CLIP model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")

# Load tokenizer for text queries
tokenizer = open_clip.get_tokenizer("ViT-B/32")

print("CLIP model loaded successfully!")


In [None]:
from PIL import Image
import os

# Function to preprocess and load images
def load_images(image_dir, num_images=1000):  # Load first 1000 images
    image_paths = sorted(os.listdir(image_dir))[:num_images]
    images = [Image.open(os.path.join(image_dir, img)).convert("RGB") for img in image_paths]
    return images, image_paths

# Load images
images, image_filenames = load_images(image_dir)

# Preprocess images for CLIP
image_tensors = torch.stack([preprocess(img) for img in images]).to(device)

print(f"Loaded and preprocessed {len(images)} images.")


In [None]:
# Encode images into feature vectors
with torch.no_grad():
    image_features = model.encode_image(image_tensors)

# Normalize features
image_features /= image_features.norm(dim=-1, keepdim=True)

print("Image embeddings generated!")


In [None]:
import torch.nn.functional as F

# Function to search for similar images
def search_images(query, top_k=5):
    # Encode text
    text_tokens = tokenizer([query]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_tokens)

    # Normalize text features
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # Compute similarity scores (cosine similarity)
    similarity_scores = (text_features @ image_features.T).squeeze(0)

    # Get top-k most similar images
    top_k_indices = similarity_scores.topk(top_k).indices.tolist()

    return top_k_indices, similarity_scores[top_k_indices]

# Test retrieval
query = "a person riding a bicycle"
top_indices, scores = search_images(query)

print("Top matching images:")
for i, idx in enumerate(top_indices):
    print(f"{i+1}. {image_filenames[idx]} (Score: {scores[i]:.4f})")


In [None]:
import matplotlib.pyplot as plt

# Function to display images
def display_results(query, top_indices, scores):
    fig, axes = plt.subplots(1, len(top_indices), figsize=(15, 5))

    for i, idx in enumerate(top_indices):
        img = Image.open(os.path.join(image_dir, image_filenames[idx]))
        axes[i].imshow(img)
        axes[i].axis("off")
        axes[i].set_title(f"Score: {scores[i]:.4f}")

    plt.suptitle(f"Query: {query}", fontsize=14)
    plt.show()

# Run retrieval and display
display_results(query, top_indices, scores)
