# Imports

In [None]:
pip install facenet_pytorch

Collecting facenet_pytorch
  Downloading facenet_pytorch-2.6.0-py3-none-any.whl.metadata (12 kB)
Collecting numpy<2.0.0,>=1.24.0 (from facenet_pytorch)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting Pillow<10.3.0,>=10.2.0 (from facenet_pytorch)
  Downloading pillow-10.2.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (9.7 kB)
Collecting torch<2.3.0,>=2.2.0 (from facenet_pytorch)
  Downloading torch-2.2.2-cp312-cp312-manylinux1_x86_64.whl.metadata (25 kB)
Collecting torchvision<0.18.0,>=0.17.0 (from facenet_pytorch)
  Downloading torchvision-0.17.2-cp312-cp312-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch<2.3.0,>=2.2.0->facenet_pytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia

In [None]:
pip install tqdm seaborn




In [None]:
import torch
import numpy as np
import pandas as pd
import cv2
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score
from facenet_pytorch import MTCNN, InceptionResnetV1

# Initialization

In [None]:
ROOT = Path("/content/drive/MyDrive/EECS5322_prj/Project 9 - attentive face recognition")
GALLERY_DIR = ROOT / "gallery"
WILD_IMG_DIR = ROOT / "queries"
GROUND_TRUTH_CSV = ROOT / "ground_truth.csv"  # contains filename, True_ID
save_path = "/content/drive/MyDrive/EECS5322_prj/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize MTCNN detector + FaceNet embedding model
mtcnn = MTCNN(image_size=160, margin=20, device=device)
facenet = InceptionResnetV1(pretrained="vggface2").eval().to(device)


  0%|          | 0.00/107M [00:00<?, ?B/s]

In [None]:
df = pd.read_csv(GROUND_TRUTH_CSV)
df['True_ID'].value_counts()


Unnamed: 0_level_0,count
True_ID,Unnamed: 1_level_1
7,12
0,7
3,7
1,5
4,5
6,3
5,2
2,1


In [None]:
device

device(type='cpu')

# build gallery

In [None]:
def load_image(path):
    img = cv2.imread(str(path))
    if img is None:
        raise FileNotFoundError(f"Cannot load image {path}")
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

In [None]:
def get_embedding(img):
    face = mtcnn(img) ## detect faces
    if face is None:
        return None
    face = face.unsqueeze(0).to(device)
    with torch.no_grad():
        emb = facenet(face).cpu().numpy().flatten() ## embed using facenet
    return emb / np.linalg.norm(emb)  # L2 normalize

In [None]:
def build_gallery(gallery_dir):
    gallery = {}
    person_dirs = sorted([d for d in gallery_dir.iterdir() if d.is_dir()])
    for person_dir in tqdm(person_dirs, desc="Building gallery"):
        pid = int(person_dir.name)  # assuming folder names 0-6
        gallery[pid] = []
        for img_path in sorted(person_dir.glob("*.jpg")):
            img = load_image(img_path)
            emb = get_embedding(img)
            if emb is None:
                print(f"[Warning] No face detected in {img_path}")
                continue
            gallery[pid].append(emb)
        gallery[pid] = np.array(gallery[pid])
    return gallery

In [None]:
print("=== Building Gallery ===")
gallery = build_gallery(GALLERY_DIR)
print("Gallery identities:", list(gallery.keys()))

=== Building Gallery ===


Building gallery: 100%|██████████| 7/7 [04:40<00:00, 40.12s/it]

Gallery identities: [0, 1, 2, 3, 4, 5, 6]





In [None]:
len(gallery)

7

In [None]:
len(gallery[1])

5

# Process wild faces

## L2 distance

In [None]:
print(len(gallery))
print(len(gallery[0]))


7
5


In [None]:
### calcualte the L2 distances for comparing the wild faces with the gallery
def identify_face(query_emb, gallery, open_set_threshold):
    best_id = 7  # default = no match
    best_score = -1e9
    # print(len(gallery.items()))
    for pid, embs in gallery.items():
        if len(embs) == 0:
            continue
        #computes the L2 distance between the query embedding and each of the 5 embeddings for that person.
        dists = np.linalg.norm(embs - query_emb, axis=1)
        score = -np.min(dists) ## picks the closest embedding among the 5.
        if score > best_score:
            best_score = score
            best_id = pid
    if open_set_threshold is not None:
        # L2 distance = -score
        if -best_score > open_set_threshold:
            return 7, best_score  # predict "no match"
    return best_id, best_score


