In [1]:
GROUND_TRUTH_DIR = "groundtruth"

import os
all_images = {}

for subdir in os.listdir(GROUND_TRUTH_DIR):
    subdir_path = os.path.join(GROUND_TRUTH_DIR, subdir)
    if os.path.isdir(subdir_path):
        for people in os.listdir(subdir_path):
            people_path = os.path.join(subdir_path, people)
            person_name = people.split("_")[0]
            if os.path.isdir(people_path):
                files = os.listdir(people_path)
                all_images.setdefault(person_name, []).extend(
                    [os.path.join(GROUND_TRUTH_DIR, subdir, people, f) for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                )

In [2]:
# ALIGNED_DIR = "aligned_faces"

# import os
# all_images = {}
# for people in os.listdir(ALIGNED_DIR):
#     people_path = os.path.join(ALIGNED_DIR, people)
#     person_name = people.split("_")[0]
#     if os.path.isdir(people_path):
#         files = os.listdir(people_path)
#         all_images.setdefault(person_name, []).extend(
#             [os.path.join(ALIGNED_DIR, people, f) for f in files]
#         )

In [3]:
# Create a pytorch Dataset to load the images
from torch.utils.data import Dataset
from PIL import Image
import numpy as np

class FaceDataset(Dataset):
    def __init__(self, images_dict, transform=None):
        self.images = []
        self.labels = []
        self.label_to_index = {label: idx for idx, label in enumerate(images_dict.keys())}
        self.transform = transform
        
        for label, img_paths in images_dict.items():
            for img_path in img_paths:
                self.images.append(img_path)
                self.labels.append(label)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        ind = self.label_to_index[label]

        image = Image.open(img_path).convert("RGB")
        # image = np.load(img_path)

        if self.transform:
            image = self.transform(image)

        return image, ind

In [None]:
import torch
from torchvision import transforms
from torchvision.transforms import ColorJitter
    
train_transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.RandomApply([
        ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.1,
            hue=0.05,
        )
    ], p=0.8),
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.ToTensor(), # [0, 1]
    # to [-1, 1]
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

val_tfms = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

from torch.utils.data import DataLoader, random_split
# Split dataset into train and validation sets
dataset = FaceDataset(all_images)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = val_tfms
batch_size = 32
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

  from .autonotebook import tqdm as notebook_tqdm
2025-12-19 15:05:40.136205: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-19 15:05:40.144298: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766156740.153677 1455960 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766156740.156339 1455960 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766156740.163749 1455960 computation_placer.cc:177] computation placer already r

In [None]:
import torch
import torch.nn as nn
from facenet_pytorch import InceptionResnetV1
import math

class FaceEmbeddingModel(nn.Module):
    def __init__(self, embedding_dim=512):
        super().__init__()
        self.backbone = InceptionResnetV1(
            pretrained="vggface2",
            classify=False
        )
        self.embedding_dim = embedding_dim

    def forward(self, x):
        emb = self.backbone(x)
        emb = nn.functional.normalize(emb, p=2, dim=1)
        return emb
    

class ArcFaceLoss(nn.Module):
    def __init__(self, embedding_dim, num_classes, s=64.0, m=0.5):
        super().__init__()
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.randn(num_classes, embedding_dim))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, embeddings, labels):
        W = nn.functional.normalize(self.weight, p=2, dim=1)
        cosine = torch.matmul(embeddings, W.t())

        theta = torch.acos(torch.clamp(cosine, -1 + 1e-7, 1 - 1e-7))
        target_logits = torch.cos(theta + self.m)

        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, labels.view(-1, 1), 1)

        output = cosine * (1 - one_hot) + target_logits * one_hot
        output *= self.s
        return nn.functional.cross_entropy(output, labels)
    
def embedding_var_loss(emb):
    return torch.var(emb, dim=0).mean()

In [None]:
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import defaultdict
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

num_classes = len(set(dataset.labels))
batch_size = 64
num_epochs = 20

