<a href="https://colab.research.google.com/github/aneeshcheriank/Advanced-vision/blob/main/Multi_task_learning_Age_estimation_and_Gender_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# import the relevant packages
import torch
import torch.nn as nn
import numpy as np, cv2, pandas as pd, glob, time
import matplotlib.pyplot as plt
%matplotlib inline
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms, models, datasets

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('available device: {}'.format(device))

available device: cuda


In [2]:
# fetch the dataset
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth=GoogleAuth()
gauth.credentials=GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

def getFile_from_drive(file_id, name):
  downloaded = drive.CreateFile({'id': file_id})
  downloaded.GetContentFile(name)

getFile_from_drive(
  '1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86',
  'fairface-img-margin025-trainval.zip'
)
getFile_from_drive(
  '1k5vvyREmHDW5TSM9QgB04Bvc8C8_7dl-',
  'fairface-label-train.csv'
)
getFile_from_drive(
  '1_rtz1M1zhvS0d5vVoXUamnohB6cJ02iJ',
  'fairface-label-val.csv'
)

!unzip -qq fairface-img-margin025-trainval.zip

In [3]:
trn_df = pd.read_csv('fairface-label-train.csv')
val_df = pd.read_csv('fairface-label-val.csv')
trn_df.head()

Unnamed: 0,file,age,gender,race,service_test
0,train/1.jpg,59,Male,East Asian,True
1,train/2.jpg,39,Female,Indian,False
2,train/3.jpg,11,Female,Black,False
3,train/4.jpg,26,Female,Indian,True
4,train/5.jpg,26,Female,Indian,True


In [8]:
# constants
IMAGE_SIZE = 224

class GenderAgeClass(Dataset):
  def __init__(self, df, tfms=None):
    self.df = df
    self.normalize = transforms.Normalize(
      [0.49, 0.48, 0.47],
      [0.22, 0.23, 0.21]     
    )

  def __len__(self):
    return len(self.df)

  def __getitem__(self, ix):
    f = self.df.iloc[ix].squeeze()
    file = f.file
    gen = f.gender == 'Female'
    age = f.age
    im = cv2.imread(file)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    return im, age, gen

  def preprocess_image(self, im):
    im = cv2.resize(im, (IMAGE_SIZE, IMAGE_SIZE))
    im = torch.tensor(im).permute(2, 0, 1)
    im = self.normalize(im/255.)
    return im[None]

  def collate_fn(self, batch):
    'Preprocess images, ages and genders'
    ims, ages, genders = [], [], []
    for im, age, gender in batch:
      im = self.preprocess_image(im)
      ims.append(im)

      ages.append(float(int(age)/80))
      # to scale down the age (max age: 80)
      genders.append(float(gender))

    ages, genders = [torch.tensor(x).to(device).float()\
          for x in [ages, genders]]
    ims = torch.cat(ims).to(device)

    return ims, ages, genders

# Create the datasets
trn = GenderAgeClass(trn_df)
val = GenderAgeClass(val_df)

train_loader = DataLoader(
  trn, batch_size=128, 
  shuffle=True, drop_last=True, 
  collate_fn = trn.collate_fn
)
test_loader =  DataLoader(
  val, batch_size=128, 
  drop_last=True, 
  collate_fn = trn.collate_fn
)

a, b, c = next(iter(train_loader))
print(a.shape, b.shape, c.shape)

torch.Size([128, 3, 224, 224]) torch.Size([128]) torch.Size([128])


In [9]:
# model
model = models.vgg16(pretrained=True)
# freeze the parameters
for param in model.parameters():
  param.requires_grad = False

model.avgpool = nn.Sequential(
  nn.Conv2d(512, 512, 3),
  nn.MaxPool2d(2),
  nn.ReLU(),
  nn.Flatten()
)

class ageGenderClassifier(nn.Module):
  def __init__(self):
    super(ageGenderClassifier, self).__init__()
    self.intermediate = nn.Sequential(
      nn.Linear(2048, 512),
      nn.ReLU(),
      nn.Dropout(0.4),
      nn.Linear(512, 128),
      nn.ReLU(),
      nn.Dropout(0.4),
      nn.Linear(128, 64),
      nn.ReLU()
    )

    self.age_classifier = nn.Sequential(
      nn.Linear(64, 1),
      nn.Sigmoid()
    )

    self.gender_classifier = nn.Sequential(
      nn.Linear(64, 1),
      nn.Sigmoid()
    )

  def forward(self, x):
    x = self.intermediate(x)
    age = self.age_classifier(x)
    gender = self.gender_classifier(x)
    return gender, age

