In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from PIL import Image
from pathlib import Path
from omegaconf import OmegaConf
import wandb
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint, WandbEvalCallback

plt.style.use("ggplot")
np.set_printoptions(precision=4)

In [None]:
SEED = 0
tf.random.set_seed(SEED)
np.random.seed(SEED)

### W&B Setup

Go to `Add-ons` -> `Secrets` and add your API Key here with name `WANDB_API_KEY`. Select the checkbox under `Attach to Notebook`.

In [None]:
wandb.login()

In [None]:
cfg = OmegaConf.create(
    dict(
        data_path="/kaggle/input/utkface-cropped/UTKFace/",
        img_size=(200, 200),
        target_size=(224, 224),
        n_channels=3,
        wandb_project="UTKFace-Age-Regression",
        wandb_group="EfficientNet",
        models_dir="models",
        use_sample_weight=False,
        use_tensorboard=False,
    )
)

## Model Training

In [None]:
model_cfg = OmegaConf.create(
    dict(
        architecture="EfficientNetV2B0",
        epochs=100,
        batch_size=32,
        lr_schedule="ExponentialDecay",
        initial_learning_rate=1e-3,
        decay_steps=100000,
        decay_rate=0.96,
        loss="mean_absolute_error",
        optimizer="Adam",
        early_stopping_patience=5,
        early_stopping_monitor="val_mae",
        early_stopping_mode="min",
        random_translation=0.1,
        random_rotation=0.1,
        random_flip="horizontal",
        resize_and_rescale=False,  # not needed for efficient net
        augment=False,
        augment_gpu=False,
    )
)

### Set up the training and testing dataset using the recommended tf.data API

In [None]:
dataset = tf.data.Dataset.list_files(cfg.data_path + "*")


def process_path(file_path):
    # read the age from the filename
    filename = tf.strings.split(file_path, os.sep)[-1]
    label = tf.strings.split(filename, "_")[0]
    label = tf.strings.to_number(label, out_type=tf.dtypes.int32)

    # read and decode the image
    raw = tf.io.read_file(file_path)
    image = tf.image.decode_jpeg(raw, channels=cfg.n_channels)
    print("Initial shape: ", image.shape)
    image = tf.image.resize(image, [*cfg.target_size])
    image.set_shape([*cfg.target_size, cfg.n_channels])
    print("Final shape: ", image.shape)
    return image, label


labeled_dataset = dataset.map(process_path)

In [None]:
for img, label in labeled_dataset.take(1):
    print("Image shape: ", img.numpy().shape)
    print("Label: ", label.numpy())

In [None]:
plt.figure(figsize=(8, 8))
for i, (image, label) in enumerate(labeled_dataset.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image.numpy().astype("int32"))
    plt.title(int(label))
    plt.axis("off")
plt.tight_layout()

In [None]:
def train_test_split(ds: tf.data.Dataset, split: float = 0.8):
    train_size = int(len(ds) * 0.8)
    test_size = len(ds) - train_size

    train_ds = labeled_dataset.shuffle(1000).take(train_size)
    test_ds = labeled_dataset.skip(train_size).take(test_size)
    print(f"Train size: {train_size}")
    print(f"Test size: {test_size}")
    return train_ds, test_ds


train_ds, test_ds = train_test_split(labeled_dataset, split=0.8)

### Model Architecture

In [None]:
resize_and_rescale = tf.keras.Sequential([Resizing(64, 64), Rescaling(1.0 / 255)])

data_augmentation = tf.keras.Sequential(
    [
        RandomRotation(factor=model_cfg.random_rotation),
        RandomTranslation(
            width_factor=model_cfg.random_translation,
            height_factor=model_cfg.random_translation,
        ),
        RandomFlip(mode=model_cfg.random_flip),
        # RandomBrightness(factor=0.2)
    ],
    name="data_augmentation",
)

In [None]:
def build_model(config):
    inputs = keras.Input(shape=(*cfg.target_size, cfg.n_channels))
    base_model = keras.applications.EfficientNetV2B0(
        include_top=False,
        input_tensor=inputs,
        weights="imagenet",
    )
    base_model.trainable = False

    # training=False is very important if we unfreeze the base_model later on
    # for an explanation of the difference between training=False in the call function and the trainable attribute,
    # see here https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute
    if config.augment_gpu:
        x = data_augmentation(inputs)  # do it on GPU
    else:
        x = inputs
    x = base_model(x, training=False)

    # Convert features of shape `base_model.output_shape[1:]` to vectors
    x = GlobalAveragePooling2D()(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    outputs = keras.layers.Dense(1, activation="relu")(x)
    model = keras.Model(inputs, outputs)

    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=config.initial_learning_rate,
        decay_steps=config.decay_steps,
        decay_rate=config.decay_rate,
    )
    metrics = ["mae"]
    weighted_metrics = (
        [keras.metrics.MeanAbsoluteError(name="mae_weighted")]
        if cfg.use_sample_weight
        else None
    )
    model.compile(
        loss=config.loss,
        optimizer=Adam(learning_rate=lr_schedule),
        metrics=metrics,
        weighted_metrics=weighted_metrics,
    )
    return model, base_model

### Define Callbacks

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor=model_cfg.early_stopping_monitor,
    verbose=1,
    patience=model_cfg.early_stopping_patience,
    mode=model_cfg.early_stopping_mode,
    restore_best_weights=True,
)