In [None]:
def process_wild_set(wild_img_dir, gallery, query_labels, open_set_threshold):
    rows = []
    img_paths = sorted(Path(wild_img_dir).glob("*.jpg"))

    for img_path in tqdm(img_paths, desc="Processing wild set"):
        fname = img_path.name
        true_id = int(query_labels[fname])  # 0-6 for gallery, 7 = no match

        img = load_image(img_path)
        emb = get_embedding(img)

        if emb is None:
            pred_id = 7
            score = 0.0
            correct = 1 if true_id == 7 else 0
        else:
            pred_id, score = identify_face(emb, gallery, open_set_threshold=open_set_threshold)
            correct = 1 if pred_id == true_id else 0

        rows.append({
            "filename": fname,
            "true_id": true_id,
            "pred_id": pred_id,
            "score": score,
            "correct": correct
        })
    return pd.DataFrame(rows)

In [None]:
print("=== Processing Wild Set ===")
ground_truth_df = pd.read_csv(GROUND_TRUTH_CSV)
query_labels = dict(zip(ground_truth_df['filename'], ground_truth_df['True_ID']))
results = process_wild_set(WILD_IMG_DIR, gallery, query_labels, open_set_threshold=0.95)
results.to_csv(save_path + "baseline_results.csv", index=False)
print("Saved baseline_results.csv")

=== Processing Wild Set ===


Processing wild set: 100%|██████████| 42/42 [00:12<00:00,  3.32it/s]


Saved baseline_results.csv


### Evaluation

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import precision_recall_curve, average_precision_score

def evaluate_and_visualize(results_df, save_prefix="results", save_path = ""):
    """
    Computes metrics and generates multiple diagrams:
    - Multi-class confusion matrix
    - Per-ID accuracy bar chart
    - Score distributions
    - Precision-recall curve
    """

    # ------------------- Metrics -------------------
    y_true = results_df['true_id']
    y_pred = results_df['pred_id']

    labels = [0,1,2,3,4,5,6,7]  # 0-6 gallery, 7=no match
    accuracy = (y_true == y_pred).mean()
    precision = (results_df.loc[y_pred != 7, 'correct'].mean()) if len(results_df.loc[y_pred != 7])>0 else 0
    recall = (results_df.loc[y_true != 7, 'correct'].mean()) if len(results_df.loc[y_true != 7])>0 else 0

    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision (gallery matches only): {precision:.4f}")
    print(f"Recall (gallery matches only): {recall:.4f}")

    # ------------------- Multi-class confusion matrix -------------------
    cm = pd.crosstab(y_true, y_pred, rownames=['True'], colnames=['Predicted'], dropna=False)
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
    plt.title("Confusion Matrix (0-6: gallery, 7: no match)")
    plt.tight_layout()
    plt.savefig( str(save_path) + str(save_prefix) + "_confusion_matrix.png")
    plt.close()

    # ------------------- Per-ID accuracy -------------------
    per_id_acc = results_df.groupby('true_id')['correct'].mean()
    plt.figure(figsize=(6,4))
    per_id_acc.plot(kind='bar', color='skyblue')
    plt.ylabel("Accuracy")
    plt.xlabel("Identity")
    plt.title("Per-ID Accuracy")
    plt.tight_layout()
    plt.savefig(str(save_path) + str(save_prefix) + "_per_id_accuracy.png")
    plt.close()

    # ------------------- Score distributions -------------------
    plt.figure(figsize=(6,4))
    sns.histplot(results_df[results_df['correct']==1]['score'], color='green', label='Correct', kde=True)
    sns.histplot(results_df[results_df['correct']==0]['score'], color='red', label='Incorrect', kde=True)
    plt.xlabel("Similarity Score (Negative L2)")
    plt.ylabel("Count")
    plt.title("Score Distribution: Correct vs Incorrect")
    plt.legend()
    plt.tight_layout()
    plt.savefig(str(save_path) + str(save_prefix)+ "_score_distribution.png")
    plt.close()

    # ------------------- Precision-Recall Curve -------------------
    y_true_bin = results_df['correct'].values
    y_scores = results_df['score'].values
    precision_vals, recall_vals, thresholds = precision_recall_curve(y_true_bin, y_scores)
    ap = average_precision_score(y_true_bin, y_scores)

    plt.figure(figsize=(6,6))
    plt.plot(recall_vals, precision_vals, label=f"AP = {ap:.4f}")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Precision-Recall Curve")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(str(save_path) + str(save_prefix)+"_pr_curve.png")
    plt.close()

    # ------------------- Save predictions -------------------
    results_df.to_csv(str(save_path) + str(save_prefix)+"_predictions.csv", index=False)
    print(f"Predictions saved as {save_prefix}_predictions.csv")
    print(f"All diagrams saved with prefix {save_prefix}_*.png")

    return accuracy, precision, recall, cm, per_id_acc


