In [5]:
import torch
import torch.nn as nn
import AdaFace
from torchvision.models import resnet50
from torch.optim import SGD
from torch.utils.data import DataLoader
from digiface_dataset import DigiFaceDataset

In [6]:
# Hyperparameters (described in paper)
learning_rate = 0.1
batch_size = 256
epochs = 40
embedding_size = 512    # not set in stone, but pretty common embedding size for face recognition tasks

# Create DataLoaders
train_dataset = DigiFaceDataset(root_dir=".", train=True, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Create ResNet50 model
model = resnet50(pretrained=False, num_classes=embedding_size)
model.fc = nn.Linear(
    model.fc.in_features, embedding_size
)  # Replacing last layer to match the required embedding size

# AdaFace loss, described in paper, no idea how it works or what the numbers mean. Taken from the GitHub page of AdaFace
adaface = AdaFace(
    embedding_size=embedding_size, classnum=70722, m=0.4, h=0.333, s=64.0, t_alpha=0.01
)

criterion = torch.nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    for batch_idx, (data, labels) in enumerate(train_loader):
        embeddings = model(data)

        # normalize embeddings so that we focus on angle between feature vectors (embeddings), not their magnitudes
        norms = torch.norm(embeddings, 2, -1, keepdim=True)
        normalized_embedding = embeddings / norms

        # described on the AdaFace GitHub page
        cosine_with_margin = adaface(normalized_embedding, norms, labels)
        loss = criterion(cosine_with_margin, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Learning Rate Scheduling, described in the SynFace paper (recommended in DigiFace, page 5 of DigiFace, page 7 of SynFace)
    if epoch == 24 or epoch == 30 or epoch == 36:
        learning_rate /= 10
        for param_group in optimizer.param_groups:
            param_group["lr"] = learning_rate

    print(f"Epoch {epoch}/{epochs} completed.")

KeyboardInterrupt: 