UTKFace Age-Prediction Pipeline

Imports & Global Settings

In [1]:
# === Imports & Global Settings (TensorFlow-first, with safe fallbacks) ===
import os, sys, atexit, math, random, logging, pickle, gc
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from matplotlib.container import BarContainer

# -- TensorFlow + Keras (prefer tf.keras; fall back to standalone Keras only if needed)
try:
    import tensorflow as tf
    from tensorflow.keras import layers, models, regularizers, callbacks
    from tensorflow.keras.models import load_model
    USING_TF_KERAS = True
except ModuleNotFoundError:
    # Fallback (rare on TF 2.16+, but harmless)
    import tensorflow as tf  # keep TF for seeding and device control if available
    import keras
    from keras import layers, models, regularizers, callbacks
    from keras.models import load_model
    USING_TF_KERAS = False

# --- Reproducibility ---
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
try:
    tf.random.set_seed(RANDOM_SEED)
except Exception:
    # In case TF fails at runtime for some reason, don’t crash the notebook
    pass

# --- Plot style ---
sns.set_theme(style="whitegrid")
rosa_palette = ["#F9C5D5", "#F7A1C4", "#F48FB1", "#F06292", "#EC407A"]

Data Loading

In [2]:
def load_utkface_dataset(image_dir="./data/images"):
    images, labels = [], []
    all_files = os.listdir(image_dir)
    random.shuffle(all_files)
    max_images = 24108                     # full UTKFace size

    for file_name in all_files[:max_images]:
        name = os.path.splitext(file_name)[0].replace(".chip", "")
        parts = name.split("_")
        if len(parts) != 4 or any(p == "" for p in parts):
            continue
        try:
            age, gender, race = int(parts[0]), int(parts[1]), int(parts[2])
            date = parts[3]
        except ValueError:
            continue

        if not (0 <= age <= 116 and gender in (0, 1) and 0 <= race <= 4 and len(date) == 17):
            continue

        img_path = os.path.join(image_dir, file_name)
        img = cv2.imread(img_path)
        if img is None:
            print(f"Failed to read: {file_name}")
            continue

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if img.ndim != 3 or img.shape[2] != 3:
            print(f"Unexpected format for: {file_name}, shape={img.shape}")
            continue

        images.append(img)
        labels.append({"age": age, "gender": gender, "race": race, "datetime": date})

    print(f"Loaded {len(images)} valid images out of {len(all_files)} files.")

    df = pd.DataFrame(labels)
    print("\n=== Dataset Summary ===")
    print(f"Total images: {len(df)}")
    print(f"Age range: {df['age'].min()} - {df['age'].max()} (mean: {df['age'].mean():.1f})")
    print(f"Gender distribution:\n{df['gender'].value_counts()}")
    print(f"Race distribution:\n{df['race'].value_counts()}")
    print(f"Average image size (HxW): {np.mean([img.shape[:2] for img in images], axis=0).astype(int)}")

    return images, labels, df

# Load the data
images, labels, df = load_utkface_dataset("./data/images")

Loaded 24099 valid images out of 24108 files.

=== Dataset Summary ===
Total images: 24099
Age range: 1 - 116 (mean: 33.0)
Gender distribution:
gender
0    12578
1    11521
Name: count, dtype: int64
Race distribution:
race
0    10220
1     4556
3     4027
2     3585
4     1711
Name: count, dtype: int64
Average image size (HxW): [662 637]


Visualisation Helpers

In [3]:
def plot_random_samples(images, labels, gender_map, race_map):
    # Show 9 random images with age/gender/race caption
    plt.figure(figsize=(10, 6))
    for i, idx in enumerate(random.sample(range(len(images)), min(9, len(images)))):
        img = images[idx]
        label = labels[idx]
        plt.subplot(3, 3, i + 1)
        plt.imshow(img)
        plt.axis("off")
        plt.title(
            f'Age: {label["age"]}\n'
            f'Gender: {gender_map.get(label.get("gender"), label.get("gender"))} '
            f'Race: {race_map.get(label.get("race"), label.get("race"))}',
            fontsize=10, color="#4A4A4A"
        )
    plt.suptitle("Random Sample of UTKFace Images", fontsize=14, weight="bold")
    plt.tight_layout()
    plt.show()


def plot_distribution_charts(df):
    # Age histogram + gender / race bar-plots + box-plots

    # Age
    plt.figure(figsize=(8, 5))
    sns.histplot(df["age"], color="#F48FB1", kde=True)
    plt.title('Age Distribution', fontsize=14, weight='bold')
    plt.xlabel('Age'); plt.ylabel('Frequency')
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.show()

    # Gender
    plt.figure(figsize=(6, 5))
    ax = sns.countplot(data=df, x="gender", hue="gender",
                       palette=["#F9C5D5", "#F48FB1"], legend=False)
    for c in ax.containers:
        if isinstance(c, BarContainer):
            ax.bar_label(c, fmt='%d', fontsize=10)
    plt.title("Gender Balance", fontsize=14, weight='bold')
    plt.xlabel("Gender"); plt.ylabel("Count")
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.show()

    # Race
    plt.figure(figsize=(7, 5))
    ax = sns.countplot(data=df, x="race", hue="race", palette=rosa_palette, legend=False)
    for c in ax.containers:
        if isinstance(c, BarContainer):
            ax.bar_label(c, fmt='%d', fontsize=10)
    plt.title("Race Balance", fontsize=14, weight='bold')
    plt.xlabel("Race"); plt.ylabel("Count")
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.show()

    # Age by gender
    plt.figure(figsize=(8, 5))
    sns.boxplot(data=df, x="gender", y="age", hue="gender",
                palette=["#F9C5D5", "#F48FB1"], legend=False, showfliers=True)
    plt.title("Age Distribution by Gender", fontsize=14, weight='bold')
    plt.xlabel("Gender"); plt.ylabel("Age")
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.show()

    # Age by race
    plt.figure(figsize=(9, 5))
    sns.boxplot(data=df, x="race", y="age", hue="race",
                palette=rosa_palette, legend=False, showfliers=True)
    plt.title("Age Distribution by Race", fontsize=14, weight='bold')
    plt.xlabel("Race"); plt.ylabel("Age")
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.show()


