In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from celeba_dataset import CelebADataset
from pympler import asizeof
from saad_face import SaadFace
from tqdm import tqdm
import torch.nn.functional as F


In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")


Using device: mps


In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = CelebADataset(img_dir='./data/celeba/img_align_celeba', label_file='./data/celeba/identity_CelebA.txt', transform=transform, type='embedding_gen')
embedding_generation_dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)

train_dataset = CelebADataset(img_dir='./data/celeba/img_align_celeba', label_file='./data/celeba/identity_CelebA.txt', transform=transform, type='train')
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1)

model = SaadFace(embedding_dim=128)
model.load_state_dict(torch.load('models/model.pth', map_location=device))
model.to(device)
model.eval()

In [8]:
# Generate a dictionary with the average embeddings for each identity
identity_embeddings = {}
for image_tensor, label in tqdm(embedding_generation_dataloader):
    image_tensor = torch.squeeze(image_tensor, 0)
    embedding = torch.mean(model(image_tensor.to(device)), axis=0)
    embedding = embedding.cpu().detach()
    identity_embeddings[label[0]] = embedding

100%|██████████| 5641/5641 [01:58<00:00, 47.62it/s]


In [28]:
embedding_labels = list(identity_embeddings.keys())
embeddings_matrix = torch.stack(list(identity_embeddings.values()))

In [34]:
len(embedding_labels)

5641

In [32]:
def find_nearest_label(input_vector: torch.Tensor):
    # Compute cosine similarity in batch
    similarities = F.cosine_similarity(input_vector.unsqueeze(0), embeddings_matrix)
    
    # Find the maximum similarity
    max_index = torch.argmax(similarities)
    nearest_label = embedding_labels[max_index]
    max_similarity = similarities[max_index].item()
    
    return nearest_label, max_similarity

In [52]:
def find_nearest_label_euclidean(input_vector: torch.Tensor):

    # Compute Euclidean distance in batch
    distances = torch.norm(embeddings_matrix - input_vector.unsqueeze(0), dim=1)

    # Find the minimum distance
    min_index = torch.argmin(distances)
    nearest_label = embedding_labels[min_index]
    min_distance = distances[min_index].item()

    return nearest_label, min_distance

In [55]:
dataiter = iter(train_dataloader)
distance_method = 'euclidean'
for images, labels in dataiter:
    image = images
    label = labels[0]
    embedding = model(image.to(device))
    embedding = embedding.cpu().detach()
    embedding = torch.squeeze(embedding, 0)
    if distance_method == 'cosine':
        nearest_label, similarity = find_nearest_label(embedding)
        print(f"True label: {label}, Nearest label: {nearest_label}, Similarity: {similarity}")
    elif distance_method == 'euclidean':
        nearest_label, distance = find_nearest_label_euclidean(embedding)
        print(f"True label: {label}, Nearest label: {nearest_label}, Similarity: {distance}")

        

True label: 6322, Nearest label: 1050, Similarity: 0.008085524663329124
True label: 6322, Nearest label: 9373, Similarity: 0.013651452027261257
True label: 6322, Nearest label: 1050, Similarity: 0.04406757280230522
True label: 6322, Nearest label: 7178, Similarity: 0.009626053273677826
True label: 6322, Nearest label: 10061, Similarity: 0.012456787750124931
True label: 1849, Nearest label: 3319, Similarity: 0.04280870035290718
True label: 1849, Nearest label: 2998, Similarity: 0.02322578988969326
True label: 1849, Nearest label: 8581, Similarity: 0.02117781899869442
True label: 1231, Nearest label: 3041, Similarity: 0.029993848875164986
True label: 1231, Nearest label: 6871, Similarity: 0.01548078004270792
True label: 8826, Nearest label: 6645, Similarity: 0.024024723097682
True label: 8826, Nearest label: 2089, Similarity: 0.014632211066782475
True label: 8826, Nearest label: 1127, Similarity: 0.024985607713460922
True label: 8826, Nearest label: 7897, Similarity: 0.028551040217280388

KeyboardInterrupt: 