In [None]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)

import torch
import torch.utils.data as data

from age_regression import AgeRegressionModel
from age_regression import AllAgeFacesDataset
from age_regression import denormalize_image

In [None]:
def create_df(dataset_path: str, model_path: str):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'using {device} for inference')
    
    dataset = AllAgeFacesDataset(dataset_path, use_augmentation=False)
    
    model = AgeRegressionModel()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.eval().to(device)
    
    dataloader = data.DataLoader(dataset, batch_size=16, num_workers=3, pin_memory=True)
    
    gt_ages = []
    pred_ages = []
    images = []

    for batch_images, batch_gt_ages in dataloader:
        images += [batch_images.numpy()]
        
        with torch.no_grad():
            batch_images = batch_images.to(device, non_blocking=True)
            batch_pred_ages = model(batch_images).cpu()

            pred_ages += [batch_pred_ages.int().numpy()]
            gt_ages += [batch_gt_ages.int().numpy()]
            
    images = np.concatenate(images, axis=0)
            
    gt_ages = np.concatenate(gt_ages, axis=0)[:, 0]
    pred_ages = np.concatenate(pred_ages, axis=0)[:, 0]
    
    gt_ages = pd.Series(gt_ages, name='gt_ages')
    pred_ages = pd.Series(pred_ages, name='pred_ages')
    
    return dataset.max_age, images, gt_ages, pred_ages

In [None]:
dataset_path = '/home/will/code/datasets/faces/val'
model_path = '/home/will/code/AgeRegression/pretrained/model_age_regression_resnext101_20.pth'

In [None]:
max_age, images, gt_ages, pred_ages = create_df(dataset_path, model_path)

In [None]:
def display_images(images, gt_ages, pred_ages):
    num_cols = 4
    num_rows = int(math.ceil(len(images) / num_cols))

    fig = plt.figure(figsize=(5 * num_cols, 5 * num_rows))

    for idx, (image, gt_age, pred_age) in enumerate(zip(images, gt_ages, pred_ages)):
        ax = fig.add_subplot(num_rows, num_cols, idx + 1)
        ax.imshow(denormalize_image(image)[..., ::-1])
        ax.grid(None)

        title = f'g.t. age: {gt_age:.1f}\npred age: {pred_age:.1f}'
        ax.set_title(title)

    fig.tight_layout() 
    plt.savefig('age_regression.png')
    plt.show()    

display_images(images[:16], gt_ages[:16], pred_ages[:16])

In [None]:
plt = sns.jointplot(gt_ages, pred_ages, height=7, xlim=(0, max_age), ylim=(0, max_age))
plt.savefig('joint_plot.png')

In [None]:
plt = sns.distplot(gt_ages, kde=False)
plt.set_xlim(0, max_age)
plt.set_title('g.t. ages')
plt.figure.savefig('gt_ages_dist.png')

In [None]:
plt = sns.distplot(pred_ages, kde=False)
plt.set_xlim(0, max_age)
plt.set_title('pred ages')
plt.figure.savefig('pred_ages_dist.png')

In [None]:
df = pd.DataFrame({'gt_ages': gt_ages, 'pred_ages':pred_ages})
df['mae'] = (df['gt_ages'] - df['pred_ages']).abs()

def _age_group(age: int):
    age = 5 * (age // 5)
    return f'{age} - {age + 5}'

df['age_group'] = df['gt_ages'].apply(_age_group)
plt_order = [_age_group(age) for age in range(0, max_age + 1, 5)]
plt = sns.catplot(x='age_group', y="mae", kind='bar', data=df, size=7, order=plt_order)
plt.set_xticklabels(rotation=45)
plt.savefig('mae_for_age.png')