# Face detection and recognition training pipeline

The following example illustrates how to fine-tune an InceptionResnetV1 model on your own dataset. This will mostly follow standard pytorch training patterns.

In [1]:
from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision.utils import save_image
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import v2
import numpy as np
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

import torchvision.transforms.functional as F

In [2]:
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img, mode="RGB")
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

#### Determine if an nvidia GPU is available

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cuda:0


## Data Augmentation

In [4]:
synthetic_data_dir = './data/original'

workers = 0 if os.name == 'nt' else 8

In [5]:
read_pipeline = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    fixed_image_standardization,
])

In [42]:
synthetic_dataset = datasets.ImageFolder(synthetic_data_dir, transform=read_pipeline)

In [6]:
posture_transformations = v2.Compose([
    v2.RandomPerspective(p=0.5),
    v2.RandomHorizontalFlip(p=0.5),
])
transformations = v2.Compose([
    v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.)),
    v2.ColorJitter(brightness=.5, hue=.3),
    v2.Grayscale(num_output_channels=3),
    v2.RandomPosterize(bits=2),
    v2.RandomPerspective(p=0.5),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    fixed_image_standardization,
])

In [None]:
num_images_to_generate = 50
synthetic_images = []
image_idx = 1
for x, y in synthetic_dataset:
    output_directory = "./data/train/Ian/Ian0.jpg"
    save_image(x, output_directory)

for i in range(num_images_to_generate):
    for x, y in synthetic_dataset:
        new_image = x
        output_directory = "./data/train/Ian/Ian" + str(image_idx) + str(".jpg")
        color_transformation_idx = np.random.randint(low=0, high=4)
        color_transformation = None

        if color_transformation_idx != 4:
            color_transformation = color_transformations[color_transformation_idx]
            new_image = color_transformation(new_image)

        new_image = posture_transformations(new_image)

        if not torch.equal(x, new_image):
            save_image(new_image, output_directory)
            image_idx +=1 

#### Define run parameters

The dataset should follow the VGGFace2/ImageNet-style directory layout. Modify `data_dir` to the location of the dataset on wish to finetune on.

#### Define Inception Resnet V1 module

See `help(InceptionResnetV1)` for more details.

In [7]:
train_data_dir = "./Training"
train_dataset = datasets.ImageFolder(train_data_dir, transform=read_pipeline)

val_data_dir = "./data/val"
val_dataset = datasets.ImageFolder(val_data_dir, transform=read_pipeline)

FileNotFoundError: [WinError 3] The system cannot find the path specified: './Training'

#### Define optimizer, scheduler, dataset, and dataloader

In [None]:
optimizer = optim.Adam(resnet.parameters(), lr=0.001)
# scheduler = MultiStepLR(optimizer, [5, 10])

img_inds = np.arange(len(train_dataset))
np.random.shuffle(img_inds)
train_inds = img_inds[:int(0.8 * len(train_dataset))]
val_inds = img_inds[int(0.8 * len(train_dataset)):]

train_loader = DataLoader(
    train_dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(train_inds)
)
val_loader = DataLoader(
    train_dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(val_inds)
)

#### Define loss and evaluation functions

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
metrics = {
    'fps': training.BatchTimer(),
    'acc': training.accuracy
}

#### Train model

In [None]:
writer = SummaryWriter()
writer.iteration, writer.interval = 0, 10

print('\n\nInitial')
print('-' * 10)
resnet.eval()
training.pass_epoch(
    resnet, loss_fn, val_loader,
    batch_metrics=metrics, show_running=True, device=device,
    writer=writer
)

for epoch in range(epochs):
    print('\nEpoch {}/{}'.format(epoch + 1, epochs))
    print('-' * 10)

    resnet.train()
    training.pass_epoch(
        resnet, loss_fn, train_loader, optimizer,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )

    resnet.eval()
    training.pass_epoch(
        resnet, loss_fn, val_loader,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )

writer.close()

### Define MTCNN module
### See help(MTCNN) for more details.

In [7]:
class TripletDataset(torch.utils.data.Dataset):
    def __init__(self, identity_image, negative_dataset, transform=None):
        """
        Args:
            anchor_image: A single image of the registered person (PIL or np.array).
            negative_dataset: A dataset with multiple identities.
            transform: Transformation for data augmentation.
        """
        self.anchor_image = anchor_image
        self.negative_dataset = negative_dataset
        self.transform = transform

    def __len__(self):
        return len(self.negative_dataset)

    def __getitem__(self, idx):
        # Generate anchor and positive (augmented anchor)
        if self.transform:
            anchor = self.transform(self.anchor_image)
            positive = self.transform(self.anchor_image)

        # Get a random negative sample from the public dataset
        negative = self.negative_dataset[idx][0]  # Assuming dataset returns (image, label)
        if self.transform:
            negative = self.transform(negative)

        return anchor, positive, negative

In [8]:
resnet = InceptionResnetV1(
    classify=False,
    pretrained='vggface2',
).to(device)
resnet.logits= None

In [17]:
torch.save(resnet.state_dict(), "./lr=0.001_batch_size=32_margin=0.5_epochs=50.pt")

