In [None]:
!pip install -q omegaconf watermark
%load_ext watermark

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from PIL import Image
from pathlib import Path
from omegaconf import OmegaConf
import wandb
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint, WandbEvalCallback

plt.style.use("ggplot")
np.set_printoptions(precision=4)
%watermark --iversions

In [None]:
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
api_key = user_secrets.get_secret("WANDB_API_KEY")
wandb.login(key=api_key)

In [None]:
cfg = OmegaConf.create(
    dict(
        data_path="/kaggle/input/utkface-cropped/UTKFace/",
        img_size=(200, 200),
        target_size=(200, 200),
        wandb_project="UTKFace-Age-Regression",
        wandb_group="Baseline",
        models_dir="models",
        use_sample_weight=False,
    )
)

## Model Training

In [None]:
model_cfg = OmegaConf.create(
    dict(
        architecture="Simple CNN",
        epochs=50,
        batch_size=32,
        lr_schedule="ExponentialDecay",
        initial_learning_rate=1e-3,
        decay_steps=100000,
        decay_rate=0.96,
        loss="mean_absolute_error",
        optimizer="Adam",
        early_stopping_patience=5,
        early_stopping_monitor="val_mae",
        early_stopping_mode="min",
        random_translation=0.1,
        random_rotation=0.15,
    )
)

### Set up the training and testing dataset using the recommended tf.data API

In [None]:
dataset = tf.data.Dataset.list_files(cfg.data_path + "*")


def process_path(file_path):
    # read the age from the filename
    filename = tf.strings.split(file_path, os.sep)[-1]
    label = tf.strings.split(filename, "_")[0]
    label = tf.strings.to_number(label, out_type=tf.dtypes.int32)

    # read and decode the image
    raw = tf.io.read_file(file_path)
    image = tf.image.decode_jpeg(raw, channels=3)
    print("Initial shape: ", image.shape)
    image.set_shape([200, 200, 3])
    print("Final shape: ", image.shape)
    return image, label


labeled_dataset = dataset.map(process_path)

In [None]:
for img, label in labeled_dataset.take(1):
    print("Image shape: ", img.numpy().shape)
    print("Label: ", label.numpy())

In [None]:
train_ds, test_ds = tf.keras.utils.split_dataset(
    labeled_dataset, left_size=0.8, shuffle=True
)
len(train_ds), len(test_ds)

### Model Architecture

In [None]:
def build_model(config):
    model = tf.keras.Sequential(
        [
            Input(shape=(64, 64, 3)),
            Conv2D(128, kernel_size=(3, 3), activation="relu"),
            Conv2D(128, kernel_size=(3, 3), activation="relu"),
            MaxPooling2D(),
            Dropout(0.25),
            Conv2D(128, kernel_size=(3, 3), activation="relu"),
            Conv2D(128, kernel_size=(3, 3), activation="relu"),
            MaxPooling2D(),
            Dropout(0.25),
            Conv2D(128, kernel_size=(3, 3), activation="relu"),
            Conv2D(128, kernel_size=(3, 3), activation="relu"),
            MaxPooling2D(),
            Flatten(),
            Dropout(0.25),
            Dense(
                1, activation="relu"
            ),  # we only need positive integers as output, therefore relu activation
        ]
    )

    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=config.initial_learning_rate,
        decay_steps=config.decay_steps,
        decay_rate=config.decay_rate,
    )
    metrics = ["mae"]
    weighted_metrics = (
        [keras.metrics.MeanAbsoluteError(name="mae_weighted")]
        if cfg.use_sample_weight
        else None
    )
    model.compile(
        loss=config.loss,
        optimizer=Adam(learning_rate=lr_schedule),
        metrics=metrics,
        weighted_metrics=weighted_metrics,
    )
    return model

### Define Callbacks

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor=model_cfg.early_stopping_monitor,
    verbose=1,
    patience=model_cfg.early_stopping_patience,
    mode=model_cfg.early_stopping_mode,
    restore_best_weights=True,
)

In [None]:
# Implement model prediction visualization callback
class WandbClfEvalCallback(WandbEvalCallback):
    """Classification Evaluation Callback that logs predictions to Weights and biases.

    This Callback runs after each epoch and logs a single batch of predictions"""

    def __init__(self, validation_data, data_table_columns, pred_table_columns):
        super().__init__(data_table_columns, pred_table_columns)

        self.data = validation_data

    def add_ground_truth(self, logs=None):
        # TODO: sample weight support
        for images, labels in self.data.take(1).as_numpy_iterator():
            for idx, (img, label) in enumerate(zip(images, labels)):
                self.data_table.add_data(idx, wandb.Image(img), label)

    def add_model_predictions(self, epoch, logs=None):
        preds = self.model.predict(self.data, verbose=0)

        table_idxs = self.data_table_ref.get_index()

        for idx in table_idxs:
            pred = preds[idx][0]
            self.pred_table.add_data(
                epoch,
                self.data_table_ref.data[idx][0],
                self.data_table_ref.data[idx][1],
                self.data_table_ref.data[idx][2],
                pred,
            )

In [None]:
resize_and_rescale = tf.keras.Sequential([Resizing(64, 64), Rescaling(1.0 / 255)])

data_augmentation = tf.keras.Sequential(
    [
        RandomRotation(model_cfg.random_rotation),
        RandomTranslation(
            width_factor=model_cfg.random_translation,
            height_factor=model_cfg.random_translation,
        ),
    ]
)

In [None]:
AUTOTUNE = tf.data.AUTOTUNE


def prepare(ds, shuffle=False, augment=False):
    # Resize and rescale all datasets.
    ds = ds.map(lambda x, y: (resize_and_rescale(x), y), num_parallel_calls=AUTOTUNE)

    if shuffle:
        ds = ds.shuffle(1000)

    # Batch all datasets.
    ds = ds.batch(model_cfg.batch_size)

    # Use data augmentation only on the training set.
    if augment:
        ds = ds.map(
            lambda x, y: (data_augmentation(x, training=True), y),
            num_parallel_calls=AUTOTUNE,
        )

    # Use buffered prefetching on all datasets.
    return ds.prefetch(buffer_size=AUTOTUNE)

In [None]:
train_ds = prepare(train_ds, shuffle=True, augment=True)
test_ds = prepare(test_ds)

### Fit the model

In [None]:
run = wandb.init(
    project=cfg.wandb_project,
    group=cfg.wandb_group,
    config=OmegaConf.to_object(model_cfg),
    tags=["Baseline", "Image Augmentation", "tf.data API"],
    notes="Baseline CNN Model with image augmentation (tf.data API)",
)

In [None]:
model = build_model(model_cfg)
model.summary()

In [None]:
callbacks = [
    early_stopping,
    WandbMetricsLogger(),
    WandbModelCheckpoint(cfg.models_dir, monitor=model_cfg.early_stopping_monitor),
    WandbClfEvalCallback(
        validation_data=test_ds,
        data_table_columns=["idx", "image", "label"],
        pred_table_columns=["epoch", "idx", "image", "label", "pred"],
    ),
]

In [None]:
%%wandb
model.fit(
    train_ds,
    epochs=model_cfg.epochs,
    validation_data=test_ds,
    callbacks=callbacks,
    use_multiprocessing=True,
    workers=4
)

In [None]:
run.finish()