In [1]:
import pandas as pd
import numpy as np
import faiss
import time
import open_clip
import torch
from PIL import Image

# Load dataset
csv_path = "G:/multimodal_ai/datasets/final_captions.csv"
df = pd.read_csv(csv_path).sample(1000, random_state=42)

# Load CLIP model
model, preprocess, _ = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
tokenizer = open_clip.get_tokenizer("ViT-B-32")

# Function to extract features
def extract_features(image_path, text):
    try:
        image = Image.open(image_path).convert("RGB")
        image = preprocess(image).unsqueeze(0)
        text_tokens = tokenizer([text])

        with torch.no_grad():
            image_features = model.encode_image(image).cpu().numpy()
            text_features = model.encode_text(text_tokens).cpu().numpy()

        combined_features = (image_features + text_features) / 2
        return combined_features.flatten()

    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None

# Extract and store embeddings
embeddings = []
image_paths = []
for _, row in df.iterrows():
    image_path = f"G:/multimodal_ai/datasets/images/{row['image']}"
    caption = row['caption']
    
    features = extract_features(image_path, caption)
    if features is not None:
        embeddings.append(features)
        image_paths.append(row['image'])

embeddings = np.array(embeddings, dtype=np.float32)
print(f"Extracted {len(embeddings)} embeddings")


  from .autonotebook import tqdm as notebook_tqdm


Extracted 1000 embeddings


In [49]:
# Generate a random query vector from dataset
query_vector = embeddings[0].reshape(1, -1)

# Measure retrieval time using brute-force similarity search
start_time = time.time()
similarities = np.dot(embeddings, query_vector.T)  # Compute cosine similarity
top_k_brute = similarities.flatten().argsort()[-5:][::-1]  # Top-5 indices
end_time = time.time()

brute_force_time = end_time - start_time
print(f"Brute-Force Retrieval Time: {brute_force_time:.6f} seconds")

Brute-Force Retrieval Time: 0.001992 seconds


In [50]:
# Create FAISS Index
faiss_index = faiss.IndexFlatL2(512)  # 512-D feature space
faiss_index.add(embeddings)

# Measure retrieval time using FAISS
start_time = time.time()
D, I = faiss_index.search(query_vector, 5)  # Search for top-5 matches
end_time = time.time()

faiss_time = end_time - start_time
print(f"FAISS Retrieval Time: {faiss_time:.6f} seconds")

FAISS Retrieval Time: 0.000982 seconds


In [51]:
improvement = ((brute_force_time - faiss_time) / brute_force_time) * 100
print(f"Retrieval Speed Improvement: {improvement:.2f}%")

Retrieval Speed Improvement: 50.70%