def plot_age_bin_distribution(bins_train, bins_val, bins_test, AGE_BINS, rosa_palette=None):
    #Bar-plots (counts & %) of stratified age-bin splits
    if rosa_palette is None:
        rosa_palette = ["#F9C5D5", "#F48FB1", "#EC407A"]
    K = len(AGE_BINS) - 1
    bin_labels = [f"{AGE_BINS[i]}–{AGE_BINS[i+1]-1}" if i < K-1 else f"{AGE_BINS[i]}+"
                  for i in range(K)]

    def counts_in_order(counter, K):
        return np.array([counter.get(i, 0) for i in range(K)], dtype=np.int32)

    train_c = counts_in_order(Counter(bins_train), K)
    val_c   = counts_in_order(Counter(bins_val),   K)
    test_c  = counts_in_order(Counter(bins_test),  K)

    x = np.arange(K); w = 0.25

    # counts
    plt.figure(figsize=(12,5))
    plt.bar(x-w, train_c, width=w, label="Train", color=rosa_palette[0])
    plt.bar(x,   val_c,   width=w, label="Val",   color=rosa_palette[1])
    plt.bar(x+w, test_c,  width=w, label="Test",  color=rosa_palette[2])
    plt.xticks(x, bin_labels, rotation=0, fontsize=11)
    plt.xlabel("Age Bins"); plt.ylabel("Count")
    plt.title("Age-bin Distribution (Counts)", weight="bold")
    plt.legend(); plt.grid(True, linestyle='--', alpha=0.3)
    plt.tight_layout(); plt.show()

    # percentages
    train_p = train_c/train_c.sum()*100
    val_p   = val_c/val_c.sum()*100
    test_p  = test_c/test_c.sum()*100

    plt.figure(figsize=(12,5))
    plt.bar(x-w, train_p, width=w, label="Train", color=rosa_palette[0])
    plt.bar(x,   val_p,   width=w, label="Val",   color=rosa_palette[1])
    plt.bar(x+w, test_p,  width=w, label="Test",  color=rosa_palette[2])
    plt.xticks(x, bin_labels, rotation=0, fontsize=11)
    plt.xlabel("Age Bins"); plt.ylabel("Share (%)")
    plt.title("Age-bin Distribution (Percent)", weight="bold")
    for i,(tp,vp,sp) in enumerate(zip(train_p,val_p,test_p)):
        plt.text(i-w, tp+0.5, f"{tp:.1f}%", ha='center', va='bottom', fontsize=9)
        plt.text(i,   vp+0.5, f"{vp:.1f}%", ha='center', va='bottom', fontsize=9)
        plt.text(i+w, sp+0.5, f"{sp:.1f}%", ha='center', va='bottom', fontsize=9)
    plt.legend(); plt.grid(True, linestyle='--', alpha=0.3)
    plt.tight_layout(); plt.show()


def plot_avg_sample_weight_per_bin(train_weights, bins_train, AGE_BINS, bar_color="#EC407A"):
    # Average inverse-frequency weight per age bin
    unique = sorted(set(bins_train))
    avg_w  = [train_weights[np.array(bins_train)==b].mean() for b in unique]
    bin_labels = [f"{AGE_BINS[i]}–{AGE_BINS[i+1]-1}" if i<len(AGE_BINS)-2 else f"{AGE_BINS[i]}+"
                  for i in unique]

    plt.figure(figsize=(10,5))
    plt.bar(bin_labels, avg_w, color=bar_color, alpha=0.8, edgecolor="white", linewidth=1.2)
    plt.xlabel("Age Bins"); plt.ylabel("Average Sample Weight")
    plt.title("Inverse-Frequency Sample Weights per Age Bin", weight="bold")
    plt.grid(axis="y", linestyle="--", alpha=0.3)
    for i,v in enumerate(avg_w):
        plt.text(i, v+max(avg_w)*0.02, f"{v:.2f}", ha='center', va='bottom', fontsize=9)
    plt.tight_layout(); plt.show()


def plot_augmented_samples(batch_X, batch_y_raw, n_rows=2, n_cols=4, title="Augmented Sample Previews"):
    # Grid of augmented images with raw age label
    fig, axes = plt.subplots(n_rows, n_cols,
                             figsize=(n_cols*2.5, n_rows*2.5), constrained_layout=True)
    fig.suptitle(title, fontsize=14, weight="bold", y=1.02)
    for ax, img, age in zip(axes.ravel(), batch_X, batch_y_raw):
        ax.imshow(np.clip(img,0,1))
        ax.set_title(f"Age: {int(age)}", fontsize=10, pad=4)
        ax.axis("off")
    plt.show()


