# Experiment 2: FAISS Retrieval

In [None]:
import os
import time

import cv2
import pandas as pd
import matplotlib.pyplot as plt
import torch
import faiss
import numpy as np

from tqdm.notebook import tqdm
from torchvision import transforms
from PIL import Image
from ultralytics import YOLO

from face_alignment import align


In [None]:
# Load the CSV file
df = pd.read_csv("dataset/IMDb-Face_clean_unique.csv")


In [None]:
sample_df = df.head(1000)


In [None]:
# Define preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])


In [None]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load face embedding model
model_face_embedding = torch.hub.load('otroshi/edgeface', 'edgeface_s_gamma_05', source='github', pretrained=True)

# Move the model to the GPU if available
model_face_embedding = model_face_embedding.to(device)

# Set model to eval
model_face_embedding.eval()

print(f"Model is loaded on {device}")

In [None]:
# Load YOLO model
model_yolo = YOLO("yolov11s-face.pt")


In [None]:
# for index, row in sample_df.iterrows():
#     image_path = os.path.join("dataset", "images", row["index"], row["image"])  # Construct the file path

#     # Load the image with OpenCV
#     image = cv2.imread(image_path)

#     if image is not None:
#         # Extract height and width from the row
#         height, width = map(int, row["height width"].split())  # Assuming height and width are stored as space-separated values

#         # Resize the image based on the given height and width
#         resized_image = cv2.resize(image, (width, height))

#         # Run YOLO inference
#         results = model_yolo(resized_image)  

#         # Extract bounding boxes
#         if len(results) > 0:
#             boxes = results[0].boxes  # Get detected bounding boxes

#             if len(boxes) > 0:
#                 # Extract CSV face coordinates
#                 x1_csv, y1_csv, x2_csv, y2_csv = map(int, row["rect"].split())

#                 # Compute CSV face center
#                 cx_csv = (x1_csv + x2_csv) / 2
#                 cy_csv = (y1_csv + y2_csv) / 2

#                 closest_box = None
#                 min_distance = float("inf")

#                 # Iterate over detected faces
#                 for box in boxes:
#                     x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()

#                     # Compute center of detected face
#                     cx_det = (x1 + x2) / 2
#                     cy_det = (y1 + y2) / 2

#                     # Compute Euclidean distance
#                     distance = np.sqrt((cx_det - cx_csv) ** 2 + (cy_det - cy_csv) ** 2)

#                     # Update the closest face
#                     if distance < min_distance:
#                         min_distance = distance
#                         closest_box = (int(x1), int(y1), int(x2), int(y2))

#                 # Crop the closest face
#                 if closest_box:
#                     x1, y1, x2, y2 = closest_box

#                     margin = 30
#                     h, w, _ = resized_image.shape  # Get image dimensions

#                     # Clip coordinates to stay within image bounds
#                     x1 = max(0, x1 - margin)
#                     y1 = max(0, y1 - margin)
#                     x2 = min(w, x2 + margin)
#                     y2 = min(h, y2 + margin)

#                     face_image = resized_image[y1:y2, x1:x2]

#                     # Plot Image
#                     image_rgb = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
#                     plt.imshow(image_rgb)
#                     plt.axis("off")  # Hide axis
#                     plt.show()

#                     # Get Embedding
#                     # Convert the OpenCV image (BGR) to PIL image (RGB)
#                     pil_image = Image.fromarray(cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB))
#                     aligned = align.get_aligned_face(None, pil_image) # align face

#                     # Check If alignment result good
#                     if aligned is not None:
                        
#                         # Plot Image
#                         plt.imshow(aligned)
#                         plt.axis("off")  # Hide axis
#                         plt.show()

#                         transformed_input = transform(aligned).unsqueeze(0).to(device) # preprocessing

#                         # extract embedding
#                         face_embedding = model_face_embedding(transformed_input)


In [None]:
# Create a FAISS index (using L2 distance for simplicity)
dimension = 512  # Adjust this based on the size of your embeddings (e.g., 128 or 512)
faiss_index = faiss.IndexFlatL2(dimension)  # Using L2 distance

# List to store image paths and corresponding embeddings
image_paths = []
embeddings_list = []

# Process each image
for index, row in tqdm(df.iterrows(), desc="Processing images", total=len(df)):
    image_path = os.path.join("dataset", "images", row["index"], row["image"])

    # Load the image with OpenCV
    image = cv2.imread(image_path)

    if image is not None:
        # Get embedding for the image
        x1, y1, x2, y2 = map(int, row["rect"].split())  # Split 'rect' and convert to integers
        height, width = map(int, row["height width"].split())  # Assuming height and width are space-separated

        resized_image = cv2.resize(image, (width, height))
        cropped_image = resized_image[y1:y2, x1:x2]  # Crop the image

        # Plot Image
        plt.imshow(cropped_image)
        plt.axis("off")  # Hide axis
        plt.show()

        pil_image = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
        aligned = align.get_aligned_face(None, pil_image)  # align face

        if aligned is not None:
            transformed_input = transform(aligned).unsqueeze(0).to(device)  # preprocessing
            face_embedding = model_face_embedding(transformed_input).cpu().detach().numpy().flatten()

            # Add the embedding to the FAISS index
            faiss_index.add(np.array([face_embedding]))  # Add the embedding to the FAISS index

            # Save the image path and embedding for later retrieval
            image_paths.append(image_path)
            embeddings_list.append(face_embedding)

# Save the FAISS index to disk
faiss.write_index(faiss_index, 'face_embeddings.index')


In [None]:
# Load the FAISS index from disk
faiss_index = faiss.read_index('face_embeddings.index')

# Example: Query with a specific face embedding (e.g., first embedding in the list)
# query_embedding = embeddings_list[0]  # Let's use the first image's embedding as the query
image_path = "face_3.png"
input_image = cv2.imread(image_path)

# Convert the image to RGB
input_image_rgb = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)

# Plot the image
plt.imshow(input_image_rgb)
plt.axis("off")  # Hide axis
plt.show()

pil_image = Image.fromarray(cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB))
aligned = align.get_aligned_face(None, pil_image)  # align face

if aligned is not None:
    transformed_input = transform(aligned).unsqueeze(0).to(device)  # preprocessing
    face_embedding = model_face_embedding(transformed_input).cpu().detach().numpy().flatten()
else:
    print("invalid alignment")

query_embedding = face_embedding

# Perform the search for the top k nearest neighbors
k = 5  # Number of nearest neighbors you want to retrieve
D, I = faiss_index.search(np.array([query_embedding]), k)  # D is the distances, I is the indices of nearest neighbors

# Output the results
print("Indices of nearest neighbors:", I)
print("Distances to nearest neighbors:", D)

# Plot images of the nearest neighbors
for idx in I[0]:
    image_path = image_paths[idx]  # Get the image path of the nearest neighbor
    image = cv2.imread(image_path)  # Load the image using OpenCV
    
    if image is not None:
        # Convert the image to RGB
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Plot the image
        plt.imshow(image_rgb)
        plt.axis("off")  # Hide axis
        plt.show()
    else:
        print(f"Failed to load image at {image_path}")
