In [None]:
import torch
import clip
from PIL import Image
import faiss
import numpy as np
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)

def encode_images(image_paths):
    image_features_list = []
    with torch.no_grad():
        for img_path in image_paths:
            image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
            image_features = model.encode_image(image)
            image_features = torch.nn.functional.normalize(image_features, p=2, dim=1)
            image_features_list.append(image_features.cpu().numpy())
    return np.vstack(image_features_list).astype('float32')

def encode_text(text_queries):
    text = clip.tokenize(text_queries).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text)
        text_features = torch.nn.functional.normalize(text_features, p=2, dim=1)
    return text_features.cpu().numpy().astype('float32')

def create_faiss_index(features):
    d = features.shape[1]
    index = faiss.IndexFlatIP(d)  # Inner product index (cosine similarity on normalized vectors)
    index.add(features)
    return index

def save_faiss_index(index, file_path):
    faiss.write_index(index, file_path)

def load_faiss_index(file_path):
    return faiss.read_index(file_path)

def query_index(index, query_features, image_paths, top_k=3):
    D, I = index.search(query_features, top_k)
    results = []
    for i, query in enumerate(query_features):
        query_results = []
        for rank, (idx, score) in enumerate(zip(I[i], D[i])):
            query_results.append((image_paths[idx], score))
        results.append(query_results)
    return results

In [None]:
image_paths = []
for root, dirs, files in os.walk("./results/keyframes/"):
    for file in files:
        if file.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_paths.append(os.path.join(root, file))

image_features_np = encode_images(image_paths)
index = create_faiss_index(image_features_np)

# Save index
save_faiss_index(index, "image_index.faiss")

# Load index
index = load_faiss_index("image_index.faiss")

# Query
text_queries = ["a white and brown cat"]
text_features_np = encode_text(text_queries)
results = query_index(index, text_features_np, image_paths, top_k=3)

for i, query in enumerate(text_queries):
    print(f"\nQuery: {query}")
    for rank, (path, score) in enumerate(results[i]):
        print(f"Rank {rank + 1}: Image = {path}, Similarity = {score:.4f}")