def plot_training_history(history, title_suffix=""):
    # Loss & MAE curves
    plt.figure(figsize=(14,6))
    train_c, val_c = "#F48FB1", "#F9C5D5"

    # loss
    plt.subplot(1,2,1)
    plt.plot(history.history['loss'], label='Train Loss', color=train_c, linewidth=2)
    plt.plot(history.history['val_loss'], label='Val Loss', color=val_c, linewidth=2)
    plt.title(f'Loss Curve {title_suffix}', weight='bold')
    plt.xlabel('Epoch'); plt.ylabel('Huber Loss')
    plt.legend(); plt.grid(True, linestyle='--', alpha=0.3)

    # MAE
    plt.subplot(1,2,2)
    plt.plot(history.history['mae'], label='Train MAE', color=train_c, linewidth=2)
    plt.plot(history.history['val_mae'], label='Val MAE', color=val_c, linewidth=2)
    plt.title(f'MAE Curve {title_suffix}', weight='bold')
    plt.xlabel('Epoch'); plt.ylabel('MAE')
    plt.legend(); plt.grid(True, linestyle='--', alpha=0.3)

    plt.tight_layout(); plt.show()

Augmentation

In [4]:
# Global objects, will be set from preprocessing
rng = None
NORMALIZE_01 = True

def set_augment_seed(random_generator, normalize=True):
    # Inject RNG and normalisation flag
    global rng, NORMALIZE_01
    rng = random_generator
    NORMALIZE_01 = normalize


def random_hflip(img, p=0.5):
    if rng.random() < p:
        return np.ascontiguousarray(img[:, ::-1, :])
    return img


def random_rotate(img, max_deg=10):
    deg = float(rng.uniform(-max_deg, max_deg))
    h, w = img.shape[:2]
    M = cv2.getRotationMatrix2D((w/2, h/2), deg, 1.0)
    return cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_LINEAR,
                         borderMode=cv2.BORDER_REFLECT_101)


def random_crop_and_resize(img, scale=(0.88, 1.0)):
    h, w = img.shape[:2]
    s = float(rng.uniform(scale[0], scale[1]))
    new_h, new_w = int(h*s), int(w*s)
    max_y = max(h - new_h, 0)
    max_x = max(w - new_w, 0)
    y0 = int(rng.integers(0, max_y+1)) if max_y>0 else 0
    x0 = int(rng.integers(0, max_x+1)) if max_x>0 else 0
    crop = img[y0:y0+new_h, x0:x0+new_w, :]
    return cv2.resize(crop, (w, h), interpolation=cv2.INTER_AREA)


def random_brightness_contrast(img, b_lim=0.15, c_lim=0.15, p=0.8):
    if rng.random() > p:
        return img
    brightness = float(rng.uniform(-b_lim, b_lim))
    contrast = 1.0 + float(rng.uniform(-c_lim, c_lim))
    out = img * contrast + brightness
    return np.clip(out, 0.0, 1.0)


def add_gaussian_noise(img, sigma=0.02, p=0.3):
    if rng.random() > p:
        return img
    noise = rng.normal(0.0, sigma, img.shape).astype(np.float32)
    out = img + noise
    return np.clip(out, 0.0, 1.0)


def augment_once(img_uint8):
    y = random_hflip(img_uint8, p=0.5)
    y = random_rotate(y, max_deg=10)
    y = random_crop_and_resize(y, scale=(0.88, 1.0))
    y = y.astype(np.float32) / 255.0 if NORMALIZE_01 else y.astype(np.float32)
    y = random_brightness_contrast(y, b_lim=0.15, c_lim=0.15, p=0.8)
    y = add_gaussian_noise(y, sigma=0.02, p=0.3)
    return y

Pre-processing

In [5]:
def preprocess_images_to_memmap(images, labels, target_size=(160,160),
                               save_X_path=None, save_y_path=None):

    if save_X_path is None or save_y_path is None:
        raise ValueError("Must provide save_X_path and save_y_path")
    os.makedirs(os.path.dirname(save_X_path), exist_ok=True)

    n = len(images)
    w, h = target_size
    print(f"Writing memmap files ({w}x{h}) ... (n={n})")

    if os.path.exists(save_X_path) and os.path.exists(save_y_path):
        print(f"Memmap files already exist – skipping creation.")
        return save_X_path, save_y_path

    X_mm = np.memmap(save_X_path, dtype=np.float16, mode="w+",
                     shape=(n, h, w, 3))
    y_mm = np.memmap(save_y_path, dtype=np.float32, mode="w+",
                     shape=(n,))

    for i, img in enumerate(images):
        if img is None:
            raise ValueError(f"Image {i} is None")
        resized = cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
        X_mm[i] = (resized.astype(np.float32) / 255.0).astype(np.float16)
        y_mm[i] = float(labels[i]["age"])

        if (i+1) % 1000 == 0 or (i+1) == n:
            print(f" Processed {i+1}/{n} images", end="\r")

    del X_mm, y_mm
    gc.collect()
    print("\nMemmap creation finished.")
    return save_X_path, save_y_path


