# ArcFace Facial Recognition Model Training

End-to-end notebook for ArcFace-based facial recognition, from data loading and preprocessing to training, evaluation, and model saving.

## Directory & Config Setup
Specify dataset and output directory paths.

In [None]:
import os
# Adjust paths as needed
DATA_ROOT = '../../dataset/images/train/'
LABEL_ROOT = '../../dataset/labels/train/'
MODEL_SAVE_PATH = './arcface_model.pt'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

## Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import cv2
import numpy as np
import matplotlib.pyplot as plt
# You can use a package like 'arcface-pytorch' if available for model head

## Dataset Class & DataLoader
Define a custom Dataset for your data layout.

In [None]:
class FaceDataset(Dataset):
    def __init__(self, img_dir, label_dir, transform=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.img_names = os.listdir(img_dir)
        self.transform = transform
    def __len__(self):
        return len(self.img_names)
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx])
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            image = self.transform(image)
        # Dummy label: Modify to read actual label from file
        label = 0
        return image, label
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((112,112))
    # Add other augmentations if needed
    ])
train_set = FaceDataset(DATA_ROOT, LABEL_ROOT, transform=transform)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)

## ArcFace Model Definition
Use ResNet backbone and ArcFace head (implementation or from package).

In [None]:
class ArcFaceHead(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.5):
        super().__init__()
        self.fc = nn.Linear(in_features, out_features, bias=False)
        self.s = s
        self.m = m
    def forward(self, x, labels):
        x_norm = nn.functional.normalize(x, p=2, dim=1)
        w_norm = nn.functional.normalize(self.fc.weight, p=2, dim=1)
        logits = torch.matmul(x_norm, w_norm.t())
        if labels is not None:
            theta = torch.acos(torch.clamp(logits, -1.0, 1.0))
            target_logit = torch.cos(theta + self.m)
            one_hot = torch.zeros_like(logits)
            one_hot.scatter_(1, labels.view(-1,1), 1.0)
            logits = logits * (1 - one_hot) + target_logit * one_hot
        logits *= self.s
        return logits
backbone = models.resnet18(pretrained=True)
backbone.fc = nn.Identity()
num_classes = 7 # angry, disgust, fear, happy, neutral, sad, surprised
arcface_head = ArcFaceHead(backbone.fc.in_features if hasattr(backbone.fc, 'in_features') else 512, num_classes)
model = nn.Sequential(backbone, arcface_head).to(DEVICE)

## Training Loop

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        features = backbone(images)
        logits = arcface_head(features, labels)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch+1}/{num_epochs} Loss: {running_loss/len(train_loader):.4f}')

## Evaluation

In [None]:
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        features = backbone(images)
        logits = arcface_head(features, labels)
        _, preds = torch.max(logits, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
print(f'Accuracy: {correct/total * 100:.2f}%')

## Model Saving

In [None]:
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f'Model saved to {MODEL_SAVE_PATH}')