In [1]:
# Model configuration  
MODEL_NAME = "openai/clip-vit-base-patch32"
TEXT_MAX_LENGTH = 77

# Data file paths
EMBEDDINGS_FILE = "../data/flickr8k_embeddings.h5"

# Search configuration
DEFAULT_TOP_K = 5

# Imports
import os
import sys
import logging
import torch
import h5py
import numpy as np
import gradio as gr
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

# Initialize logging
logging.basicConfig(
    level=logging.INFO,
    format='%(levelname)s: %(message)s',
    stream=sys.stdout,
    force=True
)
logger = logging.getLogger(__name__)
logger.info("Image search interface - logging initialized")

# Device Detection and Model Loading
def get_device():
    """Check if GPU is available and set up computing device"""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Device: {device}")
    if torch.cuda.is_available():
        logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
        logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    return device

def load_model(model_name, device):
    """Load the CLIP model for processing images and text"""
    logger.info(f"Loading {model_name}...")
    model = CLIPModel.from_pretrained(model_name).to(device)
    processor = CLIPProcessor.from_pretrained(model_name)
    logger.info(f"Model loaded successfully")
    return model, processor

# Load embeddings from file
def load_embeddings(filepath: str) -> dict:
    """Load precomputed embeddings from saved file"""
    logger.info(f"Loading embeddings from {filepath}...")
    
    with h5py.File(filepath, 'r') as f:
        # Verify required data is present
        required_keys = {'image_embeddings', 'text_embeddings', 'image_names', 'captions'}
        if not required_keys.issubset(f.keys()):
            raise ValueError(f"Missing required datasets in {filepath}")
        
        # Load all the embedding data
        embeddings_data = {
            'image_embeddings': f['image_embeddings'][:],
            'text_embeddings': f['text_embeddings'][:],
            'image_names': [name.decode('utf-8') for name in f['image_names'][:]], 
            'captions': [cap.decode('utf-8') for cap in f['captions'][:]], 
            'metadata': dict(f.attrs)
        }
        
        # Check that image folder exists
        images_path = embeddings_data['metadata'].get('images_path', '')
        if not images_path:
            raise ValueError("Images path not found in metadata")
        
        if not os.path.exists(images_path):
            raise FileNotFoundError(f"Images directory not found: {images_path}")
    
    logger.info(f"Loaded {len(embeddings_data['image_names'])} embedding pairs")
    return embeddings_data

# Text to image search
def search_images_by_text_diverse(
    query_text: str, 
    embeddings_data: dict, 
    model, 
    processor, 
    device, 
    top_k: int = DEFAULT_TOP_K
    ) -> list:
    """Search for images using text, ensuring each image appears only once"""
    
    # Convert query text to embeddings
    text_inputs = processor(
        text=query_text, 
        return_tensors="pt", 
        padding=True, 
        truncation=True, 
        max_length=TEXT_MAX_LENGTH
        ).to(device)
    
    # Normalize the query embedding for similarity calculation
    with torch.no_grad():
        query_embedding = model.get_text_features(**text_inputs)
        query_embedding = query_embedding.cpu().numpy()
        query_norm = query_embedding / np.linalg.norm(query_embedding)
    
    # Calculate similarity between query and all images
    image_embeddings = np.array(embeddings_data['image_embeddings'])
    similarities = image_embeddings @ query_norm.T
    similarities = similarities.flatten()
    
    # Find the best caption for each unique image
    image_best_matches = {}
    for idx, similarity in enumerate(similarities):
        image_name = embeddings_data['image_names'][idx]
        caption = embeddings_data['captions'][idx]
        
        if image_name not in image_best_matches or similarity > image_best_matches[image_name]['similarity']:
            image_best_matches[image_name] = {
                'similarity': similarity,
                'caption': caption,
                'index': idx
            }
    
    # Sort results by similarity score
    unique_results = []
    for image_name, match_data in image_best_matches.items():
        unique_results.append({
            'image_name': image_name,
            'caption': match_data['caption'],
            'similarity_score': float(match_data['similarity']),
            'index': match_data['index']
        })
    
    unique_results.sort(key=lambda x: x['similarity_score'], reverse=True)
    
    # Add rank numbers to results
    for i, result in enumerate(unique_results[:top_k]):
        result['rank'] = i + 1
    
    return unique_results[:top_k]