def load_and_split_from_memmap(X_path, y_path, n_samples,
                               target_size=(160,160), normalize_01=True,
                               random_seed=42, age_bins=None,
                               save_split_info=True,
                               split_info_path="./dataset_split/dataset_split_info.pkl"):
    """
    Load memmaps, perform stratified split, standardise ages,
    and return a dict with all objects (memmap views – no RAM copy).
    """

    # cleanup old globals (safety)
    for name in ["X","X_tmp","X_train","X_val","X_test",
                 "dbg_X","dbg_y","dbg_w"]:
        globals().pop(name, None)
    gc.collect()

    rng = np.random.default_rng(random_seed)
    NORMALIZE_01 = normalize_01
    AGE_BINS = age_bins or [0,5,12,18,30,45,60,80,200]

    #  read-only memmaps
    X_all = np.memmap(X_path, dtype=np.float16, mode="r",
                      shape=(n_samples, target_size[1], target_size[0], 3))
    y_all = np.memmap(y_path, dtype=np.float32, mode="r",
                      shape=(n_samples,))

    ages = np.array(y_all, dtype=np.float32)          # small – fits in RAM
    bins_idx = np.digitize(ages, AGE_BINS, right=False) - 1
    idx_all = np.arange(len(ages))

    # stratified splits
    idx_tmp, idx_test, ages_tmp, ages_test, bins_tmp, bins_test = train_test_split(
        idx_all, ages, bins_idx,
        test_size=0.15, random_state=random_seed, stratify=bins_idx)

    val_ratio = 0.15 / (1.0 - 0.15)
    idx_train, idx_val, ages_train, ages_val, bins_train, bins_val = train_test_split(
        idx_tmp, ages_tmp, bins_tmp,
        test_size=val_ratio, random_state=random_seed, stratify=bins_tmp)

    # memmap views (no copy)
    X_train = X_all[idx_train]
    X_val   = X_all[idx_val]
    X_test  = X_all[idx_test]

    y_train = ages_train
    y_val   = ages_val
    y_test  = ages_test

    # age standardisation (train stats only)
    age_mean = y_train.mean()
    age_std  = y_train.std()
    y_train_std = (y_train - age_mean) / age_std
    y_val_std   = (y_val   - age_mean) / age_std
    y_test_std  = (y_test  - age_mean) / age_std

    print(f"Age mean: {age_mean:.2f}, std: {age_std:.2f}")
    print("Train/Val/Test shapes:", X_train.shape, X_val.shape, X_test.shape)
    print("Bin counts (train):", Counter(bins_train))
    print("Bin counts (val):",   Counter(bins_val))
    print("Bin counts (test):",  Counter(bins_test))

    # save split info (for later evaluation)
    if save_split_info:
        split_info = {
            "idx_train": idx_train,
            "idx_val":   idx_val,
            "idx_test":  idx_test,
            "age_mean":  float(age_mean),
            "age_std":   float(age_std),
            "y_all":     ages,
            "AGE_BINS":  AGE_BINS,
            "RANDOM_SEED": random_seed
        }
        os.makedirs(os.path.dirname(split_info_path), exist_ok=True)
        with open(split_info_path, "wb") as f:
            pickle.dump(split_info, f)
        print(f"Split info saved → {split_info_path}")

    return {
        "rng": rng,
        "NORMALIZE_01": NORMALIZE_01,
        "AGE_BINS": AGE_BINS,
        "bins_train": bins_train,
        "bins_val":   bins_val,
        "bins_test":  bins_test,
        "X_train": X_train,
        "X_val":   X_val,
        "X_test":  X_test,
        "y_train": y_train,
        "y_val":   y_val,
        "y_test":  y_test,
        "age_mean": age_mean,
        "age_std":  age_std,
        "y_train_std": y_train_std,
        "y_val_std":   y_val_std,
        "y_test_std":  y_test_std
    }

Model Definition

In [6]:
def build_model(input_shape=(160,160,3)):
    """
    CNN with BatchNorm, LeakyReLU, SpatialDropout, L2 regularisation.
    Output: single linear neuron (standardised age).
    """
    model = models.Sequential([
        # Block 1
        layers.Conv2D(32, (3,3), padding='same', input_shape=input_shape),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.1),
        layers.Conv2D(32, (3,3), padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.1),
        layers.MaxPooling2D((2,2)),
        layers.SpatialDropout2D(0.25),

        # Block 2
        layers.Conv2D(64, (3,3), padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.1),
        layers.Conv2D(64, (3,3), padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.1),
        layers.MaxPooling2D((2,2)),
        layers.SpatialDropout2D(0.3),

        # Block 3
        layers.Conv2D(128, (3,3), padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.1),
        layers.Conv2D(128, (3,3), padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.1),
        layers.MaxPooling2D((2,2)),
        layers.SpatialDropout2D(0.35),

        # Block 4
        layers.Conv2D(256, (3,3), padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.1),
        layers.MaxPooling2D((2,2)),
        layers.SpatialDropout2D(0.4),

        # Dense head
        layers.GlobalAveragePooling2D(),
        layers.Dense(256, activation='relu',
                     kernel_regularizer=regularizers.l2(1e-3)),
        layers.BatchNormalization(),
        layers.Dropout(0.5),

        layers.Dense(64, activation='relu',
                     kernel_regularizer=regularizers.l2(1e-3)),
        layers.BatchNormalization(),
        layers.Dropout(0.4),

        layers.Dense(32, activation='relu',
                     kernel_regularizer=regularizers.l2(1e-3)),
        layers.BatchNormalization(),
        layers.Dropout(0.4),

        layers.Dense(1, activation='linear', dtype='float32')
    ])

    opt = tf.keras.optimizers.Adam(learning_rate=5e-4)
    model.compile(optimizer=opt,
                  loss=tf.keras.losses.Huber(),
                  metrics=[tf.keras.metrics.MeanAbsoluteError(name='mae'), 'mse'])
    return model