In [None]:
# Implement model prediction visualization callback
class WandbClfEvalCallback(WandbEvalCallback):
    """Classification Evaluation Callback that logs predictions to Weights and biases.

    This Callback runs after each epoch and logs a single batch of predictions"""

    def __init__(
        self, validation_data, data_table_columns, pred_table_columns, n_samples=8
    ):
        super().__init__(data_table_columns, pred_table_columns)

        self.data = validation_data

        if n_samples > model_cfg.batch_size:
            raise ValueError("n_samples must be smaller than batch size.")
        self.n_samples = n_samples

    def add_ground_truth(self, logs=None):
        # TODO: sample weight support
        for images, labels in self.data.take(1).as_numpy_iterator():
            for idx, (img, label) in enumerate(zip(images, labels)):
                self.data_table.add_data(idx, wandb.Image(img), label)
                if idx == self.n_samples - 1:
                    return

    def add_model_predictions(self, epoch, logs=None):
        preds = self.model.predict(self.data.take(1), verbose=0)

        table_idxs = self.data_table_ref.get_index()

        for idx in table_idxs:
            pred = preds[idx][0]
            self.pred_table.add_data(
                epoch,
                self.data_table_ref.data[idx][0],
                self.data_table_ref.data[idx][1],
                self.data_table_ref.data[idx][2],
                pred,
            )

#### Plot some augmented images

In [None]:
for image, label in train_ds.take(1):
    plt.figure(figsize=(7, 7))
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(tf.expand_dims(image, 0), training=True)
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.axis("off")
plt.tight_layout()

### Prepare datasets for training

In [None]:
AUTOTUNE = tf.data.AUTOTUNE


def prepare(ds: tf.data.Dataset, shuffle=False, augment=False, resize_and_rescale=True):
    if resize_and_rescale:
        ds = ds.map(
            lambda x, y: (resize_and_rescale(x), y), num_parallel_calls=AUTOTUNE
        )

    if shuffle:
        ds = ds.shuffle(1000)

    # Batch all datasets.
    ds = ds.batch(model_cfg.batch_size)

    # Use data augmentation only on the training set.
    if augment:
        ds = ds.map(
            lambda x, y: (data_augmentation(x, training=True), y),
            num_parallel_calls=AUTOTUNE,
        )

    # Use buffered prefetching on all datasets.
    return ds.prefetch(buffer_size=AUTOTUNE)

In [None]:
train_ds = prepare(
    train_ds,
    shuffle=True,
    augment=model_cfg.augment,
    resize_and_rescale=model_cfg.resize_and_rescale,
)
test_ds = prepare(test_ds, resize_and_rescale=model_cfg.resize_and_rescale)

### Fit the model

In [None]:
def restore_model(run_id: str, version: int):
    """Restores the model from the run with the given id and version (does not equal epoch in general).

    Downloads the artifact from W&B and returns the model. Use this, if the kernel crashed.
    Otherwise you can use `wandb.restore(name=<model-name>)`
    """
    model_name = f"run_{run_id}_model:v{version}"
    artifact_name = f"moritzm00/UTKFace-Age-Regression/{model_name}"
    if wandb.run is not None:
        artifact = run.use_artifact(artifact_name, type="model")
    else:
        api = wandb.Api()
        artifact = api.artifact(artifact_name, type="model")
    artifact_dir = artifact.download()
    model = tf.keras.models.load_model(f"/kaggle/working/artifacts/{model_name}")
    return model

In [None]:
run_id = (
    wandb.util.generate_id()
)  # use this to resume a run, (also set resume="must" to be sure it is resuming)
print("Run id is:", run_id)
resume = "allow"
run = wandb.init(
    id=run_id,
    project=cfg.wandb_project,
    group=cfg.wandb_group,
    config=OmegaConf.to_object(model_cfg),
    resume=resume,
    sync_tensorboard=cfg.use_tensorboard,
    tags=["EfficientNetV2B0", "Tensorboard"],
    notes="EfficientNetV2B0 Model with tensorboard logging",
)

In [None]:
model, base_model = build_model(model_cfg)
model.summary()

In [None]:
if not cfg.use_tensorboard:
    callbacks = [
        early_stopping,
        WandbMetricsLogger(),
        WandbModelCheckpoint(
            cfg.models_dir + "/model-{epoch:02d}-{val_mae:.2f}",
            monitor=model_cfg.early_stopping_monitor,
            save_best_only=True,
        ),
        WandbClfEvalCallback(
            validation_data=test_ds,
            data_table_columns=["idx", "image", "label"],
            pred_table_columns=["epoch", "idx", "image", "label", "pred"],
            n_samples=8,
        ),
    ]
else:
    callbacks = [early_stopping, TensorBoard(log_dir="./logs")]

In [None]:
%%wandb
model.fit(
    train_ds,
    epochs=model_cfg.epochs,
    validation_data=test_ds,
    callbacks=callbacks,
)

## Finetune the base model

In [None]:
%%wandb
# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training

base_model.trainable = True
# or just some layers:
# for layer in model.layers[:-40]:
#     layer.trainable = True
    
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-5, # low learning rate
    decay_steps=model_cfg.decay_steps,
    decay_rate=0.99,
)
model.compile(
    optimizer=keras.optimizers.Adam(lr_schedule), 
    loss=model_cfg.loss,
    metrics=["mae"]
)
print(model.summary())
model.fit(
    train_ds,
    epochs=run.step + model_cfg.epochs,
    initial_epoch=run.step,
    validation_data=test_ds,
    callbacks=callbacks,
    use_multiprocessing=True,
    workers=4
)

In [None]:
run.finish()

In [None]:
labels = [label for _, label in labeled_dataset.as_numpy_iterator()]

In [None]:
preds = model.predict(labeled_dataset.batch(32))

In [None]:
plt.hist(preds, bins=50, label="predictions")
plt.hist(labels, bins=50, label="ground truth")
plt.legend()
plt.show()

In [None]:
preds.shape

In [None]:
df = pd.DataFrame({"label": labels, "pred": preds.reshape(-1)})

In [None]:
df.groupby("label").agg({"pred": "mean"}).plot()