# -----------------------
# Model / loss / opt
# -----------------------
model = FaceEmbeddingModel().to(device)
criterion = ArcFaceLoss(
    embedding_dim=512,
    num_classes=num_classes,
    s=32.0,
    m=0.3
).to(device)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-4
)

# -----------------------
# Helper: embedding metrics
# -----------------------
def embedding_metrics(embeddings, labels):
    """
    Computes intra-class and inter-class cosine similarity.
    """
    embeddings = F.normalize(embeddings, dim=1)
    sim_matrix = embeddings @ embeddings.T

    labels = labels.unsqueeze(1)
    same = labels == labels.T
    diff = labels != labels.T

    # remove self-similarity
    eye = torch.eye(len(labels), dtype=torch.bool)
    same = same & ~eye.to(same.device)

    intra = sim_matrix[same]
    inter = sim_matrix[diff]

    return {
        "intra_mean": intra.mean().item() if len(intra) > 0 else float("nan"),
        "inter_mean": inter.mean().item() if len(inter) > 0 else float("nan"),
        "margin": (
            intra.mean() - inter.mean()
        ).item() if len(intra) > 0 and len(inter) > 0 else float("nan")
    }

# -----------------------
# Training loop
# -----------------------
best_val_loss = float("inf")
for p in model.backbone.parameters():
    p.requires_grad = False

for epoch in range(num_epochs):
    # -------- TRAIN --------
    model.train()
    train_loss = 0.0
    
    if epoch >= 8:
        for p in model.backbone.parameters():
            p.requires_grad = True
    elif epoch >= 4:
        for p in model.backbone.block8.parameters():
            p.requires_grad = True
    

    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch} [train]"):
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        embeddings = model(imgs)
        loss = criterion(embeddings, labels)
        loss += 1e-4 * embedding_var_loss(embeddings)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    # -------- VALIDATION --------
    model.eval()
    val_loss = 0.0
    all_embeddings = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc=f"Epoch {epoch} [val]"):
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            embeddings = model(imgs)
            loss = criterion(embeddings, labels)
            loss += 1e-4 * embedding_var_loss(embeddings)

            val_loss += loss.item()
            all_embeddings.append(embeddings.cpu())
            all_labels.append(labels.cpu())

    val_loss /= len(val_loader)

    all_embeddings = torch.cat(all_embeddings)
    all_labels = torch.cat(all_labels)

    metrics = embedding_metrics(all_embeddings, all_labels)
    if epoch > 10 and val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_face_reid_model.pth")

    # -------- LOG --------
    print(
        f"\nEpoch {epoch}\n"
        f"Train loss: {train_loss:.4f}\n"
        f"Val loss:   {val_loss:.4f}\n"
        f"Intra-class similarity: {metrics['intra_mean']:.4f}\n"
        f"Inter-class similarity: {metrics['inter_mean']:.4f}\n"
        f"Margin (↑ is good):     {metrics['margin']:.4f}\n"
    )


Epoch 0 [train]: 100%|██████████| 225/225 [00:04<00:00, 52.32it/s]
Epoch 0 [val]: 100%|██████████| 57/57 [00:01<00:00, 49.25it/s]



Epoch 0
Train loss: 12.6408
Val loss:   12.6436
Intra-class similarity: 0.2320
Inter-class similarity: -0.0206
Margin (↑ is good):     0.2526



Epoch 1 [train]: 100%|██████████| 225/225 [00:03<00:00, 62.02it/s]
Epoch 1 [val]: 100%|██████████| 57/57 [00:01<00:00, 47.66it/s]



Epoch 1
Train loss: 12.6341
Val loss:   12.6709
Intra-class similarity: 0.2310
Inter-class similarity: -0.0205
Margin (↑ is good):     0.2516



Epoch 2 [train]: 100%|██████████| 225/225 [00:03<00:00, 59.54it/s]
Epoch 2 [val]: 100%|██████████| 57/57 [00:01<00:00, 48.19it/s]