Training

In [7]:
def train_model(model, train_gen, val_gen, y_train_std, y_val_std,
                batch_size=16, epochs=50,
                save_path="./best_model/best_model.keras"):

    steps_per_epoch = len(y_train_std) // batch_size
    validation_steps = len(y_val_std) // batch_size

    print(f"Training samples : {len(y_train_std)}")
    print(f"Validation samples: {len(y_val_std)}")
    print(f"Steps/epoch : {steps_per_epoch}  |  Validation steps: {validation_steps}")

    ckpt = callbacks.ModelCheckpoint(save_path,
                                     save_best_only=True,
                                     monitor="val_mae",
                                     mode="min")
    rlr  = callbacks.ReduceLROnPlateau(monitor="val_loss",
                                       factor=0.5, patience=7,
                                       min_lr=1e-6, verbose=1)
    es   = callbacks.EarlyStopping(monitor="val_loss",
                                   patience=15,
                                   restore_best_weights=True)

    history = model.fit(train_gen,
                        steps_per_epoch=steps_per_epoch,
                        validation_data=val_gen,
                        validation_steps=validation_steps,
                        epochs=epochs,
                        callbacks=[ckpt, rlr, es],
                        verbose=1)
    return history

Utilities

In [8]:
# requires augment_once from the augmentation cell
def make_sample_weights(bins):
    counts = Counter(bins)
    n = len(bins)
    K = len(counts)
    return np.array([n / (K * counts[b]) for b in bins], dtype=np.float32)

def train_batch_generator(X, y, batch_size, weights=None, shuffle=True):
    n = len(y)
    order = np.arange(n)
    while True:
        if shuffle:
            np.random.shuffle(order)
        for start in range(0, n, batch_size):
            sel = order[start:start + batch_size]
            bx = np.empty((len(sel),) + X.shape[1:], dtype=np.float32)
            for i, j in enumerate(sel):
                img = (X[j] * 255.0).astype(np.uint8)
                aug = augment_once(img)   # from your augmentation cell
                bx[i] = aug
            by = y[sel]
            if weights is not None:
                bw = weights[sel]
                yield bx, by, bw
            else:
                yield bx, by

def val_batch_iterator(X, y, batch_size, normalize=True):
    n = len(y)
    for start in range(0, n, batch_size):
        sel = slice(start, start + batch_size)
        bx = X[sel].astype(np.float32)
        if not normalize:
            bx = (bx * 255.0)
        by = y[sel]
        yield bx, by

def evaluate_model_on_test(model, X_test, y_test_std, age_mean, age_std, history=None, normalize=True, verbose=True):
    X_test_array = X_test.astype(np.float32)
    y_test_array = y_test_std
    test_loss, test_mae, test_mse = model.evaluate(X_test_array, y_test_array, verbose=verbose)
    y_pred_std = model.predict(X_test_array, verbose=0)
    y_pred_raw = y_pred_std * age_std + age_mean
    y_test_raw = y_test_array * age_std + age_mean
    raw_mae = np.mean(np.abs(y_pred_raw.flatten() - y_test_raw.flatten()))
    raw_mse = np.mean((y_pred_raw.flatten() - y_test_raw.flatten())**2)
    if verbose:
        print(f"\n--- Test Set Evaluation ---")
        print(f"Standardized MAE: {test_mae:.4f}")
        print(f"Standardized MSE: {test_mse:.4f}")
        print(f"Raw MAE: {raw_mae:.2f} years")
        print(f"Raw MSE: {raw_mse:.2f}")
        if history is not None:
            print(f"\nTraining epochs: {len(history.history['loss'])}")
            print(f"Min val MAE: {min(history.history['val_mae']):.4f}")
    return {
        'test_loss': float(test_loss),
        'test_mae': float(test_mae),
        'test_mse': float(test_mse),
        'raw_mae': float(raw_mae),
        'raw_mse': float(raw_mse),
    }


**MAIN PIPELINE**

In [None]:
# -------------------------------------------------
# 1. Load raw images
# -------------------------------------------------
images, labels, df = load_utkface_dataset("./data/images")

In [None]:
# -------------------------------------------------
# 2. Visualise raw data
# -------------------------------------------------
gender_map = {0:"Male", 1:"Female"}
race_map   = {0:"White",1:"Black",2:"Asian",3:"Indian",4:"Other"}

plot_random_samples(images, labels, gender_map, race_map)
plot_distribution_charts(df)

In [None]:
# -------------------------------------------------
# 3. Memmap preprocessing
# -------------------------------------------------
target_size = (192, 192)
memmap_dir  = "./data/memmap"
size_tag    = f"{target_size[0]}x{target_size[1]}"
X_path = os.path.join(memmap_dir, f"X_resized_{size_tag}.dat")
y_path = os.path.join(memmap_dir, f"y_resized_{size_tag}.dat")

if not (os.path.exists(X_path) and os.path.exists(y_path)):
    preprocess_images_to_memmap(images, labels,
                                target_size=target_size,
                                save_X_path=X_path,
                                save_y_path=y_path)
else:
    print("Memmap files already exist – skipping creation.")

n_samples = len(images)

