# Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/github/age_gender_prediction/

# Age and Gender Prediction

In [None]:
from src.dataset import WikiDataset
from src.model import AgeGenderPredictor
from torch.utils.data import DataLoader

import torch
import time

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "CPU")

print("Device used:", device)

In [None]:
img_dir = "./wiki"
checkpoint_dir = "./checkpoints"
batch_size = 64
batch_shuffle = True

lr = 0.0001
n_epochs = 100

print_freq = 1000
save_freq = 10

In [None]:
dataset = WikiDataset(img_dir, train=True)

In [None]:
dataloader = DataLoader(dataset, batch_size, shuffle=batch_shuffle)

In [None]:
model = AgeGenderPredictor(lr, device=device)

In [None]:
total_iterations = 0
train_start_time = time.time()

n_print = 1
n_save = 1

for epoch in range(n_epochs):
    start_time = time.time()

    epoch_iter = 0

    for i, data in enumerate(dataloader):
        current_batch_size = len(data['age'])

        age, gender = model(data)
        model.optimize_parameters()

        total_iterations += current_batch_size
        epoch_iter += current_batch_size

        age_mse_loss = torch.sum((age - data['age'].to(device)) ** 2)
        gender_mse_loss = torch.sum((gender - data['gender'].to(device)) ** 2)

        if total_iterations > (print_freq * n_print):
            time_taken = time.time() - train_start_time

            print("--------------------E%d-----------------------" % (epoch+1))
            print("Current Iteration: %05d | Epoch Iteration: %05d" % (print_freq * n_print, epoch_iter))
            print("Current Time Taken: %07ds | Current Epoch Running Time: %07ds" % (time_taken, time.time() - start_time))
            print("Age CE Loss: %.7f | Gender CE Loss: %.7f" % (model.age_loss, model.gender_loss))
            print("Age Accuracy (MSE): %.7f | Gender Accuracy (MSE): %.7f" % (age_mse_loss, gender_mse_loss))
            n_print += 1

        if (epoch+1) >= save_freq * n_save:
            print("Saving models...")
            model.save_model(checkpoint_dir, save_freq * n_save)
            n_save += 1
            

print(f"Total time taken: {time.time() - train_start_time}")
print("Saving trained model ...")
model.save_model(checkpoint_dir, epoch="trained")
