# Face detection and recognition training pipeline

The following example illustrates how to fine-tune an InceptionResnetV1 model on your own dataset. This will mostly follow standard pytorch training patterns.

In [None]:
from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision.utils import save_image
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import v2
import numpy as np
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

import torchvision.transforms.functional as F
from PIL import Image

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

In [None]:
workers = 0 if os.name == 'nt' else 8

## Data Augmentation

In [None]:
read_pipeline = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    fixed_image_standardization,
])

In [None]:
transformations = v2.Compose([
    v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.)),
    v2.ColorJitter(brightness=.5, hue=.3),
    v2.Grayscale(num_output_channels=3),
    v2.RandomPosterize(bits=2),
    v2.RandomPerspective(p=0.5),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    fixed_image_standardization,
])

#### Define run parameters

The dataset should follow the VGGFace2/ImageNet-style directory layout. Modify `data_dir` to the location of the dataset on wish to finetune on.

In [None]:
num_images_to_generate = 50
synthetic_images = []
image_idx = 1
for x, y in synthetic_dataset:
    output_directory = "./data/train/Ian/Ian0.jpg"
    save_image(x, output_directory)

for i in range(num_images_to_generate):
    for x, y in synthetic_dataset:
        new_image = x
        output_directory = "./data/train/Ian/Ian" + str(image_idx) + str(".jpg")
        color_transformation_idx = np.random.randint(low=0, high=4)
        color_transformation = None

        if color_transformation_idx != 4:
            color_transformation = color_transformations[color_transformation_idx]
            new_image = color_transformation(new_image)

        new_image = posture_transformations(new_image)

        if not torch.equal(x, new_image):
            save_image(new_image, output_directory)
            image_idx +=1 

#### Define Inception Resnet V1 module

See `help(InceptionResnetV1)` for more details.

In [None]:
train_data_dir = "./Training"
train_dataset = datasets.ImageFolder(train_data_dir, transform=read_pipeline)

val_data_dir = "./data/val"
val_dataset = datasets.ImageFolder(val_data_dir, transform=read_pipeline)

#### Define optimizer, scheduler, dataset, and dataloader

In [None]:
optimizer = optim.Adam(resnet.parameters(), lr=0.001)
# scheduler = MultiStepLR(optimizer, [5, 10])

img_inds = np.arange(len(train_dataset))
np.random.shuffle(img_inds)
train_inds = img_inds[:int(0.8 * len(train_dataset))]
val_inds = img_inds[int(0.8 * len(train_dataset)):]

train_loader = DataLoader(
    train_dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(train_inds)
)
val_loader = DataLoader(
    train_dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(val_inds)
)

#### Define loss and evaluation functions

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
metrics = {
    'fps': training.BatchTimer(),
    'acc': training.accuracy
}

#### Train model

In [None]:
writer = SummaryWriter()
writer.iteration, writer.interval = 0, 10

print('\n\nInitial')
print('-' * 10)
resnet.eval()
training.pass_epoch(
    resnet, loss_fn, val_loader,
    batch_metrics=metrics, show_running=True, device=device,
    writer=writer
)

for epoch in range(epochs):
    print('\nEpoch {}/{}'.format(epoch + 1, epochs))
    print('-' * 10)

    resnet.train()
    training.pass_epoch(
        resnet, loss_fn, train_loader, optimizer,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )

    resnet.eval()
    training.pass_epoch(
        resnet, loss_fn, val_loader,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )

writer.close()

### Define MTCNN module
### See help(MTCNN) for more details.

In [None]:
class TripletDataset(torch.utils.data.Dataset):
    def __init__(self, identity_image, negative_dataset, transform=None):
        """
        Args:
            anchor_image: A single image of the registered person (PIL or np.array).
            negative_dataset: A dataset with multiple identities.
            transform: Transformation for data augmentation.
        """
        self.anchor_image = identity_image
        self.negative_dataset = negative_dataset
        self.transform = transform

    def __len__(self):
        return len(self.negative_dataset)

    def __getitem__(self, idx):
        # Generate anchor and positive (augmented anchor)
        if self.transform:
            anchor = self.transform(self.anchor_image)
            positive = self.transform(self.anchor_image)

        # Get a random negative sample from the public dataset
        negative = self.negative_dataset[idx][0]  # Assuming dataset returns (image, label)
        if self.transform:
            negative = self.transform(negative)

        return anchor, positive, negative

In [None]:
resnet = InceptionResnetV1(
    classify=False,
    pretrained='vggface2',
).to(device)
resnet.logits= None

In [None]:
for param in resnet.parameters():
    param.requires_grad = False
for param in resnet.last_linear.parameters():
    param.requires_grad = True
