In [1]:
import torch
import numpy as np
from tqdm import tqdm

from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data.dataloader import DataLoader

from preprocessor import AgeRecognitionPreprocessor
from dataset import AgeRecognitionDataset
from models import vit_l_16_age_recognizer, vit_b_16_age_recognizer
from loss import AgeRecognitionLoss

In [2]:
lr = 1e-3

IMAGE_DIR = './Cleaned/'
TRAINING_PAIRINGS = './training_data.csv'
BATCH_SIZE = 12
EPOCHES = 1000
DEVICE = 'cuda'

In [3]:
model = vit_b_16_age_recognizer().to(DEVICE)
loss_function = AgeRecognitionLoss().to(DEVICE)
preprocessor = AgeRecognitionPreprocessor()
dataset = AgeRecognitionDataset(triplet_csv_path=TRAINING_PAIRINGS, image_dir=IMAGE_DIR, preprocessor=preprocessor, kfolds=5, device=DEVICE)

In [4]:
optimizer = Adam(list(model.parameters()) + list(loss_function.parameters()), lr=lr)
scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=EPOCHES, eta_min=1e-6)

In [5]:
for epoch in range(EPOCHES):
    for fold in range(dataset.kfolds):
        training_dataset, validation_dataset = dataset.kfold_cross_validation(fold)
        training_dataloader = DataLoader(training_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
        validation_dataloader = DataLoader(validation_dataset, batch_size=1, shuffle=True, num_workers=4)
        model.train()
        # index = 0
        for batch in tqdm(training_dataloader):
            # Batch shape: (N, Anchor-Positive-Negative, C, H, W)
            predictions = model.forward_features(batch)
            training_loss = loss_function(predictions)
            # print(f"Training loss for batch {index} : {training_loss}")
            # index += 1
        optimizer.step()
        # del(batch)

        # model.eval()
        # validation_loss = 0
        # for batch in tqdm(validation_dataloader):
        #     # Batch shape: (N, Anchor-Positive-Negative, C, H, W)
        #     predictions = model.forward_features(batch)
        #     validation_loss += loss_function(predictions)
        # validation_loss = validation_loss * BATCH_SIZE / len(validation_dataloader) 
        # print(f"Validation loss : {validation_loss}")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': training_loss,
        }, f'./Checkpoint/model_{epoch}.pt')



100%|██████████| 6646/6646 [41:35<00:00,  2.66it/s]  
100%|██████████| 6646/6646 [1:13:27<00:00,  1.51it/s]    
100%|██████████| 6646/6646 [41:33<00:00,  2.67it/s]   
100%|██████████| 6646/6646 [46:39<00:00,  2.37it/s]   
100%|██████████| 6646/6646 [40:12<00:00,  2.75it/s]
100%|██████████| 6646/6646 [59:43<00:00,  1.85it/s]    
100%|██████████| 6646/6646 [1:15:26<00:00,  1.47it/s]    
100%|██████████| 6646/6646 [40:43<00:00,  2.72it/s]
100%|██████████| 6646/6646 [40:43<00:00,  2.72it/s]
100%|██████████| 6646/6646 [40:09<00:00,  2.76it/s]  
  0%|          | 20/6646 [00:19<1:49:04,  1.01it/s]


KeyboardInterrupt: 