In [1]:
import torch
import numpy as np
from tqdm import tqdm

from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data.dataloader import DataLoader

from preprocessor import AgeRecognitionPreprocessor
from dataset import AgeRecognitionDataset
from models import vit_l_16_age_recognizer, vit_b_16_age_recognizer
from loss import AgeRecognitionLoss

In [2]:
lr = 1e-3

IMAGE_DIR = './Cleaned/'
TRAINING_PAIRINGS = './training_data.csv'
BATCH_SIZE = 16
EPOCHES = 1000
DEVICE = 'cuda'

In [3]:
model = vit_b_16_age_recognizer().to(DEVICE)
loss_function = AgeRecognitionLoss().to(DEVICE)
preprocessor = AgeRecognitionPreprocessor()
dataset = AgeRecognitionDataset(triplet_csv_path=TRAINING_PAIRINGS, image_dir=IMAGE_DIR, preprocessor=preprocessor, kfolds=5, device=DEVICE)

In [4]:
optimizer = Adam(list(model.parameters()) + list(loss_function.parameters()), lr=lr)
scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=EPOCHES, eta_min=1e-6)

In [5]:
for epoch in range(EPOCHES):
    for fold in range(dataset.kfolds):
        training_dataset, validation_dataset = dataset.kfold_cross_validation(fold)
        training_dataloader = DataLoader(training_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
        validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
        model.train()
        # index = 0
        for batch in tqdm(training_dataloader):
            # Batch shape: (N, Anchor-Positive-Negative, C, H, W)
            predictions = model.forward_features(batch)
            training_loss = loss_function(predictions)
            # print(f"Training loss for batch {index} : {training_loss}")
            # index += 1
        optimizer.step()

        model.eval()
        validation_loss = 0
        for index, batch in tqdm(enumerate(validation_dataloader)):
            # Batch shape: (N, Anchor-Positive-Negative, C, H, W)
            predictions = model.forward_features(batch)
            validation_loss += loss_function(predictions)
        validation_loss = validation_loss * BATCH_SIZE / len(validation_dataloader) 
        print(f"Validation loss : {validation_loss}")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': validation_loss,
        }, './Checkpoint/')



  0%|          | 0/4985 [00:06<?, ?it/s]


KeyboardInterrupt: 