for param in resnet.last_bn.parameters():
    param.requires_grad = True

In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader
from PIL import Image
from torch import nn

In [None]:
# Instantiate dataset and dataloader
anchor_image = Image.open("data\original\Ian.jpg")
negative_dataset = datasets.ImageFolder("./data/negative", transform=read_pipeline)
triplet_dataset = TripletDataset(anchor_image, negative_dataset, transform=transformations)
dataloader = DataLoader(triplet_dataset, batch_size=128, shuffle=True)

lr = 0.001
momentum = 0.9
weight_decay = 0.0001
    
optimizer = torch.optim.SGD(resnet.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.1)

# Define loss function
triplet_loss = nn.TripletMarginLoss(margin=0.5, p=2, eps=1e-7)

In [None]:
from scipy.spatial.distance import euclidean
num_epochs = 50

original_image = read_pipeline(Image.open("data\original\Ian.jpg")).unsqueeze(0)
test_image1 = read_pipeline(Image.open("data\original\Ian01.jpg")).unsqueeze(0)
test_image2 = read_pipeline(Image.open("./output.jpg")).unsqueeze(0)

resnet.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for anchor, positive, negative in dataloader:
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

        # Forward pass
        emb_anchor = resnet(anchor)
        emb_positive = resnet(positive)
        emb_negative = resnet(negative)

        # Compute loss
        loss = triplet_loss(emb_anchor, emb_positive, emb_negative)
        running_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()


    if (epoch == 50):
        torch.save(resnet.state_dict(), "./lr=0.001_batch_size=64_margin=0.5_epochs=50_glasses.pt")

    resnet.eval()
    with torch.no_grad():
        emb = resnet(original_image.to(device)).cpu().numpy().squeeze()
        emb1 = resnet(test_image1.to(device)).cpu().numpy().squeeze()
        emb2 = resnet(test_image2.to(device)).cpu().numpy().squeeze()
        distance = euclidean(emb, emb1)
        print(distance)
        distance = euclidean(emb, emb2)
        print(distance)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(dataloader)}")

In [None]:
val_dataset = datasets.ImageFolder("./data/negative", transform=read_pipeline)

In [None]:
from scipy.spatial.distance import euclidean

# Test pair of images
test_image1 = read_pipeline(Image.open("data\original\Ian\Ian02.jpg")).unsqueeze(0)
test_image2 = read_pipeline(Image.open("./data/val_cropped/Ian/test.jpg")).unsqueeze(0)
resnet.eval()
count = 0
for x, y in val_dataset:
    test_image1 = read_pipeline(Image.open("data\original\Ian\Ian02.jpg")).unsqueeze(0)
    with torch.no_grad():
        emb1 = resnet(test_image1.to(device)).cpu().numpy().squeeze()
        emb2 = resnet(x.unsqueeze(0).to(device)).cpu().numpy().squeeze()   

    # Compute similarity
    distance = euclidean(emb1, emb2)
    if distance < 0.7:
        print(f"Distance: {distance}")
        count += 1

print(f"Accuracy: {1 - count / len(val_dataset)}")

In [None]:
count/len(val_dataset)

In [None]:
print(f"Accuracy: {1-count / len(val_dataset)}")

# Unused

In [None]:
from scipy.spatial.distance import euclidean

# Test pair of images
test_image1 = read_pipeline(Image.open("data\original\Ian\Ian02.jpg")).unsqueeze(0)
test_image2 = read_pipeline(Image.open("data\original\Ian\Ian01.jpg")).unsqueeze(0)

# Generate embeddings
resnet.eval()
with torch.no_grad():
    emb1 = resnet(test_image1.to(device)).cpu().numpy().squeeze()
    emb2 = resnet(test_image2.to(device)).cpu().numpy().squeeze()

# Compute similarity
distance = euclidean(emb1, emb2)
print(f"Distance: {distance}")

In [None]:
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
    device=device
)

In [None]:
dataset = datasets.ImageFolder("./data/val")
dataset.samples = [
    (p, p.replace("./data/val", "./data/val_cropped"))
        for p, _ in dataset.samples
]
        
loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=1,
    collate_fn=training.collate_pil
)

for i, (x, y) in enumerate(loader):
    mtcnn(x, save_path=y)
    print('\rBatch {} of {}'.format(i + 1, len(loader)), end='')
    
# Remove mtcnn to reduce GPU memory usage
del mtcnn

In [None]:
state = torch.load("./model.pt", weights_only=True)

In [None]:
resnet.load_state_dict(state)

In [None]:
torch.save(resnet.state_dict(), "./lr=0.001_batch_size=32_margin=0.5_epochs=50.pt")

In [None]:
torch.save(resnet.state_dict(), "./model2.pt")