In [1]:
import torch
import clip
from PIL import Image
import faiss
import numpy as np
import os
from tqdm import tqdm 


device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

def encode_images(image_paths):
    image_features_list = []
    with torch.no_grad():
        for img_path in tqdm(image_paths, desc="Encoding images"):
            image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
            image_features = model.encode_image(image)
            image_features = torch.nn.functional.normalize(image_features, p=2, dim=1)
            image_features_list.append(image_features.cpu().numpy())
    return np.vstack(image_features_list).astype('float32')

def encode_text(text_queries):
    text = clip.tokenize(text_queries).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text)
        text_features = torch.nn.functional.normalize(text_features, p=2, dim=1)
    return text_features.cpu().numpy().astype('float32')

def create_faiss_index(features):
    d = features.shape[1]
    index = faiss.IndexFlatIP(d)  # Inner product index (cosine similarity on normalized vectors)
    index.add(features)
    return index

def save_faiss_index(index, file_path):
    faiss.write_index(index, file_path)

def load_faiss_index(file_path):
    return faiss.read_index(file_path)

def query_index(index, query_features, image_paths, top_k=3):
    D, I = index.search(query_features, top_k)
    
    results = []
    for i, query in enumerate(query_features):
        query_results = []
        for rank, (idx, score) in enumerate(zip(I[i], D[i])):
            query_results.append((image_paths[idx], score))
        results.append(query_results)
    return results

In [4]:
'''Create FAISS index for images'''

image_paths = []
print("Scanning for image files...")
for root, dirs, files in os.walk("../../SBDresults/keyframes/"):
    for file in files:
        if file.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_paths.append(os.path.join(root, file))

print(f"Found {len(image_paths)} images. Starting encoding...")
image_features_np = encode_images(image_paths)

print("Creating FAISS index...")
index = create_faiss_index(image_features_np)

print("Saving FAISS index...")
save_faiss_index(index, "image_index.faiss")

Scanning for image files...
Found 14994 images. Starting encoding...


Encoding images: 100%|██████████| 14994/14994 [3:43:18<00:00,  1.12it/s]  

Creating FAISS index...
Saving FAISS index...





In [3]:
'''Use a loaded FAISS index to query images'''

image_paths = []
for root, dirs, files in os.walk("../SBDresults/keyframes/"):
    for file in files:
        if file.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_paths.append(os.path.join(root, file))

index = load_faiss_index("image_index.faiss")

# Query
text_queries = ["guitar"]
text_features_np = encode_text(text_queries)
results = query_index(index, text_features_np, image_paths, top_k=10)

for i, query in enumerate(text_queries):
    print(f"\nQuery: {query}")
    for rank, (path, score) in enumerate(results[i]):
        print(f"Rank {rank + 1}: Image = {path}, Similarity = {score:.4f}")


Query: guitar
Rank 1: Image = ../SBDresults/keyframes/00102/00102_24.jpg, Similarity = 0.2484
Rank 2: Image = ../SBDresults/keyframes/00016/00016_45.jpg, Similarity = 0.2438
Rank 3: Image = ../SBDresults/keyframes/00055/00055_75.jpg, Similarity = 0.2417
Rank 4: Image = ../SBDresults/keyframes/00150/00150_54.jpg, Similarity = 0.2378
Rank 5: Image = ../SBDresults/keyframes/00179/00179_25.jpg, Similarity = 0.2374
Rank 6: Image = ../SBDresults/keyframes/00052/00052_6.jpg, Similarity = 0.2364
Rank 7: Image = ../SBDresults/keyframes/00052/00052_32.jpg, Similarity = 0.2364
Rank 8: Image = ../SBDresults/keyframes/00087/00087_11.jpg, Similarity = 0.2361
Rank 9: Image = ../SBDresults/keyframes/00052/00052_11.jpg, Similarity = 0.2358
Rank 10: Image = ../SBDresults/keyframes/00179/00179_34.jpg, Similarity = 0.2357
