In [1]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50
from PIL import Image
import numpy as np
import faiss
import pickle
from concurrent.futures import ProcessPoolExecutor


In [2]:
image_folder = "data/images"

In [3]:
# Load pre-trained ResNet-50 model
model = resnet50(pretrained=True)
model = model.eval()  # Set model to evaluation mode

# Remove the final fully connected layer to use ResNet as a feature extractor
model = torch.nn.Sequential(*(list(model.children())[:-1]))

# Image preprocessing (same as used for ImageNet pretraining)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])




In [4]:
# Function to extract features from an image using ResNet-50
def extract_features(image_path):
    image = Image.open(image_path).convert('RGB')
    image = preprocess(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        features = model(image)
    return features.squeeze().numpy()  # Convert to NumPy array and remove batch dimension

In [8]:
from tqdm import tqdm
import os
import pickle

# Function to process a batch of images
def process_images_in_batches(image_paths):
    image_feature_dict = {}
    for image_path in image_paths:
        image_name = os.path.basename(image_path)
        feature_vector = extract_features(image_path)  # Feature extraction
        image_feature_dict[image_name] = feature_vector
    return image_feature_dict

# List all image files in the folder
image_list = os.listdir(image_folder)
image_paths = [os.path.join(image_folder, img_name) for img_name in image_list]

# Step 1: Batch processing without parallel execution
batch_size = 500  # You can tune this for batch processing based on your system's resources

# Dictionary to store all image names and feature vectors
all_image_features = {}

# Calculate total batches for progress tracking
total_batches = len(image_paths) // batch_size + (1 if len(image_paths) % batch_size != 0 else 0)

# Process the images in batches with a progress bar
for i in tqdm(range(0, len(image_paths), batch_size), desc="Processing batches"):
    batch = image_paths[i:i + batch_size]
    batch_features = process_images_in_batches(batch)  # Process the batch
    all_image_features.update(batch_features)  # Store the batch's results

# Save the image name to feature vector mapping as a pickle file
with open('image_feature_vectors.pkl', 'wb') as f:
    pickle.dump(all_image_features, f)

print("Feature extraction complete. All features saved to 'image_feature_vectors.pkl'.")


Processing batches:  12%|█▏        | 12/100 [25:15<3:04:44, 125.96s/it]

In [7]:
# Load the feature vectors from the pickle file
pickle_file_path = 'image_feature_vectors.pkl'

with open(pickle_file_path, 'rb') as f:
    image_features = pickle.load(f)

print(f"Loaded {len(image_features)} image features from the pickle file.")


In [None]:
# Extract image names (keys) and feature vectors (values)
image_names = list(image_features.keys())  # List of image names (primary keys)
feature_vectors = np.array(list(image_features.values())).astype('float32')  # Feature vectors as numpy array

print(f"Extracted {len(image_names)} image names and corresponding feature vectors.")


In [None]:
# Set FAISS parameters for IVF index creation
dimension = feature_vectors.shape[1]  # ResNet-50 outputs 2048-dimensional feature vectors
nlist = 100  # Number of clusters for IVF (you can tune this based on dataset size)
nprobe = 10  # Number of clusters to search (can be tuned for search performance)

# Create the IVF index with Flat L2 distance and enable HNSW for quantization
quantizer = faiss.IndexHNSWFlat(dimension, 32)  # Using HNSW for quantization
index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2)  # IVF with L2 distance

print("FAISS index and quantizer initialized.")


In [None]:
# Train the FAISS index (required for IVF)
index.train(feature_vectors)
print("FAISS index training complete.")


In [None]:
# Add vectors to the FAISS index and create mappings
faiss_id_to_image_name = {}

index.add(feature_vectors)  # Add all the feature vectors to the FAISS index
for i, image_name in enumerate(image_names):
    faiss_id_to_image_name[i] = image_name  # Store the mapping between FAISS ID and image name

print(f"Added {len(image_names)} feature vectors to the FAISS index and created ID-to-image mappings.")