preproc = load_and_split_from_memmap(
    X_path=X_path, y_path=y_path, n_samples=n_samples,
    target_size=target_size, normalize_01=True)

# unpack
rng           = preproc["rng"]
NORMALIZE_01  = preproc["NORMALIZE_01"]
AGE_BINS      = preproc["AGE_BINS"]
bins_train    = preproc["bins_train"]
bins_val      = preproc["bins_val"]
bins_test     = preproc["bins_test"]
X_train, X_val, X_test = preproc["X_train"], preproc["X_val"], preproc["X_test"]
y_train, y_val, y_test = preproc["y_train"], preproc["y_val"], preproc["y_test"]
y_train_std = preproc["y_train_std"]
y_val_std   = preproc["y_val_std"]
y_test_std  = preproc["y_test_std"]
age_mean    = preproc["age_mean"]
age_std     = preproc["age_std"]

# visualise stratified splits
plot_age_bin_distribution(bins_train, bins_val, bins_test, AGE_BINS)

In [None]:
# -------------------------------------------------
# 4. Augmentation seed
# -------------------------------------------------
set_augment_seed(rng, NORMALIZE_01)

In [None]:
# -------------------------------------------------
# 5. Sample weighting
# -------------------------------------------------
train_weights = make_sample_weights(bins_train)
print("Example weights (first 10):", train_weights[:10])
plot_avg_sample_weight_per_bin(train_weights, bins_train, AGE_BINS)

In [None]:
# -------------------------------------------------
# 6. Batch generators
# -------------------------------------------------
batch_size = 16
train_gen = train_batch_generator(X_train, y_train_std,
                                 batch_size=batch_size,
                                 weights=train_weights)
val_gen   = val_batch_iterator(X_val, y_val_std,
                               batch_size=batch_size)

# preview a batch
dbg_batch = next(train_gen)
if len(dbg_batch) == 3:
    dbg_X, dbg_y_std, _ = dbg_batch
else:
    dbg_X, dbg_y_std = dbg_batch
dbg_y_raw = y_train[:len(dbg_y_std)]          # raw ages for caption
plot_augmented_samples(dbg_X, dbg_y_raw)

In [None]:
# -------------------------------------------------
# 7. Build & compile model
# -------------------------------------------------
model = build_model(input_shape=target_size + (3,))
model.summary()

In [None]:
# -------------------------------------------------
# 8. Train
# -------------------------------------------------
history = train_model(model, train_gen, val_gen,
                      y_train_std, y_val_std,
                      batch_size=batch_size, epochs=50)

plot_training_history(history)

In [None]:
# -------------------------------------------------
# 9. Test evaluation (standardised + raw)
# -------------------------------------------------
test_metrics = evaluate_model_on_test(
    model, X_test, y_test_std,
    age_mean, age_std,
    history=history, verbose=True)

print("\nFinal test metrics dict:")
print(test_metrics)

Evaluation

In [None]:
IMAGE_DIR   = "data/images"
MODEL_PATH  = "best_model/best_model.keras"
OUTPUT_DIR  = "evaluation"
RANDOM_SEED = 84
TARGET_SIZE = (192,192)
NORMALIZE_01 = True
BATCH_SIZE   = 32
AGE_BINS     = [0,5,12,18,30,45,60,80,200]
SHOW_WORST_N = 8

os.makedirs(OUTPUT_DIR, exist_ok=True)

# ---------- logging tee ----------
log_path = os.path.join(OUTPUT_DIR, "run_log.txt")
_log_file = open(log_path, "w", encoding="utf-8")
class Tee:
    def __init__(self, *streams): self.streams = streams
    def write(self, data):
        for s in self.streams:
            try: s.write(data)
            except: pass
    def flush(self):
        for s in self.streams:
            try: s.flush()
            except: pass

_original_stdout = sys.stdout
_original_stderr = sys.stderr
sys.stdout = Tee(_original_stdout, _log_file)
sys.stderr = Tee(_original_stderr, _log_file)

def _cleanup():
    sys.stdout = _original_stdout
    sys.stderr = _original_stderr
    _log_file.flush(); _log_file.close()
atexit.register(_cleanup)

logger = logging.getLogger("evaluate_logger")
logger.setLevel(logging.INFO)
if not logger.handlers:
    fh = logging.FileHandler(log_path, encoding="utf-8")
    ch = logging.StreamHandler(_original_stdout)
    fmt = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s",
                            "%Y-%m-%d %H:%M:%S")
    fh.setFormatter(fmt); ch.setFormatter(fmt)
    logger.addHandler(fh); logger.addHandler(ch)

np.random.seed(RANDOM_SEED); random.seed(RANDOM_SEED)
sns.set_theme(style="whitegrid")
logger.info("Evaluation start – log → %s", log_path)

# ---------- helper ----------
def parse_filename_meta(fname):
    base = os.path.splitext(fname)[0]
    parts = base.split("_")
    if len(parts) < 3: return None
    try:
        age, gender, race = int(parts[0]), int(parts[1]), int(parts[2])
    except: return None
    return {"filename":fname, "age":age, "gender":gender, "race":race}

# ---------- 1. load model ----------
logger.info("Loading model %s", MODEL_PATH)
model = load_model(MODEL_PATH)
_dummy = np.zeros((1, TARGET_SIZE[1], TARGET_SIZE[0], 3), dtype=np.float32)
if NORMALIZE_01: _dummy /= 255.0
model.predict(_dummy, verbose=0)        # ensure built

