In [1]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Cell 1: Load Model and Processor
from transformers import CLIPProcessor, CLIPModel
import torch

# Define the model ID
model_id = "openai/clip-vit-base-patch32"

# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the processor and model
print(f"Loading model: {model_id}")
processor = CLIPProcessor.from_pretrained(model_id)
model = CLIPModel.from_pretrained(model_id).to(device) # Move model to GPU if available

print("CLIP model and processor loaded successfully.")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Using device: cpu
Loading model: openai/clip-vit-base-patch32
CLIP model and processor loaded successfully.


In [3]:
# Cell 2: Load and Process Dataset
from datasets import load_dataset
from PIL import Image
import io

# Define the path to your training data
data_dir = "InsPLAD-det"
# Load the training dataset using 'imagefolder'
# This automatically loads images and assigns labels based on subdirectories (if any)
# We might not need the labels directly, but it's good structure.
# Set trust_remote_code=True if prompted or if it's needed for imagefolder
print("Loading dataset...")
train_dataset = load_dataset("imagefolder", data_dir=data_dir + "/train", split="train")
print(f"Dataset loaded: {train_dataset}")

Loading dataset...
Dataset loaded: Dataset({
    features: ['image'],
    num_rows: 7935
})


In [5]:
# --- Image Processing Function ---
def compute_embeddings(image_batch):
    """
    Preprocesses images using the CLIPProcessor and computes embeddings using the CLIPModel.
    Handles potential errors with corrupted images.
    """
    images = []
    valid_indices = [] # Keep track of which images were successfully processed

    # Pre-process images individually to handle potential loading errors
    for i, img_data in enumerate(image_batch['image']):
        try:
            # Ensure image is in RGB format
            image = img_data.convert("RGB")
            images.append(image)
            valid_indices.append(i)
        except Exception as e:
            print(f"Warning: Skipping image at index {i} due to error: {e}")
            # Optionally, you could try loading from bytes if path fails:
            # try:
            #     image = Image.open(io.BytesIO(img_data['bytes'])).convert("RGB")
            #     images.append(image)
            #     valid_indices.append(i)
            # except Exception as e_inner:
            #     print(f"Warning: Skipping image at index {i} due to error (bytes): {e_inner}")


    if not images: # If no images were valid in the batch
        return {"embeddings": []}

    # Process the valid images in a batch
    inputs = processor(text=None, images=images, return_tensors="pt", padding=True, truncation=True)

    # Move inputs to the correct device (GPU/CPU)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Get image embeddings (disable gradient calculation for inference)
    with torch.no_grad():
        image_features = model.get_image_features(**inputs)

    # Move embeddings back to CPU for storage/further processing if needed
    embeddings = image_features.cpu().numpy()

    # We need to return embeddings corresponding to the original batch structure,
    # inserting None or a placeholder for failed images if necessary, but for simplicity
    # in mapping, we'll just return the embeddings for the successfully processed images.
    # However, datasets mapping usually expects the output to align with the input batch.
    # A more robust way handles this alignment, but let's try a simpler direct map first.
    # For now, we assume the map function handles potential discrepancies if lengths differ.
    # A safer approach would pad the results to match the original batch size.
    return {"embeddings": embeddings}

In [6]:
# --- Apply the function to the dataset ---
# Use batched=True for efficiency
# This might take a while depending on dataset size and hardware
print("Computing embeddings (this may take a while)...")
# Note: If you encounter issues with batching due to image errors,
# you might need to set batch_size=1, which will be slower.
train_dataset_with_embeddings = train_dataset.map(compute_embeddings, batched=True, batch_size=16) # Adjust batch_size as needed

print("Embeddings computed and added to the dataset:")
print(train_dataset_with_embeddings)

# You can access embeddings like this:
example_embedding = train_dataset_with_embeddings[0]['embeddings']
print(f"Example embedding shape: {example_embedding}")

