In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import chromadb
import numpy as np
import os

In [15]:
# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


In [17]:
def load_image(image_path):
    image = Image.open(image_path)
    return image

def compute_image_embeddings(image):
    # Preprocess the image (resize, normalize, etc.)
    inputs = processor(images=image, return_tensors="pt")
    # Generate image embeddings
    with torch.no_grad():
        image_embeddings = model.get_image_features(**inputs)

    return image_embeddings[0]


# Function to calculate text embeddings
def compute_text_embedding(text):
    # Preprocess the text
    inputs = processor(text=[text], return_tensors="pt", padding=True)

    # Get text embeddings
    with torch.no_grad():
        text_embeddings = model.get_text_features(**inputs)

    return text_embeddings[0]


# Convert tensor to list if needed
def tensor_to_list(embedding_tensor):
    if not isinstance(embedding_tensor, (list, np.ndarray)):
        embedding_tensor = embedding_tensor.tolist()  # Convert tensor to list
    return embedding_tensor


In [None]:
# Initialize Chroma client
client = chromadb.PersistentClient()

# Create or get a collection to store the embeddings and metadata
collection = client.create_collection(name="image_embeddings", get_or_create=True)

collection.count()

In [19]:
# Store image, embeddings, and metadata in ChromaDB
def store_image_in_db(image_path, image_id, metadata=None):
    image = load_image(image_path)
    embedding_tensor = compute_image_embeddings(image)
    embedding = tensor_to_list(embedding_tensor)

    # Store the image embedding and metadata in ChromaDB
    collection.add(
        documents=[image_path],  # You can store the image path or a URL as a reference
        embeddings=[embedding],  # Store the embedding vector here
        metadatas=[metadata or {}],  # Metadata can be any additional information, like tags
        ids=[image_id]  # Unique ID for the image
    )
    print(f"Image {image_id} stored successfully.")

In [None]:
# Define the directory path
directory = 'images/'

# Loop through all files in the directory
for filename in os.listdir(directory):
    # Construct full file path
    file_path = os.path.join(directory, filename)
    
    # Check if it's a file (not a directory)
    if os.path.isfile(file_path):
        print(f"Processing file: {file_path}")
        store_image_in_db(file_path,filename, metadata={"type": "image"})


In [21]:
# validation with image

test_image_path = 'images/00000001_020.jpg'

image = load_image(test_image_path)
embedding_tensor = compute_image_embeddings(image)
embedding = tensor_to_list(embedding_tensor)


In [22]:
query_text = 'Dog '

embedding_tensor = compute_text_embedding(query_text)
embedding = tensor_to_list(embedding_tensor)

In [None]:
results = collection.query(query_embeddings=[embedding], n_results=3)

results