# overwrite the classifier module with the new class
model.classifier = ageGenderClassifier()
model = model.to(device)

gender_criterion = nn.BCELoss()
age_criterion = nn.L1Loss()

loss_fn = gender_criterion, age_criterion
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
# tran batch
def train_batch(data, model, optimizer, criteria):
  model.train()
  ims, age, gender = data
  optimizer.zero_grad()
  pred_gender, pred_age = model(ims)
  
  gender_criterion, age_criterion = criteria
  gender_loss = gender_criterion(pred_gender.squeeze(), gender)
  age_loss = age_criterion(pred_age.squeeze(), age)
  
  # loss calculation
  total_loss = gender_loss + age_loss
  total_loss.backward()
  optimizer.step()
  return total_loss

def validate_batch(data, model, criteria):
  model.eval()
  img, age, gender = data
  with torch.no_grad():
    pred_gender, pred_age = model(img)
  gender_criterion, age_criterion = criteria
  gender_loss = gender_criterion(pred_gender.squeeze(), gender)
  age_loss = age_criterion(pred_age.squeeze(), age)

  total_loss = gender_loss + age_loss
  pred_gender = (pred_gender > 0.5).squeeze()
  gender_acc = (pred_gender == gender).float().sum()
  age_mae = torch.abs(age-pred_age).float().sum()
  return total_loss, gender_acc, age_mae

# train the model
import time
val_gender_accuracies = []
val_age_maes = []
train_losses = []
val_losses = []

n_epochs = 5
best_test_loss = 1000
start = time.time()
# loop through epochs and reinitialize the trian and test loss 
# values at the start of each epoch
for epoch in range(n_epochs):
  epoch_train_loss, epocht_test_loss = 0, 0
  val_age_mae, val_gender_acc, ctr = 0, 0, 0
  _n = len(train_loader)

  for ix, data in enumerate(train_loader):
    loss = train_batch(data, model, optimizer, loss_fn)
    epoch_train_loss += loss.item()

  for ix, data in enumerate(test_loader):
    loss, gender_acc, age_mae = validate_batch(data, model, loss_fn)
    epoch_test_loss += loss.item()
    val_age_mae += age_mae
    val_gender_acc = gender_acc
    ctr += len(data[0])

  val_age_mae /= ctr
  val_gender_acc /= ctr
  epoch_train_loss /= len(train_loader)
  epoch_test_loss /= len(test_loader)

  elapsed = time.time()-start
  best_test_loss = min(best_test_loss, epoch_test_loss)
  print('{}/{} ({:.2f}s - {:.2f}s remaining)'.format(
    epoch+1, n_epochs, time.time()-start, (n_epoch-epoch)*elapsed/(epoch+1)
  ))

  info = f'''Epoch: {epoch+1:03d}
    \tTrain Loss: {epoch_train_loss:.3f}
    \tTest:\{epoch_test_loss:.3f}
    \tBest Test Loss: {best_test_loss:.4f}'''
  info += f'\nGender Accuracy: {val_gender_acc*100:.2f}%\tAge MAE:\
  {val_age_mae:.2f}\n'
  print(info)

  val_gender_accuracies.append(val_gender_acc)
  val_age_maes.append(val_age_mae)

In [None]:
epochs = np.arange(1,(n_epochs+1))
plt.figure(figsize=(5, 20))

plt.subplot(121)
plt.plot(epochs, val_gender_accuracies, 'bo')
plt.title('Validation Gender Accuracy')
plt.xlabel('Epochs'); plt.ylabel('Accuracy')

plt.sublot(122)
plt.plot(epochs, val_age_maes, 'r')
plt.title('Validation Age Mean-Absolute-Error')
plt.xlabel('Epochs'); plt.ylabel('MAE')

plt.show()

## Make prediction on a random image

In [None]:
# get the image
!wget https://www.dropbox.com/s/6kzr8168e9kpjkf/5_9.JPG

# read the image and preprocess it
im = cv2.imread('/conent/5_9.JPG')
im = trn.preprocess_image(im).to(device)

gender, age = model(im)
pred_gender = gender.to('cpu').detach().numpy()
pred_age = age.to('cpu').detach().numpy()

# plot the image
im = cv2.imread('/conent/5_9.JPG')
im = cv2.cvtColor(im, cv2.COLOR_BRG2RGB)
plt.imshow(im)
plt.title('Predicted gender: {}; Predicted age: {}'.format(
    pred_gender, pred_age
))