Epoch 2
Train loss: 12.6452
Val loss:   12.6489
Intra-class similarity: 0.2330
Inter-class similarity: -0.0195
Margin (↑ is good):     0.2525



Epoch 3 [train]: 100%|██████████| 225/225 [00:03<00:00, 57.93it/s]
Epoch 3 [val]: 100%|██████████| 57/57 [00:01<00:00, 48.73it/s]



Epoch 3
Train loss: 12.6358
Val loss:   12.6381
Intra-class similarity: 0.2295
Inter-class similarity: -0.0206
Margin (↑ is good):     0.2501



Epoch 4 [train]: 100%|██████████| 225/225 [00:03<00:00, 58.29it/s]
Epoch 4 [val]: 100%|██████████| 57/57 [00:01<00:00, 49.60it/s]



Epoch 4
Train loss: 7.2755
Val loss:   4.1918
Intra-class similarity: 0.6961
Inter-class similarity: -0.0634
Margin (↑ is good):     0.7595



Epoch 5 [train]: 100%|██████████| 225/225 [00:03<00:00, 58.02it/s]
Epoch 5 [val]: 100%|██████████| 57/57 [00:01<00:00, 49.23it/s]



Epoch 5
Train loss: 4.1971
Val loss:   3.2531
Intra-class similarity: 0.7964
Inter-class similarity: -0.0723
Margin (↑ is good):     0.8687



Epoch 6 [train]: 100%|██████████| 225/225 [00:03<00:00, 58.85it/s]
Epoch 6 [val]: 100%|██████████| 57/57 [00:01<00:00, 47.41it/s]



Epoch 6
Train loss: 3.5298
Val loss:   2.9249
Intra-class similarity: 0.8346
Inter-class similarity: -0.0741
Margin (↑ is good):     0.9087



Epoch 7 [train]: 100%|██████████| 225/225 [00:03<00:00, 60.05it/s]
Epoch 7 [val]: 100%|██████████| 57/57 [00:01<00:00, 49.91it/s]



Epoch 7
Train loss: 3.0923
Val loss:   2.8258
Intra-class similarity: 0.8462
Inter-class similarity: -0.0770
Margin (↑ is good):     0.9232



Epoch 8 [train]: 100%|██████████| 225/225 [00:07<00:00, 31.04it/s]
Epoch 8 [val]: 100%|██████████| 57/57 [00:01<00:00, 48.39it/s]



Epoch 8
Train loss: 1.7697
Val loss:   1.2333
Intra-class similarity: 0.8540
Inter-class similarity: -0.0745
Margin (↑ is good):     0.9286



Epoch 9 [train]: 100%|██████████| 225/225 [00:07<00:00, 31.48it/s]
Epoch 9 [val]: 100%|██████████| 57/57 [00:01<00:00, 48.29it/s]



Epoch 9
Train loss: 1.0893
Val loss:   1.2067
Intra-class similarity: 0.8471
Inter-class similarity: -0.0730
Margin (↑ is good):     0.9201



Epoch 10 [train]: 100%|██████████| 225/225 [00:07<00:00, 30.45it/s]
Epoch 10 [val]: 100%|██████████| 57/57 [00:01<00:00, 48.58it/s]



Epoch 10
Train loss: 0.7830
Val loss:   1.1385
Intra-class similarity: 0.8509
Inter-class similarity: -0.0728
Margin (↑ is good):     0.9237



Epoch 11 [train]: 100%|██████████| 225/225 [00:07<00:00, 30.88it/s]
Epoch 11 [val]: 100%|██████████| 57/57 [00:01<00:00, 48.88it/s]



Epoch 11
Train loss: 0.6304
Val loss:   1.1904
Intra-class similarity: 0.8340
Inter-class similarity: -0.0745
Margin (↑ is good):     0.9085



Epoch 12 [train]: 100%|██████████| 225/225 [00:07<00:00, 31.01it/s]
Epoch 12 [val]: 100%|██████████| 57/57 [00:01<00:00, 48.85it/s]



