In [None]:
import os
import pathlib
import typing

import tensorflow as tf
import tensorflow_addons as tfa

In [None]:
RNG_SEED = 42

BATCH_SIZE = 64
INPUT_SHAPE = (32, 32, 3)
UPSCALED_SHAPE = (224, 224, 3)
CLASS_COUNT = 10
VERBOSE = 1

DATASET_DIR = pathlib.Path("/datasets/cifar10_train_test")

# enable XLA
tf.config.optimizer.set_jit("autoclustering")

In [None]:
def load_train_partition(
    input_shape: tuple[int, int, int],
    upscaled_shape: tuple[int, int, int],
    batch_size: int,
    preprocessing_func: typing.Callable[[tf.Tensor], tf.Tensor],
    directory: pathlib.Path,
    rng_seed: int,
) -> tf.data.Dataset:
    train: tf.data.Dataset = tf.keras.utils.image_dataset_from_directory(
        directory=directory,
        batch_size=None,
        image_size=(input_shape[0], input_shape[1]),
        label_mode="categorical",
        shuffle=False,
        color_mode="rgb",
    )
    
    resizing_layer = tf.keras.layers.Resizing(
        height=upscaled_shape[0], 
        width=upscaled_shape[1],
    )
    resized = train.map(lambda d, t: (resizing_layer(d), t))
    
    preprocessed = resized.map(lambda d, t: (preprocessing_func(d), t))
    
    return (
        preprocessed.cache()
        .shuffle(
            buffer_size=preprocessed.cardinality().numpy(),
            seed=rng_seed,
            reshuffle_each_iteration=True,
        )
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.AUTOTUNE)
    )

In [None]:
def load_non_train_partition(
    input_shape: tuple[int, int, int],
    upscaled_shape: tuple[int, int, int],
    batch_size: int,
    preprocessing_func: typing.Callable[[tf.Tensor], tf.Tensor],
    directory: pathlib.Path,
) -> tf.data.Dataset:
    test: tf.data.Dataset = tf.keras.utils.image_dataset_from_directory(
        directory=directory,
        batch_size=None,
        image_size=(input_shape[0], input_shape[1]),
        label_mode="categorical",
        shuffle=False,
        color_mode="rgb",
    )
    
    resizing_layer = tf.keras.layers.Resizing(
        height=upscaled_shape[0], 
        width=upscaled_shape[1],
    )    
    resized = test.map(lambda d, t: (resizing_layer(d), t))
    
    preprocessed = resized.map(lambda d, t: (preprocessing_func(d), t))
    
    return (
        preprocessed.cache()
        .batch(batch_size, drop_remainder=False)
        .prefetch(tf.data.AUTOTUNE)
    )

In [None]:
def benchmark_model(
    model_constructor: typing.Callable[[], tf.keras.Model],
    preprocessing_func: typing.Callable[[tf.Tensor], tf.Tensor],
):
    base_model = model_constructor(
        include_top=True,
        weights=None,
        input_tensor=None,
        input_shape=UPSCALED_SHAPE,
        pooling="max",
        classes=CLASS_COUNT,
        classifier_activation="softmax",
        )

    model_input = tf.keras.Input(shape=UPSCALED_SHAPE, batch_size=BATCH_SIZE)

    data_aug = tf.keras.layers.RandomFlip(mode="horizontal")(model_input)
    data_aug = tf.keras.layers.RandomRotation(factor=15.0 / 360)(data_aug)
    data_aug = tf.keras.layers.RandomTranslation(height_factor=0.1, width_factor=0.1)(data_aug)

    model_output = base_model(data_aug)

    model = tf.keras.Model(inputs=model_input, outputs=model_output)

    radam = tfa.optimizers.RectifiedAdam()
    ranger = tfa.optimizers.Lookahead(radam, sync_period=6, slow_step_size=0.5)

    model.compile(
        optimizer=ranger, 
        loss="categorical_crossentropy",
        metrics="accuracy",
        )

    train = load_train_partition(
        input_shape=INPUT_SHAPE,
        upscaled_shape=UPSCALED_SHAPE,
        batch_size=BATCH_SIZE,
        preprocessing_func=preprocessing_func,
        directory=DATASET_DIR / "train",
        rng_seed=RNG_SEED,
        )

    test = load_non_train_partition(
        input_shape=INPUT_SHAPE,
        upscaled_shape=UPSCALED_SHAPE,
        batch_size=BATCH_SIZE,
        preprocessing_func=preprocessing_func,
        directory=DATASET_DIR / "test",
        )    

    callbacks = [tf.keras.callbacks.EarlyStopping(patience=6, monitor="loss", restore_best_weights=True)]
    model.fit(train, epochs=999, callbacks=callbacks, verbose=VERBOSE)
    loss, accuracy = model.evaluate(test)
    return accuracy

In [None]:
models_and_funcs = {
    tf.keras.applications.ResNet152 : tf.keras.applications.resnet.preprocess_input,
    tf.keras.applications.ResNet50 : tf.keras.applications.resnet.preprocess_input,
    tf.keras.applications.VGG16 : tf.keras.applications.vgg16.preprocess_input,
    tf.keras.applications.VGG19 : tf.keras.applications.vgg19.preprocess_input,
}

pairs = iter(models_and_funcs.items())
results = {}

In [None]:
# ResNet152

model, preprocessing_func = next(pairs)
test_accuracy = benchmark_model(model, preprocessing_func)
results[model.__name__] = test_accuracy

In [None]:
# ResNet50

model, preprocessing_func = next(pairs)
test_accuracy = benchmark_model(model, preprocessing_func)
results[model.__name__] = test_accuracy

In [None]:
# VGG16

model, preprocessing_func = next(pairs)
test_accuracy = benchmark_model(model, preprocessing_func)
results[model.__name__] = test_accuracy

In [None]:
# VGG19

model, preprocessing_func = next(pairs)
test_accuracy = benchmark_model(model, preprocessing_func)
results[model.__name__] = test_accuracy

In [None]:
# summary
for model, accuracy in results.items():
    print(model, accuracy)