In [1]:
pip install torch torchvision faiss-cpu open-clip-torch transformers scikit-learn


Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Collecting open-clip-torch
  Downloading open_clip_torch-2.32.0-py3-none-any.whl.metadata (31 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch

# Final

In [40]:
import torch
import open_clip
import faiss
import numpy as np
from torchvision import transforms
from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights
from PIL import Image
import os
from typing import List, Tuple, Optional, Union
import json

class HybridSearchSystem:
    def __init__(self, device=None):
        """
        Initialize the hybrid search system with CLIP and DINOv2 models
        """
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load CLIP model
        print("Loading CLIP model...")
        self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
            'ViT-B-32', pretrained='laion2b_s34b_b79k'
        )
        self.clip_tokenizer = open_clip.get_tokenizer('ViT-B-32')
        self.clip_model = self.clip_model.to(self.device)
        self.clip_model.eval()
        
        # Load DINOv2 (using ViT as substitute)
        print("Loading DINOv2 model...")
        self.dino_model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        self.dino_model = self.dino_model.to(self.device)
        self.dino_model.eval()
        
        # Initialize separate indices
        self.text_index = None
        self.image_index = None
        self.hybrid_index = None
        
        # Initialize variables
        self.image_paths = []
        self.text_descriptions = []
        self.id_map = []
        
        print(f"Models loaded successfully on {self.device}")
    
    def embed_text(self, text: str) -> torch.Tensor:
        """Embed text using CLIP"""
        tokens = self.clip_tokenizer([text]).to(self.device)
        with torch.no_grad():
            embedding = self.clip_model.encode_text(tokens).squeeze().cpu()
        return embedding / embedding.norm()  # Normalize
    
    def embed_clip_image(self, image_path: str) -> torch.Tensor:
        """Embed image using CLIP"""
        try:
            image = self.clip_preprocess(Image.open(image_path)).unsqueeze(0).to(self.device)
            with torch.no_grad():
                embedding = self.clip_model.encode_image(image).squeeze().cpu()
            return embedding / embedding.norm()  # Normalize
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")
            return None
    
    def embed_dino_image(self, image_path: str) -> torch.Tensor:
        """Embed image using DINOv2 (ViT)"""
        try:
            weights = ViT_B_16_Weights.IMAGENET1K_V1
            preprocess = weights.transforms()
            image = preprocess(Image.open(image_path).convert("RGB")).unsqueeze(0).to(self.device)
            with torch.no_grad():
                embedding = self.dino_model(image).squeeze().cpu()
            return embedding / embedding.norm()  # Normalize
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")
            return None
    
    def embed_hybrid_image(self, image_path: str) -> torch.Tensor:
        """Create hybrid embedding combining CLIP and DINOv2"""
        clip_vec = self.embed_clip_image(image_path)
        dino_vec = self.embed_dino_image(image_path)
        
        if clip_vec is None or dino_vec is None:
            return None
            
        # Concatenate embeddings
        hybrid_vec = torch.cat([clip_vec, dino_vec])
        return hybrid_vec / hybrid_vec.norm()  # Normalize
    
    def build_database(self, image_paths: List[str], text_descriptions: List[str]):
        """
        Build separate search databases for different query types
        
        Args:
            image_paths: List of paths to images
            text_descriptions: List of text descriptions corresponding to images
        """
        if len(image_paths) != len(text_descriptions):
            raise ValueError("Number of images and descriptions must match")
        
        self.image_paths = image_paths
        self.text_descriptions = text_descriptions
        self.id_map = list(range(len(image_paths)))
        
        print(f"Building database with {len(image_paths)} items...")
        
        # Storage for different types of embeddings
        text_embeddings = []
        image_embeddings = []
        hybrid_embeddings = []
        valid_indices = []
        
        for i, (img_path, desc) in enumerate(zip(image_paths, text_descriptions)):
            print(f"Processing item {i+1}/{len(image_paths)}: {os.path.basename(img_path)}")
            
            # Get individual embeddings
            img_embedding = self.embed_hybrid_image(img_path)
            text_embedding = self.embed_text(desc)
            
            if img_embedding is None:
                print(f"Skipping item {i} due to image processing error")
                continue
            
            # Store embeddings for different indices
            text_embeddings.append(text_embedding.numpy())
            image_embeddings.append(img_embedding.numpy())
            
            # Create hybrid embedding (image + text)
            hybrid_embedding = torch.cat([img_embedding, text_embedding])
            hybrid_embedding = hybrid_embedding / hybrid_embedding.norm()
            hybrid_embeddings.append(hybrid_embedding.numpy())
            
            valid_indices.append(i)
        
        if not text_embeddings:
            raise ValueError("No valid embeddings could be created")
        
        # Update lists to only include valid items
        self.image_paths = [self.image_paths[i] for i in valid_indices]
        self.text_descriptions = [self.text_descriptions[i] for i in valid_indices]
        self.id_map = list(range(len(self.image_paths)))
        
        # Build separate FAISS indices
        
        # 1. Text-only index
        text_embeddings_array = np.stack(text_embeddings).astype('float32')
        text_dimension = text_embeddings_array.shape[1]
        self.text_index = faiss.IndexFlatIP(text_dimension)
        self.text_index.add(text_embeddings_array)
        
        # 2. Image-only index
        image_embeddings_array = np.stack(image_embeddings).astype('float32')
        image_dimension = image_embeddings_array.shape[1]
        self.image_index = faiss.IndexFlatIP(image_dimension)
        self.image_index.add(image_embeddings_array)
        
        # 3. Hybrid index
        hybrid_embeddings_array = np.stack(hybrid_embeddings).astype('float32')
        hybrid_dimension = hybrid_embeddings_array.shape[1]
        self.hybrid_index = faiss.IndexFlatIP(hybrid_dimension)
        self.hybrid_index.add(hybrid_embeddings_array)
        
        print(f"Database built successfully with {len(text_embeddings)} items")
        print(f"Text embedding dimension: {text_dimension}")
        print(f"Image embedding dimension: {image_dimension}")
        print(f"Hybrid embedding dimension: {hybrid_dimension}")
    
    def search(self, image_path: Optional[str] = None, 
               text_query: Optional[str] = None, 
               k: int = 5, 
               similarity_threshold: float = 0.0) -> List[Tuple[str, str, float]]:
        """
        Search for similar items using appropriate index based on query type
        
        Args:
            image_path: Path to query image (optional)
            text_query: Text query (optional)
            k: Number of results to return
            similarity_threshold: Minimum similarity score (0-1)
            
        Returns:
            List of tuples (image_path, description, similarity_score)
        """
        if not image_path and not text_query:
            raise ValueError("Must provide either image_path or text_query")
        
        if self.text_index is None:
            raise ValueError("Database not built. Call build_database() first.")
        
        # Route to appropriate index based on query type
        if text_query and not image_path:
            # Text-only search
            print(f"Processing query text: '{text_query}'")
            text_embedding = self.embed_text(text_query)
            query_embedding = text_embedding.numpy().reshape(1, -1).astype('float32')
            similarities, indices = self.text_index.search(query_embedding, k)
            
        elif image_path and not text_query:
            # Image-only search
            print(f"Processing query image: {os.path.basename(image_path)}")
            img_embedding = self.embed_hybrid_image(image_path)
            if img_embedding is None:
                raise ValueError("Could not process query image")
            query_embedding = img_embedding.numpy().reshape(1, -1).astype('float32')
            similarities, indices = self.image_index.search(query_embedding, k)
            
        else:
            # Hybrid search (image + text)
            print(f"Processing hybrid query - Image: {os.path.basename(image_path)}, Text: '{text_query}'")
            img_embedding = self.embed_hybrid_image(image_path)
            text_embedding = self.embed_text(text_query)
            
            if img_embedding is None:
                raise ValueError("Could not process query image")
            
            # Create hybrid query embedding
            hybrid_embedding = torch.cat([img_embedding, text_embedding])
            hybrid_embedding = hybrid_embedding / hybrid_embedding.norm()
            query_embedding = hybrid_embedding.numpy().reshape(1, -1).astype('float32')
            similarities, indices = self.hybrid_index.search(query_embedding, k)
        
        # Format results
        results = []
        for i, (similarity, idx) in enumerate(zip(similarities[0], indices[0])):
            if similarity >= similarity_threshold:
                results.append((
                    self.image_paths[idx],
                    self.text_descriptions[idx],
                    float(similarity)
                ))
        
        return results
    
    def find_best_match(self, image_path: Optional[str] = None, 
                       text_query: Optional[str] = None,
                       similarity_threshold: float = 0.7) -> Optional[Tuple[str, str, float]]:
        """
        Find the best match above the similarity threshold
        
        Args:
            image_path: Path to query image (optional)
            text_query: Text query (optional)
            similarity_threshold: Minimum similarity score for a valid match
            
        Returns:
            Tuple (image_path, description, similarity_score) or None if no good match
        """
        results = self.search(image_path, text_query, k=1, similarity_threshold=similarity_threshold)
        
        if results:
            return results[0]
        else:
            return None
    
    def save_database(self, filepath: str):
        """Save all databases to disk"""
        if self.text_index is None:
            raise ValueError("No database to save")
        
        # Save FAISS indices
        faiss.write_index(self.text_index, f"{filepath}_text.faiss")
        faiss.write_index(self.image_index, f"{filepath}_image.faiss")
        faiss.write_index(self.hybrid_index, f"{filepath}_hybrid.faiss")
        
        # Save metadata
        metadata = {
            'image_paths': self.image_paths,
            'text_descriptions': self.text_descriptions,
            'id_map': self.id_map
        }
        
        with open(f"{filepath}_metadata.json", 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"Database saved to {filepath}")
    
    def load_database(self, filepath: str):
        """Load all databases from disk"""
        # Load FAISS indices
        self.text_index = faiss.read_index(f"{filepath}_text.faiss")
        self.image_index = faiss.read_index(f"{filepath}_image.faiss")
        self.hybrid_index = faiss.read_index(f"{filepath}_hybrid.faiss")
        
        # Load metadata
        with open(f"{filepath}_metadata.json", 'r') as f:
            metadata = json.load(f)
        
        self.image_paths = metadata['image_paths']
        self.text_descriptions = metadata['text_descriptions']
        self.id_map = metadata['id_map']
        
        print(f"Database loaded from {filepath}")