In [None]:
print("=== Evaluating Precision–Recall ===")
# save_path = "/content/drive/MyDrive/EECS5322_prj/1_0.8_threshold/"
acc, prec, rec, cm, per_id_acc = evaluate_and_visualize(results, save_path+ "1_0.8_threshold/")
# print(f"Baseline Average Precision (AP): {ap:.4f}")

=== Evaluating Precision–Recall ===
Accuracy: 0.7857
Precision (gallery matches only): 1.0000
Recall (gallery matches only): 0.7000
Predictions saved as /content/drive/MyDrive/EECS5322_prj/1_0.8_threshold/_predictions.csv
All diagrams saved with prefix /content/drive/MyDrive/EECS5322_prj/1_0.8_threshold/_*.png


No threshold resutls:

=== Evaluating Precision–Recall ===

Accuracy: 0.6190

Precision (gallery matches only): 0.6341

Recall (gallery matches only): 0.8667

Every query was assigned to a gallery ID (0–6), even if it didn’t belong there.

So accuracy and recall are higher (because “unknown” queries were forced into gallery IDs).

Precision was low because many predictions were wrong (false positives).

threshold 0.8

=== Evaluating Precision–Recall ===
Accuracy: 0.5476

Precision (gallery matches only): 1.0000

Recall (gallery matches only): 0.3667


accuracy and recall dropped but the precision increased

Queries not in the gallery are correctly predicted as 7.

The model avoids making wrong predictions for unknowns

Accuracy	Drops slightly because some previously “forced” predictions to gallery are now 7 — correct for unknowns, but may appear as wrong if threshold misclassifies a known gallery image as 7.

Recall	Drops because some gallery images may now be misclassified as 7 if the threshold is too strict (false negatives).

Precision	Increases because the model no longer predicts gallery IDs for unknowns → fewer false positives.

Threshold of the 0.95

had the highest scores overall

Accuracy: 0.7857

Precision (gallery matches only): 1.0000

Recall (gallery matches only): 0.7000

# Find the Minimum distance (threshold)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def plot_distance_distributions(results_df, save_path):
    """
    Plot histograms of minimum distances for known gallery vs unknown queries
    to help choose an open-set threshold.
    """
    # Compute negative score to distance
    results_df['min_dist'] = -results_df['score']

    # Separate known vs unknown
    known = results_df[results_df['true_id'] != 7]['min_dist']
    unknown = results_df[results_df['true_id'] == 7]['min_dist']

    plt.figure(figsize=(8,6))
    sns.histplot(known, color='green', label='Known (gallery)', kde=True, bins=30)
    sns.histplot(unknown, color='red', label='Unknown (not in gallery)', kde=True, bins=30)
    plt.xlabel("Minimum L2 Distance to Gallery")
    plt.ylabel("Count")
    plt.title("Distance Distribution for Known vs Unknown Queries")
    plt.axvline(x=known.mean(), color='green', linestyle='--', label="Known Mean")
    plt.axvline(x=unknown.mean(), color='red', linestyle='--', label="Unknown Mean")
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f"Distance distribution plot saved as {save_path}")


In [None]:
results = process_wild_set(WILD_IMG_DIR, gallery, query_labels, open_set_threshold=None)

# Plot distance distributions to visualize separation
plot_distance_distributions(results, save_path=save_path + "/distance_distributions.png")


Processing wild set: 100%|██████████| 42/42 [00:10<00:00,  3.87it/s]


Distance distribution plot saved as /content/drive/MyDrive/EECS5322_prj//distance_distributions.png


The best threshold is 0.95 for l2 distance