In [None]:
import os
import torch
import cv2
from typing import Optional, Tuple, List
from facenet_pytorch import MTCNN, InceptionResnetV1
import matplotlib.pyplot as plt

# Initialize the MTCNN module for face detection and the InceptionResnetV1 module for face embedding.
mtcnn = MTCNN(image_size=160, keep_all=True)
resnet = InceptionResnetV1(pretrained="vggface2").eval()


def convert_tensor_to_image(tensor: torch.Tensor):
    image = tensor.permute(1, 2, 0).detach().numpy()
    image = (image - image.min()) / (image.max() - image.min())
    image = (image * 255).astype("uint8")
    return image


def embedding_face(image) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
    faces, probs = mtcnn(image, return_prob=True)
    if faces is None or len(faces) == 0:
        return None, None

    embedding = resnet(faces[0].unsqueeze(0))
    return embedding, faces[0]


def find_most(target: str, candidates: List[str]):
    embedding_target, _ = embedding_face(cv2.imread(target))
    if embedding_target is None:
        raise ValueError("Target face embedding could not be computed.")

    embedding_candidates = []
    for candidate in candidates:
        embedding_candidate, _ = embedding_face(cv2.imread(candidate))
        if embedding_candidate is not None:
            embedding_candidates.append(embedding_candidate)

    return find_most_similar(embedding_target, embedding_candidates)


def find_most_similar(embedding: torch.Tensor, candidates: List[torch.Tensor]):
    similarities = []

    for candidate in candidates:
        similarity = torch.nn.functional.cosine_similarity(embedding, candidate).item()
        similarities.append(similarity)

    return similarities


def plot_face_and_similarities(
    faces: List[str],
    similarities: List[float],
):
    fig, axs = plt.subplots(2, 5, figsize=(15, 6))
    for i, (face, similarity) in enumerate(zip(faces, similarities)):
        row = i // 5
        col = i % 5

        image = cv2.imread(face, cv2.IMREAD_COLOR)
        axs[row, col].imshow(image)
        axs[row, col].axis("off")
        axs[row, col].set_title(f"Similarity: {similarity:.2f}")

    plt.tight_layout()
    plt.show()


def main() -> None:
    root_directory = os.getcwd()
    me_directory = os.path.join(root_directory, "datasets", "faces", "me")
    ginting_directory = os.path.join(root_directory, "datasets", "faces", "ginting")
    momota_directory = os.path.join(root_directory, "datasets", "faces", "momota")

    if (
        not os.path.exists(me_directory)
        or not os.path.exists(ginting_directory)
        or not os.path.exists(momota_directory)
    ):
        print("One or more directories do not exist. Please check the paths.")
        return

    # Collects all images from the directories
    target_image = os.path.join(me_directory, os.listdir(me_directory)[0])
    print(target_image)

    candidates = []
    for directory in [ginting_directory, momota_directory, me_directory]:
        candidates.extend([os.path.join(directory, f) for f in os.listdir(directory) if f.endswith((".jpg", ".png"))])

    similarities = find_most(target_image, candidates)
    plot_face_and_similarities(
        faces=candidates,
        similarities=similarities,
    )


ImportError: cannot import name 'is_directory' from 'PIL._util' (/Users/hinsun/Workspace/Software/Agrismart/.venv/lib/python3.12/site-packages/PIL/_util.py)