# ---------- 2. load test memmap ----------
with open("dataset_split/dataset_split_info.pkl", "rb") as f:
    split_info = pickle.load(f)
idx_test   = split_info["idx_test"]
age_mean   = split_info["age_mean"]
age_std    = split_info["age_std"]

size_tag = f"{TARGET_SIZE[0]}x{TARGET_SIZE[1]}"
X_path = f"data/memmap/X_resized_{size_tag}.dat"
y_path = f"data/memmap/y_resized_{size_tag}.dat"

n_samples = len(split_info["idx_train"]) + len(split_info["idx_val"]) + len(split_info["idx_test"])
X_all = np.memmap(X_path, dtype=np.float16, mode="r",
                  shape=(n_samples, TARGET_SIZE[1], TARGET_SIZE[0], 3))
y_all = np.memmap(y_path, dtype=np.float32, mode="r", shape=(n_samples,))

X_test = X_all[idx_test].astype(np.float32)
y_real = np.array(y_all[idx_test], dtype=np.float32)
y_std  = (y_real - age_mean) / age_std
logger.info("Test set loaded – %d images", X_test.shape[0])

# reconstruct metadata (for CSV & bias analysis)
meta_list = []
for i in idx_test:
    fname = sorted(os.listdir(IMAGE_DIR))[i]
    meta = parse_filename_meta(fname)
    if meta: meta_list.append(meta)
meta_df = pd.DataFrame(meta_list)

# ---------- 3. predict ----------
logger.info("Predicting …")
y_pred_std = model.predict(X_test, batch_size=BATCH_SIZE, verbose=1).flatten()
y_pred = y_pred_std * age_std + age_mean

mae = mean_absolute_error(y_real, y_pred)
mse = mean_squared_error(y_real, y_pred)
rmse = math.sqrt(mse)
r2  = r2_score(y_real, y_pred)

logger.info("=== Metrics (real ages) ===")
logger.info("MAE  : %.4f years", mae)
logger.info("MSE  : %.4f", mse)
logger.info("RMSE : %.4f", rmse)
logger.info("R²   : %.4f", r2)

# ---------- 4. CSV ----------
results_df = meta_df.copy()
results_df["true_age_real"] = y_real
results_df["true_age_std"]  = y_std
results_df["pred_age_real"] = y_pred
results_df["pred_age_std"]  = y_pred_std
results_df["error_real"]    = y_pred - y_real
results_df["abs_error_real"]= np.abs(results_df["error_real"])
csv_path = os.path.join(OUTPUT_DIR, "test_predictions.csv")
results_df.to_csv(csv_path, index=False)
logger.info("Saved predictions → %s", csv_path)

# ---------- 5. Plot helpers ----------
def savefig(pth):
    plt.savefig(pth, dpi=150, bbox_inches="tight")
    logger.info("Saved plot → %s", pth)
    plt.close()

# ---- scatter ----
lims = [0, max(y_real.max(), y_pred.max())+5]
plt.figure(figsize=(7,7))
plt.scatter(y_real, y_pred, s=20, alpha=0.5, edgecolor="w")
plt.plot(lims, lims, 'k--', alpha=0.7, label="Perfect")
plt.xlim(lims); plt.ylim(lims)
plt.xlabel("True Age"); plt.ylabel("Predicted Age")
plt.title("Predicted vs True Age")
plt.legend(); plt.grid(True, linestyle="--", alpha=0.3)
savefig(os.path.join(OUTPUT_DIR, "pred_vs_true_scatter.png"))

# ---- hexbin ----
plt.figure(figsize=(7,7))
hb = plt.hexbin(y_real, y_pred, gridsize=50, cmap="Reds", mincnt=1)
cb = plt.colorbar(hb); cb.set_label("Counts")
plt.plot(lims, lims, 'k--', alpha=0.7)
plt.xlim(lims); plt.ylim(lims)
plt.xlabel("True Age"); plt.ylabel("Predicted Age")
plt.title("Predicted vs True (density)")
plt.grid(True, linestyle="--", alpha=0.3)
savefig(os.path.join(OUTPUT_DIR, "pred_vs_true_density.png"))

# ---- residuals ----
residuals = y_pred - y_real
plt.figure(figsize=(8,5))
plt.scatter(y_real, residuals, alpha=0.5, s=20, edgecolor="w")
plt.axhline(0, color='k', linestyle='--', alpha=0.6)
plt.xlabel("True Age"); plt.ylabel("Residual (Pred-True)")
plt.title("Residuals vs True Age")
plt.grid(True, linestyle="--", alpha=0.3)
savefig(os.path.join(OUTPUT_DIR, "residuals_vs_age.png"))

# ---- residual histogram ----
plt.figure(figsize=(8,5))
sns.histplot(residuals, bins=40, kde=True, color="#FF8A65")
plt.xlabel("Residual"); plt.ylabel("Count")
plt.title("Residual Distribution")
plt.grid(True, linestyle="--", alpha=0.3)
savefig(os.path.join(OUTPUT_DIR, "residual_histogram.png"))

# ---- MAE per age bin ----
bin_idx = np.digitize(y_real, AGE_BINS, right=False) - 1
bin_labels = [f"{AGE_BINS[i]}–{AGE_BINS[i+1]-1}" if i<len(AGE_BINS)-2 else f"{AGE_BINS[i]}+"
              for i in range(len(AGE_BINS)-1)]