# Image to text search - find similar captions for an image
def search_text_by_image(
    query_image: Image.Image,
    embeddings_data: dict,
    model,
    processor,
    device,
    top_k: int = DEFAULT_TOP_K
) -> list:
    """Find text captions that are similar to an uploaded image"""
    
    # Convert image to embeddings
    image_inputs = processor(
        images=query_image,
        return_tensors="pt",
        padding=True
    ).to(device)
    
    # Normalize the image embedding for similarity calculation
    with torch.no_grad():
        query_embedding = model.get_image_features(**image_inputs)
        query_embedding = query_embedding.cpu().numpy()
        query_norm = query_embedding / np.linalg.norm(query_embedding)
    
    # Calculate similarity between image and all text captions
    text_embeddings = np.array(embeddings_data['text_embeddings'])
    similarities = text_embeddings @ query_norm.T
    similarities = similarities.flatten()
    
    # Get the most similar captions
    top_indices = np.argsort(similarities)[::-1][:top_k]
    
    results = []
    for i, idx in enumerate(top_indices):
        results.append({
            'caption': embeddings_data['captions'][idx],
            'image_name': embeddings_data['image_names'][idx], 
            'similarity_score': float(similarities[idx]),
            'rank': i + 1,
            'index': int(idx)
        })
    
    return results

# Set up the system
device = get_device()
model, processor = load_model(MODEL_NAME, device)
embeddings_data = load_embeddings(EMBEDDINGS_FILE)
images_path = embeddings_data['metadata']['images_path']

logger.info("System initialization complete - ready for bidirectional search")

INFO: Image search interface - logging initialized
INFO: Device: cuda
INFO: GPU: NVIDIA GeForce RTX 3060 Ti
INFO: VRAM: 8.2 GB
INFO: Loading openai/clip-vit-base-patch32...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


INFO: Model loaded successfully
INFO: Loading embeddings from ../data/flickr8k_embeddings.h5...
INFO: Loaded 40455 embedding pairs
INFO: System initialization complete - ready for bidirectional search


In [2]:
# Gradio interface

def format_results_for_gallery(results: list, images_path: str):
    """Transform search results into Gradio Gallery format"""
    gallery_data = []
    result_text = f"Found {len(results)} semantically similar images:\n\n"
    
    for result in results:
        image_path = os.path.join(images_path, result['image_name'])
        if os.path.exists(image_path):
            gallery_data.append((image_path, f"Rank {result['rank']}: {result['caption'][:100]}..."))
            result_text += f"**Rank {result['rank']}** (Similarity: {result['similarity_score']:.3f})\n"
            result_text += f"Image: `{result['image_name']}`\n"
            result_text += f"Caption: _{result['caption']}_\n\n"
        else:
            logger.warning(f"Image not found: {image_path}")
    
    return gallery_data, result_text

def perform_text_to_image_search(query_text: str, num_results: int = 5):
    """Execute semantic search with proper error handling"""
    if not query_text or query_text.strip() == "":
        return [], "Please enter a search query."
    
    try:
        logger.info(f"Searching for: '{query_text}'")
        results = search_images_by_text_diverse(
            query_text=query_text.strip(),
            embeddings_data=embeddings_data,
            model=model,
            processor=processor,
            device=device,
            top_k=num_results
        )
        
        if not results:
            return [], f"No results found for query: '{query_text}'"
        
        gallery_data, result_text = format_results_for_gallery(results, images_path)
        logger.info(f"Query completed - {len(results)} results returned")
        return gallery_data, result_text
        
    except Exception as e:
        error_msg = f"Search failed: {str(e)}"
        logger.error(error_msg)
        return [], error_msg

def format_text_results(results: list) -> str:
    """Transform image-to-text results into readable format"""
    if not results:
        return "No similar text descriptions found."
    
    result_text = f"Found {len(results)} semantically related descriptions:\n\n"
    
    for result in results:
        result_text += f"**Rank {result['rank']}** (Similarity: {result['similarity_score']:.3f})\n"
        result_text += f"Caption: _{result['caption']}_\n"
        result_text += f"Source Image: `{result['image_name']}`\n\n"
    
    return result_text

def perform_image_to_text_search(image, num_results: int = 5):
    """Execute image-to-text search with comprehensive error handling"""
    if image is None:
        return "Please upload an image to search for similar text descriptions."
    
    try:
        logger.info("Processing uploaded image for text search...")
        
        # Search for similar text descriptions
        results = search_text_by_image(
            query_image=image,
            embeddings_data=embeddings_data,
            model=model,
            processor=processor,
            device=device,
            top_k=num_results
        )
        
        if not results:
            return "No semantically similar text descriptions found."
        
        formatted_results = format_text_results(results)
        logger.info(f"Image analysis completed - {len(results)} text matches found")
        return formatted_results
        
    except Exception as e:
        error_msg = f"Image analysis failed: {str(e)}"
        logger.error(error_msg)
        return error_msg

# Create the web interface

with gr.Blocks(theme=gr.themes.Soft(), title="Image Search Engine") as multimodal_demo:
    gr.Markdown("# 🔍 Image Search Engine")
    gr.Markdown("*Search images using natural language descriptions*")
    gr.Markdown("---")
    
    with gr.Tabs():
        with gr.TabItem("📝 Text → 🖼️ Images", elem_id="text2image"):
            gr.Markdown("### Search for images")
            
            with gr.Row():
                with gr.Column(scale=3):
                    text_input = gr.Textbox(
                        label="Search Query", 
                        placeholder="Enter your search query (e.g., 'a dog playing in the park')",
                        lines=2
                    )
                with gr.Column(scale=1):
                    num_results = gr.Slider(
                        minimum=1, 
                        maximum=20, 
                        value=5, 
                        step=1, 
                        label="Results Count"
                    )
            
            text_btn = gr.Button("🔍 Search Images", variant="primary", size="lg")
            
            with gr.Row():
                with gr.Column():
                    image_gallery = gr.Gallery(
                        label="Search Results",
                        show_label=True,
                        elem_id="gallery",
                        columns=3,
                        rows=2,
                        object_fit="contain",
                        height="auto"
                    )
                
            with gr.Row():
                result_details = gr.Markdown(
                    value="Enter a search query above to find relevant images.",
                    label="Search Results"
                )
            
        with gr.TabItem("🖼️ Image → 📝 Text", elem_id="image2text"):
            gr.Markdown("### Find similar text descriptions")
            gr.Markdown("*Upload an image to discover semantically related text descriptions from our dataset.*")
            
            with gr.Row():
                with gr.Column(scale=3):
                    image_input = gr.Image(
                        label="Upload Image", 
                        type="pil",
                        height=400
                    )
                with gr.Column(scale=1):
                    image_num_results = gr.Slider(
                        minimum=1,
                        maximum=20,
                        value=5,
                        step=1,
                        label="Results Count"
                    )
            
            image_btn = gr.Button("🔍 Find Similar Text", variant="primary", size="lg")
            
            with gr.Row():
                image_output = gr.Markdown(
                    value="Upload an image above to find semantically similar text descriptions.",
                    label="Similar Text Descriptions"
                )
    
    # Event Bindings
    text_btn.click(
        fn=perform_text_to_image_search,
        inputs=[text_input, num_results],
        outputs=[image_gallery, result_details]
    )
    
    # Allow Enter key to trigger search
    text_input.submit(
        fn=perform_text_to_image_search,
        inputs=[text_input, num_results],
        outputs=[image_gallery, result_details]
    )
    
    image_btn.click(
        fn=perform_image_to_text_search,
        inputs=[image_input, image_num_results],
        outputs=image_output
    )

# Launch the interface
logger.info("Starting image search interface...")
multimodal_demo.launch(inline=True, share=False)


INFO: Starting image search interface...
* Running on local URL:  http://127.0.0.1:7861
INFO: HTTP Request: GET http://127.0.0.1:7861/gradio_api/startup-events "HTTP/1.1 200 OK"
INFO: HTTP Request: HEAD http://127.0.0.1:7861/ "HTTP/1.1 200 OK"
* To create a public link, set `share=True` in `launch()`.




INFO: HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
