<a href="https://colab.research.google.com/drive/1DW3BwSZEPl8JyArZPv5J6p3PQad7XiAy?usp=drive_link" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install and import libraries

Install and import the necessary libraries, including `transformers` and `torch`.


In [None]:
%pip install transformers torch

Import the installed libraries.

In [None]:
import transformers
import torch

## Load clip model and processor

Download and load the CLIP model and its corresponding processor locally.

In [None]:
model_name = "openai/clip-vit-base-patch32"
model = transformers.CLIPModel.from_pretrained(model_name)
processor = transformers.CLIPProcessor.from_pretrained(model_name)

## Load images

Load the images you want to create embeddings for.


Create a list of image file paths and load the images using PIL.

## Download dataset

Use the provided KaggleHub code to download the dataset containing the images.

In [None]:
from PIL import Image
import os
import kagglehub

# Download the dataset and get the path
dataset_path = kagglehub.dataset_download("gunhcolab/object-detection-dataset-standard-52card-deck")

# Assuming the images are in a directory named 'train' within the downloaded dataset
image_dir = os.path.join(dataset_path, 'train') # Update with the correct image directory

image_paths = []
# Walk through the nested directories to find image files
for root, _, files in os.walk(image_dir):
    for f in files:
        if f.endswith(('.jpg', '.jpeg', '.png')):
            image_paths.append(os.path.join(root, f))

# Load all images
images = []
loaded_image_paths = [] # Keep track of paths for successfully loaded images
for p in image_paths:
    try:
        img = Image.open(p).convert("RGB")
        images.append(img)
        loaded_image_paths.append(p) # Add path only if image is loaded successfully
    except Exception as e:
        print(f"Error loading image {p}: {e}")

print(f"Loaded {len(images)} images.")
# Update image_paths to only include successfully loaded images
image_paths = loaded_image_paths

Inspect the `kagglehub` module to find the correct function for getting the dataset path.

## Create a search function

Create a function to search the image embeddings using either text or image queries, processing them with the local CLIP model.




Define a function that takes a query (text or image), processes it using the CLIP model and processor to get its embedding, calculates similarity with the image embeddings, and returns the indices of the most similar images.

In [None]:
import torch.nn.functional as F

def search_images(query, image_features, model, processor, image_paths, top_k=5, batch_size=64):
    """
    Searches for images based on a text or image query using CLIP embeddings with batch processing.

    Args:
        query: The search query (string for text or PIL Image for image).
        image_features: The precomputed embeddings of the images.
        model: The CLIP model.
        processor: The CLIP processor.
        image_paths: A list of paths to the images.
        top_k: The number of top similar images to return.
        batch_size: The batch size for calculating similarity.

    Returns:
        A list of tuples containing the similarity score and the path to the image.
    """
    if isinstance(query, str):
        # Process the text query
        text_inputs = processor(text=query, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            query_features = model.get_text_features(**text_inputs)
    elif isinstance(query, Image.Image):
        # Process the image query
        image_inputs = processor(images=query, return_tensors="pt")
        with torch.no_grad():
            query_features = model.get_image_features(pixel_values=image_inputs.pixel_values)
    else:
        raise ValueError("Query must be a string (text) or a PIL Image.")

    # Normalize query features
    query_features = F.normalize(query_features, p=2, dim=1)

    similarity_scores = []
    # Calculate similarity in batches
    for i in range(0, image_features.shape[0], batch_size):
        batch_image_features = image_features[i:i + batch_size]
        # Normalize batch image features
        batch_image_features = F.normalize(batch_image_features, p=2, dim=1)
        batch_scores = (query_features @ batch_image_features.T).squeeze(0)
        similarity_scores.append(batch_scores)

    similarity_scores = torch.cat(similarity_scores, dim=0)


    # Get the top k results
    top_k_scores, top_k_indices = torch.topk(similarity_scores, top_k)

    # Get the paths and scores of the top k images
    results = [(top_k_scores[i].item(), image_paths[top_k_indices[i].item()]) for i in range(top_k)]

    return results

## Create image embeddings


Generate embeddings for the loaded images using the local CLIP model.


Process the loaded images using the CLIP processor and generate embeddings using the CLIP model.

In [None]:
# Generate embeddings for a subset of loaded images
num_images_to_process = 100 # Set the number of images to process
inputs = processor(images=images[:num_images_to_process], return_tensors="pt")
with torch.no_grad():
    image_features = model.get_image_features(pixel_values=inputs.pixel_values)

print(f"Image embeddings created for {num_images_to_process} images.")
print("Shape of image embeddings:", image_features.shape)

Inspect the contents of the downloaded dataset directory to confirm the correct image directory and file extensions.

In [None]:
import os
import kagglehub

# Get the path to the downloaded dataset
dataset_path = kagglehub.dataset_download("gunhcolab/object-detection-dataset-standard-52card-deck")

# List the contents of the downloaded dataset directory
print(f"Contents of the dataset directory: {os.listdir(dataset_path)}")

# If there's a subdirectory for images, list its contents as well
# Replace 'train' with the actual subdirectory name if different
image_subdir = os.path.join(dataset_path, 'train')
if os.path.exists(image_subdir):
    print(f"Contents of the image subdirectory ('train'): {os.listdir(image_subdir)[:20]}") # Print only the first 20 items to avoid flooding the output

In [None]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np

from typing import Dict
from pathlib import Path
from IPython.core.display import display, HTML


def display_projections(
    labels: np.ndarray,
    projections: np.ndarray,
    image_paths: np.ndarray,
    image_data_uris: Dict[str, str],
    show_legend: bool = False,
    show_markers_with_text: bool = True
) -> None:
    # Create a separate trace for each unique label
    unique_labels = np.unique(labels)
    traces = []
    for unique_label in unique_labels:
        mask = labels == unique_label
        customdata_masked = image_paths[mask]
        trace = go.Scatter3d(
            x=projections[mask][:, 0],
            y=projections[mask][:, 1],
            z=projections[mask][:, 2],
            mode='markers+text' if show_markers_with_text else 'markers',
            text=labels[mask],
            customdata=customdata_masked,
            name=str(unique_label),
            marker=dict(size=8),
            hovertemplate="<b>class: %{text}</b><br>path: %{customdata}<extra></extra>"
        )
        traces.append(trace)

    # Create the 3D scatter plot
    fig = go.Figure(data=traces)
    fig.update_layout(
        scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
        width=1000,
        height=1000,
        showlegend=show_legend
    )

    # Convert the chart to an HTML div string and add an ID to the div
    plotly_div = fig.to_html(full_html=False, include_plotlyjs=False, div_id="scatter-plot-3d")

    # Define your JavaScript code for copying text on point click
    javascript_code = f"""
    <script>
        function displayImage(imagePath) {{
            var imageElement = document.getElementById('image-display');
            var placeholderText = document.getElementById('placeholder-text');
            var imageDataURIs = {image_data_uris};
            imageElement.src = imageDataURIs[imagePath];
            imageElement.style.display = 'block';
            placeholderText.style.display = 'none';
        }}

        // Get the Plotly chart element by its ID
        var chartElement = document.getElementById('scatter-plot-3d');

        // Add a click event listener to the chart element
        chartElement.on('plotly_click', function(data) {{
            var customdata = data.points[0].customdata;
            displayImage(customdata);
        }});
    </script>
    """

    # Create an HTML template including the chart div and JavaScript code
    html_template = f"""
    <!DOCTYPE html>
    <html>
        <head>
            <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
            <style>
                #image-container {{
                    position: fixed;
                    top: 0;
                    left: 0;
                    width: 200px;
                    height: 200px;
                    padding: 5px;
                    border: 1px solid #ccc;
                    background-color: white;
                    z-index: 1000;
                    box-sizing: border-box;
                    display: flex;
                    align-items: center;
                    justify-content: center;
                    text-align: center;
                }}
                #image-display {{
                    width: 100%;
                    height: 100%;
                    object-fit: contain;
                }}
            </style>
        </head>
        <body>
            {plotly_div}
            <div id="image-container">
                <img id="image-display" src="" alt="Selected image" style="display: none;" />
                <p id="placeholder-text">Click on a data entry to display an image</p>
            </div>
            {javascript_code}
        </body>
    </html>
    """

    # Display the HTML template in the Jupyter Notebook
    display(HTML(html_template))

## Prepare data for visualization

Extract labels from image paths, reduce dimensionality of image embeddings, and create image data URIs.

In [None]:
from sklearn.decomposition import PCA
import base64
from io import BytesIO
import numpy as np # Import numpy

# 1. Extract labels from image paths for the processed images
# Assuming the label is the directory name immediately preceding the image file
labels = np.array([os.path.basename(os.path.dirname(p)) for p in image_paths[:num_images_to_process]]) # Use only the processed image_paths

# 2. Reduce dimensionality of image embeddings using PCA
# Reduce to 3 components for 3D visualization.
pca = PCA(n_components=3)
projections = pca.fit_transform(image_features.cpu().numpy())

# 3. Create image data URIs
image_data_uris = {}
# Create data URIs for all loaded images
for i, img in enumerate(images[:num_images_to_process]): # Use only the processed images
    buffered = BytesIO()
    img.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    image_data_uris[image_paths[i]] = f"data:image/png;base64,{img_str}"

print("Data prepared for visualization.")
print("Labels created:", labels[:5])
print("Projections shape:", projections.shape)
print("Image data URIs created for", len(image_data_uris), "images.")

Now that `labels`, `projections`, and `image_data_uris` are defined, we can call the `display_projections` function to visualize the image embeddings.

In [None]:
display_projections(
    labels=labels,
    projections=projections,
    image_paths=np.array(image_paths[:num_images_to_process]), # Pass only the processed image_paths
    image_data_uris=image_data_uris
)

## Perform search

Use the search function to retrieve images based on a given text or image query.\
Call the `search_images` function with a sample text query and the generated image features to find similar images.


In [None]:
# Perform a text-based search
query_text = "a playing card with a club"
search_results = search_images(query_text, image_features, model, processor, image_paths)

print(f"Search results for '{query_text}':")
for score, path in search_results:
    print(f"Similarity: {score:.4f}, Image Path: {path}")

## Display results

Display the retrieved images based on the search results.\
Load and display the images from the paths returned by the search function.

In [None]:
from IPython.display import display

print("Displaying top 5 search results:")
for score, path in search_results:
    try:
        img = Image.open(path)
        print(f"Similarity: {score:.4f}")
        display(img)
    except Exception as e:
        print(f"Error displaying image {path}: {e}")

## Final task

We have successfully built a pipeline to generate image embeddings using a locally downloaded CLIP model and retrieve images based on a text query. We installed the necessary libraries, loaded the CLIP model and processor, downloaded and loaded the images from a Kaggle dataset, created image embeddings, defined a search function, performed a text-based search, and displayed the top search results.