bin_mae = [np.mean(np.abs(y_pred[bin_idx==i] - y_real[bin_idx==i]))
           if np.any(bin_idx==i) else np.nan
           for i in range(len(AGE_BINS)-1)]

plt.figure(figsize=(10,5))
sns.barplot(x=bin_labels, y=bin_mae, palette="rocket")
plt.xlabel("Age Bin"); plt.ylabel("MAE (years)")
plt.title("MAE per Age Bin")
plt.grid(True, linestyle="--", alpha=0.3)
savefig(os.path.join(OUTPUT_DIR, "mae_per_age_bin.png"))

# ---- worst predictions ----
worst_idx = np.argsort(-results_df["abs_error_real"].values)[:SHOW_WORST_N]
plt.figure(figsize=(12,6))
for i, idx in enumerate(worst_idx):
    ax = plt.subplot(2, (SHOW_WORST_N+1)//2, i+1)
    img = X_test[idx]
    if NORMALIZE_01: img = np.clip(img,0,1)
    ax.imshow(img)
    true = results_df.loc[idx, "true_age_real"]
    pred = results_df.loc[idx, "pred_age_real"]
    err  = results_df.loc[idx, "abs_error_real"]
    ax.set_title(f"T:{true:.1f} / P:{pred:.1f}\nΔ={err:.1f}", fontsize=10)
    ax.axis("off")
plt.suptitle("Hardest Predictions", weight="bold")
savefig(os.path.join(OUTPUT_DIR, "hardest_predictions.png"))

# ---- summary txt ----
summary_path = os.path.join(OUTPUT_DIR, "results_summary.txt")
with open(summary_path, "w", encoding="utf-8") as f:
    f.write("=== Test Set Summary (Real Ages) ===\n")
    f.write(f"Images : {len(y_real)}\n")
    f.write(f"MAE    : {mae:.6f} years\n")
    f.write(f"MSE    : {mse:.6f}\n")
    f.write(f"RMSE   : {rmse:.6f}\n")
    f.write(f"R²     : {r2:.6f}\n")
logger.info("Summary → %s", summary_path)

# ---- demographic bias ----
gender_map = {0:"Male",1:"Female"}
race_map   = {0:"White",1:"Black",2:"Asian",3:"Indian",4:"Other"}
demo_df = results_df.copy()
if "gender" in meta_df.columns and "race" in meta_df.columns:
    demo_df["gender"] = meta_df["gender"]
    demo_df["race"]   = meta_df["race"]
demo_df["gender_str"] = demo_df["gender"].map(gender_map).fillna(demo_df["gender"].astype(str))
demo_df["race_str"]   = demo_df["race"].map(race_map).fillna(demo_df["race"].astype(str))

gender_stats = demo_df.groupby("gender_str").agg(
    MAE=("abs_error_real","mean"),
    RMSE=("abs_error_real", lambda x: math.sqrt(np.mean(x**2))),
    Count=("abs_error_real","count")
).reset_index()
gender_stats.to_csv(os.path.join(OUTPUT_DIR, "gender_error_stats.csv"), index=False)
plt.figure(figsize=(6,5))
sns.barplot(data=gender_stats, x="gender_str", y="MAE",
            palette=["#F9C5D5","#F48FB1"])
plt.title("MAE by Gender"); plt.xlabel("Gender"); plt.ylabel("MAE (years)")
savefig(os.path.join(OUTPUT_DIR, "mae_by_gender.png"))

race_stats = demo_df.groupby("race_str").agg(
    MAE=("abs_error_real","mean"),
    RMSE=("abs_error_real", lambda x: math.sqrt(np.mean(x**2))),
    Count=("abs_error_real","count")
).reset_index()
race_stats.to_csv(os.path.join(OUTPUT_DIR, "race_error_stats.csv"), index=False)
plt.figure(figsize=(8,5))
sns.barplot(data=race_stats, x="race_str", y="MAE", palette="mako")
plt.title("MAE by Race"); plt.xlabel("Race"); plt.ylabel("MAE (years)")
plt.xticks(rotation=15)
savefig(os.path.join(OUTPUT_DIR, "mae_by_race.png"))

# ---- calibration per bin ----
demo_df["age_bin"] = pd.cut(demo_df["true_age_real"], bins=AGE_BINS,
                            include_lowest=True, right=False,
                            labels=bin_labels)
calib = demo_df.groupby("age_bin", observed=False).agg(
    true_mean=("true_age_real","mean"),
    pred_mean=("pred_age_real","mean"),
    count=("true_age_real","count")
).reset_index()

max_age = max(calib["true_mean"].max(), calib["pred_mean"].max()) + 5
plt.figure(figsize=(7,6))
plt.scatter(calib["true_mean"], calib["pred_mean"],
            s=np.clip(calib["count"]*2,10,300),
            c=calib["count"], cmap="Reds", alpha=0.8,
            edgecolors="w", linewidths=0.5)
plt.plot([0,max_age],[0,max_age],"k--",alpha=0.6)
plt.xlabel("Mean True Age (bin)"); plt.ylabel("Mean Predicted Age (bin)")
plt.title("Calibration per Age Bin")
cb = plt.colorbar(); cb.set_label("Samples")
plt.grid(True, linestyle="--", alpha=0.3)
savefig(os.path.join(OUTPUT_DIR, "calibration_age_bin.png"))

print("\n=== Evaluation finished – all files in", OUTPUT_DIR, "===\n")