In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from facenet_pytorch import InceptionResnetV1
import torch
import tqdm
import numpy as np
import faiss

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedding_model = InceptionResnetV1(pretrained='vggface2', classify=False).eval()
transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
def get_mean_embeddings(dataloader, dataset, model):
    class_embeddings = {}
    class_names = dataset.classes

    class_embeddings = {
        class_idx: [] for class_idx in class_names
    }

    with torch.no_grad():
        for images, labels in tqdm.tqdm(dataloader):
            images = images.to(device)
            embeddings = model(images)
            
            for emb, label in zip(embeddings, labels):
                class_name = class_names[label.item()]
                class_embeddings[class_name].append(emb.cpu().numpy())

    avg_embeddings = {}
    for class_idx, embeddings in class_embeddings.items():
        avg_emb = np.mean(embeddings, axis=0)
        avg_embeddings[class_idx] = avg_emb
    
    return avg_embeddings

In [8]:
dataset = datasets.ImageFolder(
    root="C:/face_dataset/imdb_train_newindex",
    transform=transform
)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

In [None]:
mean_embeddings = get_mean_embeddings(dataloader, dataset, embedding_model)

In [10]:
def save_embeddings(file_path, embeddings):
    np.savez_compressed(
        file_path,
        names=np.array(list(embeddings.keys())),
        embeddings=np.stack(list(embeddings.values()))
    )
def load_embeddings(file_path):
    saved_data = np.load(file_path)
    embeddings = dict(zip(saved_data['names'], saved_data['embeddings']))
    return embeddings
def create_faiss_index(embeddings: dict) -> faiss.IndexFlatIP:
    embedding_matrix = np.stack(list(embeddings.values())).astype('float32')
    faiss.normalize_L2(embedding_matrix)
    embedding_dim = embedding_matrix.shape[1]
    index = faiss.IndexFlatIP(embedding_dim)
    index.add(embedding_matrix)
    return index

In [14]:
faiss_index = create_faiss_index(mean_embeddings)
faiss.write_index(faiss_index, "../imdb_dataset/embeddings.faiss")
loaded_index = faiss.read_index("../imdb_dataset/embeddings.faiss")

In [13]:
test_dataset = datasets.ImageFolder(
    root="C:/face_dataset/imdb_test_newindex",
    transform=transform
)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
test_embeddings = get_mean_embeddings(test_dataloader, test_dataset, embedding_model)
save_embeddings('C:/face_dataset/new_test_embeddings.npz', test_embeddings)

In [17]:
test_embeddings = load_embeddings('C:/face_dataset/new_test_embeddings.npz')

In [None]:
def top_accuracy_faiss(loaded_index, test_embeddings, top_len = 5):
    test_matrix = np.stack(list(test_embeddings.values())).astype('float32')
    names_list = list(int(name) for name in test_embeddings.keys())

    faiss.normalize_L2(test_matrix)
    
    D, I = loaded_index.search(test_matrix, top_len)

    unfinded_celebs = []
    correct = 0
    total = len(test_embeddings)
    
    true_indices = np.arange(len(test_embeddings))
    
    for i in range(len(test_matrix)):
        if np.isin(true_indices[i], I[i]).any():
            correct += 1
        else:
            unfinded_celebs.append(names_list[i])
    
    return correct / total, unfinded_celebs

In [None]:
accuracy, missing_celebs = top_accuracy_faiss(loaded_index, test_embeddings, 10)
print(f"Top-5 Accuracy: {accuracy:.2%}")

Top-5 Accuracy: 89.26%


In [None]:
import pandas as pd
celebs_df = pd.read_csv('C:/face_dataset/celeb_names.csv')
missing_df = celebs_df[celebs_df['id'].isin(missing_celebs)].copy()

In [None]:
missing_df[missing_df['number_of_images'] > 50]

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(20, 12))
missing_df['number_of_images'].value_counts().sort_index().plot(kind='bar')

for i, count in enumerate(missing_df['number_of_images'].value_counts().sort_index()):
    plt.text(i, count + 5, str(count), ha='center')
plt.show()