In [1]:
import os
import timeit
import torch
import pytorch_lightning as pl

from datetime import datetime
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from face_recognition.models.MetricsCallback import MetricsCallback
from face_recognition.models.FaceNetPytorchLightning import LightningFaceNet
from face_recognition.models.FaceNet import FaceNetResnet
from face_recognition.data.datasets import LFWValidationDataset, TupleDataset, VGGTripletDataset
from face_recognition.utils.constants import MODEL_DIR, CHECKPOINTS_DIR

overfit_root = './face_recognition/data/images/vgg-cropped'
lfw_root = './face_recognition/data/images/lfw_aligned'
pairs_txt = './face_recognition/data/images/pairs.txt'

In [2]:
hparams = {
    'margin': 0.2,
    'lr': 0.0001,
    'weight_decay': 1e-5,
    'optimizer': 'adam'
}

In [3]:
def get_dataloader(dataset, train=False):
    batch_size = 256

    phase = "training" if train else 'validation'
    print(f"Initialize {phase} dataloader.")
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=20)

    return dataloader


def init_datasets():
    train_dir = './face_recognition/data/images/vgg-cropped'
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.6068, 0.4517, 0.3800], std=[0.2492, 0.2173, 0.2082])
    ])

    train_set = VGGTripletDataset(train_dir, 100000, transform=transform)

    val_loader = None
    if lfw_root and pairs_txt:
        lfw_set = LFWValidationDataset(lfw_root, pairs_txt, transform=transform)
        len_lfw_set = int(0.2 * len(train_set)) #len(lfw_set)

        len_train_set = len(train_set) - len_lfw_set
        train_set, val_set = random_split(train_set, [len_train_set, len_lfw_set])

        tuple_set = TupleDataset(lfw_set, val_set)
        val_loader = get_dataloader(tuple_set)

    train_loader = get_dataloader(train_set, train=True)

    return train_loader, val_loader


In [None]:
num_epochs = 30
model_dir = MODEL_DIR
if not os.path.exists(model_dir):
    os.mkdir(model_dir)

time_stamp = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
subdir = os.path.join(model_dir, time_stamp)
if not os.path.exists(subdir):
    os.mkdir(subdir)

train_dir = overfit_root
if not train_dir:
    raise ValueError('No training data specified.')

train_loader, val_loader = init_datasets()

checkpoint_callback = ModelCheckpoint(
    filepath=CHECKPOINTS_DIR,
    verbose=True,
    monitor='val_acc',
    mode='max',
    save_top_k=1
)
logger = TensorBoardLogger('tb_logs', name='facesecure_training')
print("Initialize FaceNet + Resnet")
backbone = FaceNetResnet(pretrained=True)
model = LightningFaceNet(hparams, backbone)

trainer = pl.Trainer(
    gpus=1 if torch.cuda.is_available() else 0,
    max_epochs=num_epochs,
    logger=logger,
    #checkpoint_callback=checkpoint_callback,
    callbacks=[MetricsCallback()]
)

print("Begin Training.")
start = timeit.default_timer()
trainer.fit(model, train_loader, val_loader)
stop = timeit.default_timer()
print("Finished Training in", stop - start, "seconds")

print("Save trained weights.")
model_name = os.path.join(subdir, time_stamp + '.pth')
torch.save(model.model.state_dict(), model_name)