In order to validate model checkpoints, we run inference on the validation part of the dataset.

In [1]:
import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
from sklearn.metrics import accuracy_score, mean_absolute_error
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras import callbacks, losses, metrics, models, optimizers
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.layers import BatchNormalization, Conv2D, Dense, Dropout, GlobalAveragePooling2D, MaxPooling2D
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import plot_model, Sequence
from tqdm import tqdm

%matplotlib inline

from google.colab import drive
drive.mount('/content/drive')

def extract_target(file_name):
    age, gender = file_name.split('_')[:2]
    age, gender = float(age), int(gender)
    return age, gender

IMG_DIR = '/content/drive/MyDrive/age_gender_estimation/train_val'

train_img_names, val_img_names = train_test_split(os.listdir(IMG_DIR), test_size=0.25, random_state=42, shuffle=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
def run_inference(weights_path):
    model = load_model(weights_path)
    ages_correct, genders_correct = [], []
    ages_predicted, genders_predicted = [], []
    for img_name in tqdm(val_img_names):
        img = np.array(Image.open(os.path.join(IMG_DIR, img_name))) / 255
        img = img.reshape(1, 128, 128, 3)
        age, gender = extract_target(img_name)
        ages_correct.append(age), genders_correct.append(gender)
        model_pred = model.predict(img)
        age_pred = model_pred[0][0][0]
        gender_pred = model_pred[1][0][0]
        ages_predicted.append(age_pred), genders_predicted.append(round(gender_pred))
    print(f'\nAccuracy: {accuracy_score(genders_correct, genders_predicted)}')
    print(f'MAE: {mean_absolute_error(ages_correct, ages_predicted)}')

In [3]:
run_inference('/content/drive/MyDrive/age_gender_estimation/models/age_gender_A.h5')

100%|██████████| 1250/1250 [03:39<00:00,  5.70it/s]


Accuracy: 0.8544
MAE: 7.970671821308136





In [4]:
run_inference('/content/drive/MyDrive/age_gender_estimation/models/age_gender_B.h5')

100%|██████████| 1250/1250 [04:52<00:00,  4.27it/s]


Accuracy: 0.896
MAE: 5.932304301548005