Epoch 12
Train loss: 0.5473
Val loss:   1.3083
Intra-class similarity: 0.8252
Inter-class similarity: -0.0745
Margin (↑ is good):     0.8997



Epoch 13 [train]: 100%|██████████| 225/225 [00:07<00:00, 31.22it/s]
Epoch 13 [val]: 100%|██████████| 57/57 [00:01<00:00, 50.12it/s]



Epoch 13
Train loss: 0.4365
Val loss:   1.4859
Intra-class similarity: 0.8193
Inter-class similarity: -0.0689
Margin (↑ is good):     0.8882



Epoch 14 [train]: 100%|██████████| 225/225 [00:07<00:00, 31.35it/s]
Epoch 14 [val]: 100%|██████████| 57/57 [00:01<00:00, 50.07it/s]



Epoch 14
Train loss: 0.3446
Val loss:   1.3352
Intra-class similarity: 0.8118
Inter-class similarity: -0.0676
Margin (↑ is good):     0.8795



Epoch 15 [train]: 100%|██████████| 225/225 [00:07<00:00, 31.01it/s]
Epoch 15 [val]: 100%|██████████| 57/57 [00:01<00:00, 50.21it/s]



Epoch 15
Train loss: 0.2956
Val loss:   1.3782
Intra-class similarity: 0.8209
Inter-class similarity: -0.0750
Margin (↑ is good):     0.8960



Epoch 16 [train]: 100%|██████████| 225/225 [00:07<00:00, 31.22it/s]
Epoch 16 [val]: 100%|██████████| 57/57 [00:01<00:00, 49.95it/s]



Epoch 16
Train loss: 0.2507
Val loss:   1.3932
Intra-class similarity: 0.8156
Inter-class similarity: -0.0746
Margin (↑ is good):     0.8901



Epoch 17 [train]: 100%|██████████| 225/225 [00:07<00:00, 31.23it/s]
Epoch 17 [val]: 100%|██████████| 57/57 [00:01<00:00, 50.74it/s]



Epoch 17
Train loss: 0.1925
Val loss:   1.5206
Intra-class similarity: 0.8301
Inter-class similarity: -0.0664
Margin (↑ is good):     0.8964



Epoch 18 [train]: 100%|██████████| 225/225 [00:07<00:00, 30.69it/s]
Epoch 18 [val]: 100%|██████████| 57/57 [00:01<00:00, 50.85it/s]



Epoch 18
Train loss: 0.2200
Val loss:   1.5492
Intra-class similarity: 0.7995
Inter-class similarity: -0.0652
Margin (↑ is good):     0.8646



Epoch 19 [train]: 100%|██████████| 225/225 [00:07<00:00, 30.12it/s]
Epoch 19 [val]: 100%|██████████| 57/57 [00:01<00:00, 50.73it/s]


Epoch 19
Train loss: 0.1468
Val loss:   1.5754
Intra-class similarity: 0.8034
Inter-class similarity: -0.0716
Margin (↑ is good):     0.8750






In [1]:
# save final model
torch.save(model.state_dict(), "face_reid_model.pth")

NameError: name 'torch' is not defined

In [8]:
# load best model
# model.load_state_dict(torch.load("final_face_reid_model.pth"))
model.load_state_dict(torch.load("best_face_reid_model.pth"))

<All keys matched successfully>

# Visualise validation set

In [None]:
import torch
import numpy as np

model.eval()

embeddings = []
labels = []
with torch.no_grad():
    for imgs, lbls in val_loader:
        imgs = imgs.to(device)
        emb = model(imgs)
        embeddings.append(emb.cpu())
        labels.append(lbls)
embeddings = torch.cat(embeddings).numpy()
labels = torch.cat(labels).numpy()

Validation Accuracy: 0.0000


In [25]:
import umap

reducer = umap.UMAP(
    n_neighbors=15,
    min_dist=0.1,
    metric="cosine",
    random_state=42
)