In [11]:
state = torch.load("./lr=0.001_batch_size=16_margin=0.5_epochs=30.pt", weights_only=True)

In [12]:
resnet.load_state_dict(state)

<All keys matched successfully>

In [9]:
for param in resnet.parameters():
    param.requires_grad = False
for param in resnet.last_linear.parameters():
    param.requires_grad = True
for param in resnet.last_bn.parameters():
    param.requires_grad = True

In [10]:
import torch.optim as optim
from torch.utils.data import DataLoader
from PIL import Image
from torch import nn

In [11]:
# Instantiate dataset and dataloader
anchor_image = Image.open("./data/original/Ian/Ian01.jpg")
negative_dataset = datasets.ImageFolder("./data/test_cropped", transform=read_pipeline)
triplet_dataset = TripletDataset(anchor_image, negative_dataset, transform=transformations)
dataloader = DataLoader(triplet_dataset, batch_size=32, shuffle=True)

# Define optimizer
optimizer = optim.Adam(resnet.parameters(), lr=0.001)

# Define loss function
triplet_loss = nn.TripletMarginLoss(margin=0.5, p=2, eps=1e-7)

In [12]:
num_epochs = 50
resnet.train()

for epoch in range(num_epochs):
    running_loss = 0.0
    for anchor, positive, negative in dataloader:
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

        # Forward pass
        emb_anchor = resnet(anchor)
        emb_positive = resnet(positive)
        emb_negative = resnet(negative)

        # Compute loss
        loss = triplet_loss(emb_anchor, emb_positive, emb_negative)
        running_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(dataloader)}")

Epoch [1/50], Loss: 0.47318948546181555
Epoch [2/50], Loss: 0.47262860813002655
Epoch [3/50], Loss: 0.46285642661910126
Epoch [4/50], Loss: 0.4626386777502327
Epoch [5/50], Loss: 0.4524418338773331
Epoch [6/50], Loss: 0.42317332780879474
Epoch [7/50], Loss: 0.3633413582248388
Epoch [8/50], Loss: 0.32769138315593566
Epoch [9/50], Loss: 0.2805519400947336
Epoch [10/50], Loss: 0.10482616023416969
Epoch [11/50], Loss: 0.03927253020684356
Epoch [12/50], Loss: 0.0259538392248397
Epoch [13/50], Loss: 0.0208788334104066
Epoch [14/50], Loss: 0.01705126832846714
Epoch [15/50], Loss: 0.01620827327315934
Epoch [16/50], Loss: 0.01404169663687907
Epoch [17/50], Loss: 0.013913731378669612
Epoch [18/50], Loss: 0.012544564171671723
Epoch [19/50], Loss: 0.012009199875168467
Epoch [20/50], Loss: 0.010323869827054981
Epoch [21/50], Loss: 0.010322317962449241
Epoch [22/50], Loss: 0.010226985435159023
Epoch [23/50], Loss: 0.010204120917056782
Epoch [24/50], Loss: 0.008735930994782903
Epoch [25/50], Loss: 0.

In [13]:
from scipy.spatial.distance import euclidean

# Test pair of images
test_image1 = read_pipeline(Image.open("data\original\Ian\Ian01.jpg")).unsqueeze(0)
test_image2 = read_pipeline(Image.open("data\val_cropped\Ian\image0 (1).jpg")).unsqueeze(0)

# Generate embeddings
resnet.eval()
with torch.no_grad():
    emb1 = resnet(test_image1.to(device)).cpu().numpy().squeeze()
    emb2 = resnet(test_image2.to(device)).cpu().numpy().squeeze()

# Compute similarity
distance = euclidean(emb1, emb2)
print(f"Distance: {distance}")

Distance: 0.0


In [18]:
val_dataset = datasets.ImageFolder("./data/test_cropped", transform=read_pipeline)

In [None]:
from scipy.spatial.distance import euclidean

# Test pair of images
test_image1 = read_pipeline(Image.open("data\original\Ian\Ian01.jpg")).unsqueeze(0)
test_image2 = read_pipeline(Image.open("data\original\Ian\Ian01.jpg")).unsqueeze(0)

# Generate embeddings
resnet.eval()
with torch.no_grad():
    emb1 = resnet(test_image1.to(device)).cpu().numpy().squeeze()
    emb2 = resnet(test_image2.to(device)).cpu().numpy().squeeze()

# Compute similarity
distance = euclidean(emb1, emb2)
print(f"Distance: {distance}")

In [32]:
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
    device=device
)

In [33]:
dataset = datasets.ImageFolder("./data/val")
dataset.samples = [
    (p, p.replace("./data/val", "./data/val_cropped"))
        for p, _ in dataset.samples
]
        
loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=1,
    collate_fn=training.collate_pil
)

for i, (x, y) in enumerate(loader):
    mtcnn(x, save_path=y)
    print('\rBatch {} of {}'.format(i + 1, len(loader)), end='')
    
# Remove mtcnn to reduce GPU memory usage
del mtcnn

Batch 5 of 5