# Example usage
def main():
    # Initialize the search system
    search_system = HybridSearchSystem()
    
    # Example database - replace with your actual paths and descriptions
    image_paths = [
        '/kaggle/input/images/image1.jpg',
        '/kaggle/input/images/image2.jpeg', 
        '/kaggle/input/images/image3.png',
        '/kaggle/input/images2/image4.jpg'
    ]
    
    text_descriptions = [
        'red sports car in parking lot',
        'blue ocean waves crashing on beach',
        'historic stone building with arched windows',
        'blue travel suitcase with front pocket'
    ]
    
    # Build the database
    try:
        search_system.build_database(image_paths, text_descriptions)
        
        # Example queries
        print("\n" + "="*50)
        print("SEARCH EXAMPLES")
        print("="*50)
        
        # Query with text only
        print("\n1. Text-only query:")
        results = search_system.search(text_query="blue suitcase", k=3)
        for i, (img_path, desc, score) in enumerate(results, 1):
            print(f"  {i}. {os.path.basename(img_path)} - {desc} (Score: {score:.3f})")
        
        # Query with image only
        print("\n2. Image-only query:")
        query_image = "/kaggle/input/images2/query.jpg"  # Replace with actual path
        if os.path.exists(query_image):
            results = search_system.search(image_path=query_image, k=3)
            for i, (img_path, desc, score) in enumerate(results, 1):
                print(f"  {i}. {os.path.basename(img_path)} - {desc} (Score: {score:.3f})")
        
        # Query with both image and text
        print("\n3. Hybrid query (image + text):")
        if os.path.exists(query_image):
            results = search_system.search(
                image_path=query_image, 
                text_query="travel luggage", 
                k=3
            )
            for i, (img_path, desc, score) in enumerate(results, 1):
                print(f"  {i}. {os.path.basename(img_path)} - {desc} (Score: {score:.3f})")
        
        # Find best match with threshold
        print("\n4. Best match with threshold:")
        best_match = search_system.find_best_match(
            text_query="blue suitcase", 
            similarity_threshold=0.7
        )
        
        if best_match:
            img_path, desc, score = best_match
            print(f"  Best match: {os.path.basename(img_path)} - {desc} (Score: {score:.3f})")
        else:
            print("  No good match found above threshold")
        
        # Save database (optional)
        # search_system.save_database("my_search_database")
        
    except Exception as e:
        print(f"Error: {e}")


if __name__ == "__main__":
    main()


Loading CLIP model...
Loading DINOv2 model...
Models loaded successfully on cuda
Building database with 4 items...
Processing item 1/4: image1.jpg
Processing item 2/4: image2.jpeg
Processing item 3/4: image3.png
Processing item 4/4: image4.jpg
Database built successfully with 4 items
Text embedding dimension: 512
Image embedding dimension: 1512
Hybrid embedding dimension: 2024

SEARCH EXAMPLES

1. Text-only query:
Processing query text: 'blue suitcase'
  1. image4.jpg - blue travel suitcase with front pocket (Score: 0.922)
  2. image2.jpeg - blue ocean waves crashing on beach (Score: 0.535)
  3. image3.png - historic stone building with arched windows (Score: 0.366)

2. Image-only query:
Processing query image: query.jpg
  1. image4.jpg - blue travel suitcase with front pocket (Score: 0.832)
  2. image3.png - historic stone building with arched windows (Score: 0.145)
  3. image1.jpg - red sports car in parking lot (Score: 0.053)

3. Hybrid query (image + text):
Processing hybrid query 

In [46]:
# Example usage
def main():
    # Initialize the search system
    search_system = HybridSearchSystem()
    
    # Example database - replace with your actual paths and descriptions
    image_paths = [
        '/kaggle/input/images/image1.jpg',
        '/kaggle/input/images/image2.jpeg', 
        '/kaggle/input/images/image3.png',
        '/kaggle/input/images2/image4.jpg'
    ]
    
    text_descriptions = [
        'red sports car in parking lot',
        'blue ocean waves crashing on beach',
        'historic stone building with arched windows',
        'blue travel suitcase with front pocket'
    ]
    
    # Build the database
    try:
        search_system.build_database(image_paths, text_descriptions)
        
        # Example queries
        print("\n" + "="*50)
        print("SEARCH EXAMPLES")
        print("="*50)
        
        # Query with text only
        print("\n1. Text-only query:")
        results = search_system.search(text_query="grey suitcase", k=3)
        for i, (img_path, desc, score) in enumerate(results, 1):
            print(f"  {i}. {os.path.basename(img_path)} - {desc} (Score: {score:.3f})")
        
        # Query with image only
        print("\n2. Image-only query:")
        query_image = "/kaggle/input/dinodino/q2.jpg"  # Replace with actual path
        if os.path.exists(query_image):
            results = search_system.search(image_path=query_image, k=3)
            for i, (img_path, desc, score) in enumerate(results, 1):
                print(f"  {i}. {os.path.basename(img_path)} - {desc} (Score: {score:.3f})")
        
        # Query with both image and text
        print("\n3. Hybrid query (image + text):")
        if os.path.exists(query_image):
            results = search_system.search(
                image_path=query_image, 
                text_query="travel luggage", 
                k=3
            )
            for i, (img_path, desc, score) in enumerate(results, 1):
                print(f"  {i}. {os.path.basename(img_path)} - {desc} (Score: {score:.3f})")
        
        # Find best match with threshold
        print("\n4. Best match with threshold:")
        best_match = search_system.find_best_match(
            image_path=query_image, 
            text_query="travel luggage", 
            similarity_threshold=0.8
        )
        
        if best_match:
            img_path, desc, score = best_match
            print(f"  Best match: {os.path.basename(img_path)} - {desc} (Score: {score:.3f})")
        else:
            print("  No good match found above threshold")
        
        # Save database (optional)
        # search_system.save_database("my_search_database")
        
    except Exception as e:
        print(f"Error: {e}")


if __name__ == "__main__":
    main()


Loading CLIP model...
Loading DINOv2 model...
Models loaded successfully on cuda
Building database with 4 items...
Processing item 1/4: image1.jpg
Processing item 2/4: image2.jpeg
Processing item 3/4: image3.png
Processing item 4/4: image4.jpg
Database built successfully with 4 items
Text embedding dimension: 512
Image embedding dimension: 1512
Hybrid embedding dimension: 2024

SEARCH EXAMPLES

1. Text-only query:
Processing query text: 'grey suitcase'
  1. image4.jpg - blue travel suitcase with front pocket (Score: 0.783)
  2. image3.png - historic stone building with arched windows (Score: 0.412)
  3. image2.jpeg - blue ocean waves crashing on beach (Score: 0.366)

2. Image-only query:
Processing query image: q2.jpg
  1. image4.jpg - blue travel suitcase with front pocket (Score: 0.717)
  2. image1.jpg - red sports car in parking lot (Score: 0.063)
  3. image2.jpeg - blue ocean waves crashing on beach (Score: 0.054)

3. Hybrid query (image + text):
Processing hybrid query - Image: q2