emb_2d = reducer.fit_transform(embeddings)



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [26]:
import plotly.express as px
import pandas as pd

index_to_label = {idx: label for label, idx in dataset.label_to_index.items()}

df = pd.DataFrame({
    "x": emb_2d[:, 0],
    "y": emb_2d[:, 1],
    "label": [index_to_label[idx] for idx in labels],
    "index": np.arange(len(labels))
})

fig = px.scatter(
    df,
    x="x",
    y="y",
    color="label",
    hover_data=["label", "index"],
    title="ArcFace Embedding Space (Validation Set)",
    width=900,
    height=700
)

fig.update_traces(marker=dict(size=6, opacity=0.75))
fig.update_layout(legend_title_text="Identity")

html_path = "arcface_embeddings.html"
fig.write_html(html_path)


# Build a KNN-classification

In [47]:
import torch
import torch.nn.functional as F

def build_embedding_gallery(model, loader, device="cuda"):
    model.eval()
    embs = []
    labels = []

    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device)
            emb = F.normalize(model(imgs), dim=1)
            embs.append(emb.cpu())
            labels.append(lbls.cpu())

    return torch.cat(embs), torch.cat(labels)

gallery_embs, gallery_labels = build_embedding_gallery(
    model, train_loader, device
)

In [56]:
import torch
import torch.nn.functional as F
from collections import Counter

class KNNFaceRecognizer:
    def __init__(self, embeddings, labels, 
                 index_to_label,
                 k=5):
        self.embeddings = F.normalize(embeddings, dim=1)
        self.labels = labels
        self.k = k
        self.index_to_label = index_to_label

    @torch.no_grad()
    def predict(self, embedding, sim_threshold=0.5):
        if not isinstance(embedding, torch.Tensor):
            emb = torch.tensor(embedding)
        else:
            emb = embedding.clone()

        emb = F.normalize(emb, dim=0)

        sims = torch.mv(self.embeddings, emb)
        topk = torch.topk(sims, self.k)

        top_labels = self.labels[topk.indices]
        vote = Counter(top_labels.tolist()).most_common(1)[0]
        pred_label, count = vote

        best_sim = topk.values[0].item()
        confidence = count / self.k

        if best_sim < sim_threshold:
            return None, {
                "best_sim": best_sim,
                "confidence": confidence,
                "neighbors": top_labels.tolist()
            }
        
        pred_label = self.index_to_label.get(pred_label, None)
        return pred_label, {
            "best_sim": best_sim,
            "confidence": confidence,
            "neighbors": top_labels.tolist()
        }

knn_model = KNNFaceRecognizer(
    embeddings=gallery_embs,
    labels=gallery_labels,
    k=5,
    index_to_label=index_to_label
)

torch.save({
    "embeddings": knn_model.embeddings,
    "labels": knn_model.labels,
    "k": knn_model.k,
    "index_to_label": index_to_label
}, "knn_face_model.pt")


In [58]:
# Measure accuracy on validation set
correct = 0
total = 0
model.eval()
with torch.no_grad():
    for data in val_loader:
        images, labels = data
        outputs = model(images.to(device))
        outputs = F.normalize(outputs, dim=1).cpu()

        for i in range(outputs.size(0)):
            emb = outputs[i]
            label = labels[i].item()
            label = index_to_label[label]

            pred_label, info = knn_model.predict(
                emb,
                sim_threshold=0.7
            )

            if pred_label == label:
                correct += 1
            total += 1
accuracy = correct / total if total > 0 else 0
print(f"Validation Accuracy: {accuracy * 100:.2f}%")

Validation Accuracy: 95.06%


In [None]:
# Save knn model
import pickle

In [46]:
# cut from 3:20 to 5:00 with ffmpeg, reduce the fps by half
video = "/mnt/ssd0/castle/castle_downloader/CASTLE2024/main/day1/Allie/09.mp4"
output_video = "trimmed_video.mp4"
!ffmpeg -ss 00:03:20 -to 00:05:00 -i {video} -c copy {output_video}

ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab