<a href="https://colab.research.google.com/github/Ofredy/CNN-Age-Predictor-App/blob/main/workshop/week4/AgePredictorTraining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
# System imports
import os
import time

# Library imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import cv2
import pandas as pd

!pip install onnx
import onnx

Collecting onnx
  Downloading onnx-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.9/15.9 MB[0m [31m46.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: onnx
Successfully installed onnx-1.16.0


In [None]:
# DATASET CONFIGS
BATCH_SIZE = 32
INPUT_CHANNELS = 3
IMG_SIZE = 244
PATH_TO_FOLDER = ""
TRAIN_CSV = "fairface-label-train.csv"
VAL_CSV = "fairface-label-val.csv"

# CNN CONFIGS
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MOBILENET_V3_AVG_POOL_OUT_SIZE = 960
NUM_EPOCHS = 5
LEARNING_RATE = 1e-4
LAST_LAYER_TO_FREEZE = 171
TRAIN_AGE_PRED_ONNX_PATH = "age_predictor.onnx"

In [None]:
# Collecting data to train the Age Predictor
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

def getFile_from_drive( file_id, name ):
  downloaded = drive.CreateFile({'id': file_id})
  downloaded.GetContentFile(name)

getFile_from_drive('1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86', 'fairface-img-margin025-trainval.zip')
getFile_from_drive('1k5vvyREmHDW5TSM9QgB04Bvc8C8_7dl-', 'fairface-label-train.csv')
getFile_from_drive('1_rtz1M1zhvS0d5vVoXUamnohB6cJ02iJ', 'fairface-label-val.csv')

!unzip -qq fairface-img-margin025-trainval.zip



In [None]:
# Dataset class

class AgeDataset(Dataset):

    def __init__(self, data_frame):

        self.data_frame = data_frame

        # Normalizing transform for images
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, index):

        # Parsing csv file for information
        f = self.data_frame.iloc[index].squeeze()
        file = os.path.join(PATH_TO_FOLDER, f.file)
        age = f.age
        img = cv2.imread(file)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        return img, age

    def preprocess_image(self, img):

        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
        img = torch.tensor(img).permute(2, 0, 1)
        img = self.normalize(img/255)

        return img[None]

    def collate_fn(self, batch):

        imgs, ages = [], []

        for img, age in batch:
            img = self.preprocess_image(img)
            imgs.append(img)

            ages.append(float(int(age)/80))

        ages = torch.tensor(ages).to(DEVICE).float()

        imgs = torch.cat(imgs).to(DEVICE)

        return imgs, ages

In [None]:
# AgePredictor Model

class AgePredictor(nn.Module):

    def __init__(self):

        super().__init__()

        # Loading in the mobilenet_v3 model
        self.mobilenet_v3_age_predictor = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT)

        # Freeze parameters so we don't backprop through them
        for param in self.mobilenet_v3_age_predictor.parameters():
            param.requires_grad = False

        # Modifying the classifier to instead predict age
        self.mobilenet_v3_age_predictor.classifier = nn.Sequential(
            nn.Linear(MOBILENET_V3_AVG_POOL_OUT_SIZE, MOBILENET_V3_AVG_POOL_OUT_SIZE//2),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(MOBILENET_V3_AVG_POOL_OUT_SIZE//2, MOBILENET_V3_AVG_POOL_OUT_SIZE//4),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(MOBILENET_V3_AVG_POOL_OUT_SIZE//4, 1),
            nn.Sigmoid()
        )

        self.mobilenet_v3_age_predictor.to(DEVICE)

    def forward(self, input_image):

        # Inputting the input_image and getting our age prediction
        return self.mobilenet_v3_age_predictor(input_image)

    def save_model_onnx(self):

        torch_input = torch.randn(1, INPUT_CHANNELS, IMG_SIZE, IMG_SIZE).to('cpu')
        torch.onnx.export(self.mobilenet_v3_age_predictor.to('cpu'),
                          torch_input,
                          f=TRAIN_AGE_PRED_ONNX_PATH)

In [None]:
# Training functions

def train_batch(mobilenet_v3_age_predictor, age_criterion, optimizer, data):

    mobilenet_v3_age_predictor.train()

    img, age = data

    optimizer.zero_grad()

    # Forward prop of the model
    predicted_age = mobilenet_v3_age_predictor(img)

    # Backward prop of the model
    age_loss = age_criterion(predicted_age.squeeze(), age)

    age_loss.backward()

    optimizer.step()

    return age_loss

def val_batch(mobilenet_v3_age_predictor, age_criterion, data):

    mobilenet_v3_age_predictor.eval()

    img, age = data

    with torch.no_grad():
        predicted_age = mobilenet_v3_age_predictor(img)

    age_loss = age_criterion(predicted_age.squeeze(), age)

    # No forward, this function is used to make sure we are not overfitting our training data
    age_mae = torch.abs(age-predicted_age).float().sum()

    return age_loss, age_mae

def train_age_predictor(mobilenet_v3_age_predictor, age_criterion, optimizer, train_data_loader, val_data_loader):

    starting_epoch = 0

    print("Began training age_predictor, starting_epoch: %d, will train until num_epochs: %d" % (starting_epoch, starting_epoch+NUM_EPOCHS))

    start_time = time.time()

    train_losses, val_losses = [], []
    val_age_maes = []

    n_epochs = NUM_EPOCHS

    for epoch in range(starting_epoch, starting_epoch+n_epochs):

        epoch_train_loss, epoch_val_loss = 0, 0
        val_age_mae, ctr = 0, 0

        # Training batch
        for _, data in enumerate(train_data_loader):
            loss = train_batch(mobilenet_v3_age_predictor, age_criterion, optimizer, data)
            epoch_train_loss += loss.item()

        # Validation batch
        for _, data in enumerate(val_data_loader):
            loss, age_mae = val_batch(mobilenet_v3_age_predictor, age_criterion, data)
            epoch_val_loss += loss.item()
            val_age_mae += age_mae
            ctr += len(data[0])

        # Avergaging the epoch results
        epoch_train_loss /= len(train_data_loader)
        epoch_val_loss /= len(val_data_loader)
        val_age_mae /= ctr

        train_losses.append(epoch_train_loss)
        val_losses.append(epoch_val_loss)
        val_age_maes.append(val_age_mae)

        time_elasped = time.time() - start_time

        epoch_age_pred_weight_path = os.path.join("{}age_predictor_weights.pt".format(epoch+1))

        # Saving the weights of the network
        torch.save(mobilenet_v3_age_predictor.state_dict(), epoch_age_pred_weight_path)

        print('{}/{} ({:.2f}s - {:.2f}s remaining)'.format(epoch+1, starting_epoch+n_epochs, time.time()-start_time, (starting_epoch-epoch)*(time_elasped/(epoch+1))))
        print("train_loss: %.3f, val_loss: %.3f, val_age_mae:%.3f" % (epoch_train_loss, epoch_val_loss, val_age_mae))


In [None]:
train_dataframe = pd.read_csv(TRAIN_CSV)
val_dataframe = pd.read_csv(VAL_CSV)

train = AgeDataset(train_dataframe)
val = AgeDataset(val_dataframe)

# Making the dataloaders with the AgeDatasets
train_data_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=train.collate_fn)
val_data_loader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=val.collate_fn)

age_predictor = AgePredictor()

# Defining our Loss function and the optimizer
age_criterion = nn.L1Loss()
optimizer = optim.Adam(age_predictor.parameters(), lr=LEARNING_RATE)

train_age_predictor(age_predictor, age_criterion, optimizer, train_data_loader, val_data_loader)

Began training age_predictor, starting_epoch: 0, will train until num_epochs: 5
1/5 (413.28s - 0.00s remaining)
train_loss: 0.127, val_loss: 0.116, val_age_mae:6.130
2/5 (824.03s - -411.99s remaining)
train_loss: 0.118, val_loss: 0.116, val_age_mae:6.150
3/5 (1233.40s - -822.23s remaining)
train_loss: 0.115, val_loss: 0.112, val_age_mae:6.081
4/5 (1641.09s - -1230.76s remaining)
train_loss: 0.113, val_loss: 0.111, val_age_mae:6.008
5/5 (2054.38s - -1643.45s remaining)
train_loss: 0.112, val_loss: 0.111, val_age_mae:6.098


In [13]:
age_predictor.save_model_onnx()