Computing embeddings (this may take a while)...
Embeddings computed and added to the dataset:
Dataset({
    features: ['image', 'embeddings'],
    num_rows: 7935
})
Example embedding shape: [0.32368314266204834, -0.0077776312828063965, 0.267245888710022, -0.19161367416381836, 0.053290851414203644, -0.36559200286865234, 0.13718417286872864, 0.6847219467163086, -0.6552464365959167, -0.01993478834629059, 0.20937520265579224, -0.4695011377334595, 0.10200235247612, -0.0548650287091732, 0.03850864991545677, 0.16700479388237, -1.3750569820404053, -0.08511523902416229, -0.02725159376859665, -0.4032372832298279, -0.6758159399032593, -0.1930159330368042, -0.09288192540407181, 0.3543173372745514, 0.686080813407898, 0.39208200573921204, -0.011036578565835953, -0.05986612290143967, -0.2303474396467209, -0.3637860417366028, 0.17015764117240906, 0.0865301787853241, 0.06409253180027008, 0.03617081046104431, -0.20874415338039398, 0.18845264613628387, -0.4128386378288269, 0.44660210609436035, 0.68758946

In [7]:
example_embedding = train_dataset_with_embeddings[0]['embeddings']
print(f"Example embedding shape: {example_embedding}")

Example embedding shape: [0.32368314266204834, -0.0077776312828063965, 0.267245888710022, -0.19161367416381836, 0.053290851414203644, -0.36559200286865234, 0.13718417286872864, 0.6847219467163086, -0.6552464365959167, -0.01993478834629059, 0.20937520265579224, -0.4695011377334595, 0.10200235247612, -0.0548650287091732, 0.03850864991545677, 0.16700479388237, -1.3750569820404053, -0.08511523902416229, -0.02725159376859665, -0.4032372832298279, -0.6758159399032593, -0.1930159330368042, -0.09288192540407181, 0.3543173372745514, 0.686080813407898, 0.39208200573921204, -0.011036578565835953, -0.05986612290143967, -0.2303474396467209, -0.3637860417366028, 0.17015764117240906, 0.0865301787853241, 0.06409253180027008, 0.03617081046104431, -0.20874415338039398, 0.18845264613628387, -0.4128386378288269, 0.44660210609436035, 0.6875894665718079, 1.1724520921707153, -0.018593020737171173, 0.06301946938037872, 0.24696901440620422, 0.12024759501218796, -0.3570813238620758, 0.06043991446495056, 0.61401

In [9]:
# Keep track of file paths (useful for FAISS index later)
# Add the file paths as a new column
def add_filepath(example):
    # The 'image' field in datasets loaded with imagefolder contains the path
    # Check if 'path' attribute exists, common in newer versions or specific configurations
    if hasattr(example['image'], 'filename'):
         example['filepath'] = example['image'].filename
    # Fallback if filename attribute isn't directly available (might need adjustment based on dataset object structure)
    # elif isinstance(example['image'], dict) and 'path' in example['image']:
    #    example['filepath'] = example['image']['path']
    else:
         # If the structure is different, you might need to inspect example['image']
         # For PIL images loaded directly, the path might not be automatically stored this way.
         # load_dataset("imagefolder") usually stores it.
         print(f"Warning: Could not determine file path for an image. Image data type: {type(example['image'])}")
         example['filepath'] = None # Or handle as appropriate
    return example

train_dataset_with_embeddings = train_dataset_with_embeddings.map(add_filepath)

print("Filepaths added:")
print(train_dataset_with_embeddings)
# Check an example:
# print(f"Example filepath: {train_dataset_with_embeddings[0]['filepath']}")

Map: 100%|██████████| 7935/7935 [01:10<00:00, 113.19 examples/s]

Filepaths added:
Dataset({
    features: ['image', 'embeddings', 'filepath'],
    num_rows: 7935
})





In [10]:
# Cell 3: Create and Populate FAISS Index
import faiss
import numpy as np
import os

# --- Extract Embeddings and Filepaths ---
print("Extracting embeddings and filepaths...")
# Ensure embeddings are in a NumPy array of type float32, which FAISS prefers
embeddings = np.array(train_dataset_with_embeddings["embeddings"], dtype=np.float32)
filepaths = train_dataset_with_embeddings["filepath"]

# --- Check Dimensions ---
if embeddings.ndim == 1:
    # Handle case where map might have returned a list of lists instead of a 2D array
    # This can sometimes happen depending on batching/error handling nuances
    print("Warning: Embeddings array seems 1D, attempting to reshape or stack.")
    # Example check: are all elements arrays themselves?
    if all(isinstance(e, np.ndarray) for e in embeddings):
         embeddings = np.vstack(embeddings).astype(np.float32)
    else:
         raise ValueError("Embeddings are not in the expected 2D array format. Check the compute_embeddings function.")

if embeddings.shape[0] != len(filepaths):
    raise ValueError(f"Mismatch between number of embeddings ({embeddings.shape[0]}) and filepaths ({len(filepaths)}). Check data processing.")

print(f"Embeddings shape: {embeddings.shape}") # Should be (num_images, embedding_dim)
embedding_dim = embeddings.shape[1]

# --- Create FAISS Index ---
# Using IndexFlatL2 for exact search with L2 distance (Euclidean distance)
# Other index types exist for approximate nearest neighbors (faster, less exact)
print(f"Creating FAISS index (IndexFlatL2) with dimension {embedding_dim}...")
index = faiss.IndexFlatL2(embedding_dim)

# --- Add Embeddings to Index ---
print(f"Adding {embeddings.shape[0]} embeddings to the index...")
index.add(embeddings)
print(f"Index populated. Total vectors in index: {index.ntotal}")

# --- Save Index and Filepaths ---
index_filename = "faiss_clip_index.bin"
filepaths_filename = "filepaths.list" # Simple text file for filepaths

print(f"Saving FAISS index to {index_filename}...")
faiss.write_index(index, index_filename)

print(f"Saving filepaths to {filepaths_filename}...")
with open(filepaths_filename, 'w') as f:
    for path in filepaths:
        f.write(f"{path}\n")

print("Index and filepaths saved successfully.")

Extracting embeddings and filepaths...
Embeddings shape: (7935, 512)
Creating FAISS index (IndexFlatL2) with dimension 512...
Adding 7935 embeddings to the index...
Index populated. Total vectors in index: 7935
Saving FAISS index to faiss_clip_index.bin...
Saving filepaths to filepaths.list...
Index and filepaths saved successfully.


In [None]:
# Cell 4: Test Similarity Search

import faiss
import numpy as np
from PIL import Image
import random # To pick a random query image
import os

# --- Load Index and Filepaths ---
index_filename = "faiss_clip_index.bin"
filepaths_filename = "filepaths.list"

print(f"Loading FAISS index from {index_filename}...")
index = faiss.read_index(index_filename)
print(f"Index loaded. Total vectors: {index.ntotal}")

print(f"Loading filepaths from {filepaths_filename}...")
with open(filepaths_filename, 'r') as f:
    loaded_filepaths = [line.strip() for line in f.readlines()]
print(f"Loaded {len(loaded_filepaths)} filepaths.")

# --- Search Function ---
def find_similar_images(image_path, top_k=5):
    """
    Takes an image path, computes its embedding, searches the FAISS index,
    and returns the filepaths of the top_k most similar images.
    """
    print(f"\nSearching for images similar to: {image_path}")
    try:
        # Load and process the query image
        query_image = Image.open(image_path).convert("RGB")
        inputs = processor(text=None, images=query_image, return_tensors="pt", padding=True, truncation=True)

        # Move inputs to the correct device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Get embedding
        with torch.no_grad():
            query_embedding = model.get_image_features(**inputs)

        # Move to CPU and ensure float32
        query_embedding = query_embedding.cpu().numpy().astype(np.float32)

        # FAISS Search
        # index.search returns distances (D) and indices (I) of neighbors
        distances, indices = index.search(query_embedding, top_k)

        # Retrieve filepaths using indices
        results = []
        print("Found similar images:")
        for i in range(top_k):
            neighbor_index = indices[0][i]
            distance = distances[0][i]
            filepath = loaded_filepaths[neighbor_index]
            results.append({"filepath": filepath, "distance": distance})
            print(f"  - Index: {neighbor_index}, Distance: {distance:.4f}, Path: {filepath}")

        return results

    except FileNotFoundError:
        print(f"Error: Query image not found at {image_path}")
        return []
    except Exception as e:
        print(f"An error occurred during search: {e}")
        return []

# --- Example Usage ---
# Pick a random image from the loaded filepaths list to use as a query
if loaded_filepaths:
    query_image_index = random.randint(0, len(loaded_filepaths) - 1)
    query_image_path = loaded_filepaths[query_image_index]

    # Make sure the path exists (sometimes paths in datasets can be relative/absolute issues)
    if os.path.exists(query_image_path):
         # Perform the search
         similar_images = find_similar_images(query_image_path, top_k=5)

         # Optional: Display the query image and top result (requires matplotlib)
         try:
             import matplotlib.pyplot as plt

            #  if similar_images:
            #      # Display Query Image
            #      plt.figure(figsize=(10, 5))
            #      plt.subplot(1, 2, 1)
            #      plt.imshow(Image.open(query_image_path))
            #      plt.title(f"Query Image:\n{os.path.basename(query_image_path)}")
            #      plt.axis('off')

            #      # Display Top Result
            #      top_result_path = similar_images[0]['filepath']
            #      if os.path.exists(top_result_path):
            #           plt.subplot(1, 2, 2)
            #           plt.imshow(Image.open(top_result_path))
            #           plt.title(f"Top Result (Dist: {similar_images[0]['distance']:.4f}):\n{os.path.basename(top_result_path)}")
            #           plt.axis('off')
            #      else:
            #           print(f"Warning: Top result image path not found: {top_result_path}")


            #      plt.tight_layout()
            #      plt.show()
                 
            #  else:
            #      print("No similar images found to display.")

         except ImportError:
             print("\nInstall matplotlib (pip install matplotlib) to display images.")
         except Exception as e:
             print(f"Could not display images due to error: {e}")
    else:
        print(f"Error: The randomly selected query image path does not exist: {query_image_path}")
        print("Please ensure the paths in filepaths.list are correct and accessible.")

else:
    print("No filepaths loaded, cannot run example search.")
