# **Implementing ResNet From Scratch**

### Data Input Pipeline

In [0]:
import collections
import functools
import glob
import math
import os

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

import tensorflow.keras.regularizers as regulizers
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

from tensorflow.keras.layers import (Input, Activation, Dense, Flatten, Conv2D,
                                     MaxPooling2D,GlobalAveragePooling2D, 
                                     AveragePooling2D, BatchNormalization, add)
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import img_to_array, load_img

In [0]:
os.environ["CUDA_VISIBLE_DEVICES"]= "2" 

In [0]:
cifar_builder = tfds.builder('cifar100')
cifar_builder.download_and_prepare()

print(f'{cifar_builder.info}')

In [14]:
print(cifar_builder.info.features['label'].names)

['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree',

In [15]:
input_shape = [224, 224, 3]
batch_size = 32
num_epochs = 300

train_cifar_dataset = cifar_builder.as_dataset(split=tfds.Split.TRAIN)
val_cifar_dataset = cifar_builder.as_dataset(split=tfds.Split.TEST)

num_classes = cifar_builder.info.features['label'].num_classes

num_train_images = cifar_builder.info.splits['train'].num_examples
num_valid_images = cifar_builder.info.splits['test'].num_examples



In [0]:
train_cifar_dataset = train_cifar_dataset.repeat(num_epochs).shuffle(10000)

In [0]:
def _prepare_data_fn(features, input_shape, augment=False):
    """
    Resize image to expected dimensions, and opt. apply some random transformations.
    :param features:    Data
    :param input_shape: Shape expected by the models (images will be resized accordingly)
    :param augment:     Flag to apply some random augmentations to the images
    :return:            Augmented Images, Labels
    """
    input_shape = tf.convert_to_tensor(input_shape)
    
    # Tensorflow-Dataset returns batches as feature dictionaries, expected by Estimators.
    # To train Keras models, it is more straightforward to return the batch content as tuples:
    image = features['image']
    label = features['label']
    # Convert the images to float type, also scaling their values from [0, 255] to [0., 1.]:
    image = tf.image.convert_image_dtype(image, tf.float32)
    
    if augment:
        # Randomly applied horizontal flip:
        image = tf.image.random_flip_left_right(image)

        # Random B/S changes:
        image = tf.image.random_brightness(image, max_delta=0.1)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.clip_by_value(image, 0.0, 1.0) # keeping pixel values in check

        # Random resize and random crop back to expected size:
        
        random_scale_factor = tf.random.uniform([1], minval=1., maxval=1.4, dtype=tf.float32)
        scaled_height = tf.cast(tf.cast(input_shape[0], tf.float32) * random_scale_factor, 
                                tf.int32)
        scaled_width = tf.cast(tf.cast(input_shape[1], tf.float32) * random_scale_factor, 
                               tf.int32)
        scaled_shape = tf.squeeze(tf.stack([scaled_height, scaled_width]))
        image = tf.image.resize(image, scaled_shape)
        image = tf.image.random_crop(image, input_shape)
    else:
        image = tf.image.resize(image, input_shape[:2])
        
    return image, label

In [0]:
prepare_data_fn_for_train = functools.partial(_prepare_data_fn, 
                                              input_shape=input_shape,
                                              augment=True)

train_cifar_dataset = train_cifar_dataset.map(prepare_data_fn_for_train)

# We also ask the dataset to batch the samples:
train_cifar_dataset = train_cifar_dataset.batch(batch_size)

train_cifar_dataset = train_cifar_dataset.prefetch(1)

In [0]:
prepare_data_fn_for_val = functools.partial(_prepare_data_fn,
                                            input_shape=input_shape)

val_cifar_dataset = (val_cifar_dataset
                    .repeat()
                    .map(prepare_data_fn_for_val, num_parallel_calls=4)
                    .batch(batch_size)
                    .prefetch(1))

In [0]:
train_steps_per_epoch = math.ceil(num_train_images/batch_size)
val_steps_per_epoch = math.ceil(num_valid_images/batch_size)

### ResNet Implementation

In [0]:
def _res_conv(filters, 
              kernel_size=3, 
              padding='same', 
              strides=1, 
              use_relu=True, 
              use_bias=False, 
              name='cbr', 
              kernel_initializer='he_normal', 
              kernel_regularizer=regulizers.l2(1e-4)):
    """
    Return a layer block chaining conv, batchnrom and reLU activation.
    :param filters:                 Number of filters.
    :param kernel_size:             Kernel size.
    :param padding:                 Convolution padding.
    :param strides:                 Convolution strides.
    :param use_relu:                Flag to apply ReLu activation at the end.
    :param use_bias:                Flag to use bias or not in Conv layer.
    :param name:                    Name suffix for the layers.
    :param kernel_initializer:      Kernel initialisation method name.
    :param kernel_regularizer:      Kernel regularizer.
    :return:                        Callable layer block
    """

    def layer_fn(x):

        conv = Conv2D(filters=filters, 
                      kernel_size=kernel_size, 
                      padding=padding,
                      strides=strides, 
                      use_bias=use_bias, 
                      kernel_initializer=kernel_initializer, 
                      kernel_regularizer=kernel_regularizer, 
                      name=f'{name}_c')(x)

        res = BatchNormalization(axis=-1, name=f'{name}_bn')(conv)

        if use_relu:
            res = Activation("relu", name=f'{name}_r')(res)
        return res

    return layer_fn

In [0]:
def _merge_with_shortcut(kernel_initializer='he_normal', 
                         kernel_regularizer=regulizers.l2(1e-4), 
                         name='block'):
    """
    Return a layer block which merge an input tensor and the corresponding 
    residual output tensor from another branch.
    :param kernel_initializer:      Kernel initialisation method name.
    :param kernel_regularizer:      Kernel regularizer.
    :param name:                    Name suffix for the layers.
    :return:                        Callable layer block
    """

    def layer_fn(x, x_residual):
        # We check if `x_residual` was scaled down. 
        # If so, we scale `x` accordingly with a 1x1 conv:
        x_shape = tf.keras.backend.int_shape(x)
        x_residual_shape = tf.keras.backend.int_shape(x_residual)
        if x_shape == x_residual_shape:
            shortcut = x
        else:
            strides = (
                # Vertical Strides
                int(round(x_shape[1] / x_residual_shape[1])),
                # Horizontal strides
                int(round(x_shape[2] / x_residual_shape[2]))
            )
            x_residual_channels = x_residual_shape[3]
            shortcut = Conv2D(filters=x_residual_channels, 
                              kernel_size=(1, 1), 
                              padding="valid", 
                              strides=strides,
                              kernel_initializer=kernel_initializer, 
                              kernel_regularizer=kernel_regularizer,
                              name=name + '_shortcut_c')(x)

        merge = add([shortcut, x_residual])
        return merge

    return layer_fn

In [0]:
def _residual_block_basic(filters, 
                          kernel_size=3, 
                          strides=1, 
                          use_bias=False, 
                          name='res_basic',
                          kernel_initializer='he_normal', 
                          kernel_regularizer=regulizers.l2(1e-4)):
    """
    Return a basic residual layer block.
    :param filters:                 Number of filters.
    :param kernel_size:             Kernel size.
    :param strides:                 Convolution strides
    :param use_bias:                Flag to use bias or not in Conv layer.
    :param kernel_initializer:      Kernel initialisation method name.
    :param kernel_regularizer:      Kernel regularizer.
    :return:                        Callable layer block
    """

    def layer_fn(x):
        x_conv1 = _res_conv(filters=filters, 
                            kernel_size=kernel_size, 
                            padding='same', 
                            strides=strides, 
                            use_relu=True, use_bias=use_bias,
                            kernel_initializer=kernel_initializer, 
                            kernel_regularizer=kernel_regularizer,
                            name=name + '_cbr_1')(x)
        x_residual = _res_conv(filters=filters, 
                               kernel_size=kernel_size, 
                               padding='same', 
                               strides=1, 
                               use_relu=False, 
                               use_bias=use_bias,
                               kernel_initializer=kernel_initializer, 
                               kernel_regularizer=kernel_regularizer,
                               name=name + '_cbr_2')(x_conv1)
        merge = _merge_with_shortcut(kernel_initializer, 
                                     kernel_regularizer,
                                     name=name)(x, x_residual)
        merge = Activation('relu')(merge)
        return merge

    return layer_fn

In [0]:
def _residual_block_bottleneck(filters, 
                               kernel_size=3, 
                               strides=1, 
                               use_bias=False, 
                               name='res_bottleneck',
                               kernel_initializer='he_normal', 
                               kernel_regularizer=regulizers.l2(1e-4)):
    """
    Return a residual layer block with bottleneck, 
    recommended for deep ResNets (depth > 34).
    
    :param filters:                 Number of filters.
    :param kernel_size:             Kernel size.
    :param strides:                 Convolution strides
    :param use_bias:                Flag to use bias or not in Conv layer.
    :param kernel_initializer:      Kernel initialisation method name.
    :param kernel_regularizer:      Kernel regularizer.
    :return:                        Callable layer block
    """

    def layer_fn(x):
        x_bottleneck = _res_conv(filters=filters, 
                                 kernel_size=1, 
                                 padding='valid', 
                                 strides=strides, 
                                 use_relu=True, 
                                 use_bias=use_bias,
                                 kernel_initializer=kernel_initializer, 
                                 kernel_regularizer=kernel_regularizer,
                                 name=name + '_cbr1')(x)
        x_conv = _res_conv(filters=filters, 
                           kernel_size=kernel_size, 
                           padding='same', 
                           strides=1, 
                           use_relu=True, 
                           use_bias=use_bias,
                           kernel_initializer=kernel_initializer, 
                           kernel_regularizer=kernel_regularizer,
                           name=name + '_cbr2')(x_bottleneck)
        x_residual = _res_conv(filters=filters * 4, 
                               kernel_size=1, 
                               padding='valid', 
                               strides=1, 
                               use_relu=False, 
                               use_bias=use_bias,
                               kernel_initializer=kernel_initializer, 
                               kernel_regularizer=kernel_regularizer,
                               name=name + '_cbr3')(x_conv)
        merge = _merge_with_shortcut(kernel_initializer, 
                                     kernel_regularizer, 
                                     name=name)(x, x_residual)
        merge = Activation('relu')(merge)
        return merge

    return layer_fn

In [0]:
def _residual_macroblock(block_fn, filters, repetitions=3, kernel_size=3, 
                         strides_1st_block=1, use_bias=False,
                         kernel_initializer='he_normal', 
                         kernel_regularizer=regulizers.l2(1e-4),
                         name='res_macroblock'):
    """
    Return a layer block, composed of a repetition of `N` residual blocks.
    :param block_fn:               Block layer method to be used.
    :param repetitions:            Number of times the block should be repeated.
    :param filters:                Number of filters.
    :param kernel_size:            Kernel size.
    :param strides_1st_block:      Convolution strides for the 1st block.
    :param use_bias:               Flag to use bias or not in Conv layer.
    :param kernel_initializer:     Kernel initialisation method name.
    :param kernel_regularizer:     Kernel regularizer.
    :return:                       Callable layer block
    """

    def layer_fn(x):
        for i in range(repetitions):
            block_name = f"{name}_{i}" 
            strides = strides_1st_block if i == 0 else 1
            x = block_fn(filters=filters, kernel_size=kernel_size, 
                         strides=strides, use_bias=use_bias,
                         kernel_initializer=kernel_initializer, 
                         kernel_regularizer=kernel_regularizer,
                         name=block_name)(x)
        return x

    return layer_fn

In [0]:
def _residual_macroblock(block_fn, filters, repetitions=3, kernel_size=3, 
                         strides_1st_block=1, use_bias=False,
                         kernel_initializer='he_normal', 
                         kernel_regularizer=regulizers.l2(1e-4),
                         name='res_macroblock'):
    """
    Return a layer block, composed of a repetition of `N` residual blocks.
    :param block_fn:               Block layer method to be used.
    :param repetitions:            Number of times the block should be repeated.
    :param filters:                Number of filters.
    :param kernel_size:            Kernel size.
    :param strides_1st_block:      Convolution strides for the 1st block.
    :param use_bias:               Flag to use bias or not in Conv layer.
    :param kernel_initializer:     Kernel initialisation method name.
    :param kernel_regularizer:     Kernel regularizer.
    :return:                       Callable layer block
    """

    def layer_fn(x):
        for i in range(repetitions):
            block_name = f"{name}_{i}" 
            strides = strides_1st_block if i == 0 else 1
            x = block_fn(filters=filters, kernel_size=kernel_size, 
                         strides=strides, use_bias=use_bias,
                         kernel_initializer=kernel_initializer, 
                         kernel_regularizer=kernel_regularizer,
                         name=block_name)(x)
        return x

    return layer_fn

In [0]:
def ResNet(input_shape, num_classes=1000, block_fn=_residual_block_basic, 
           repetitions=(2, 2, 2, 2), use_bias=False, 
           kernel_initializer='he_normal', 
           kernel_regularizer=regulizers.l2(1e-4)):
    """
    Build a ResNet model for classification.
    :param input_shape:             Input shape (e.g. (224, 224, 3))
    :param num_classes:             Number of classes to predict
    :param block_fn:                Block layer method to be used.
    :param repetitions:             List of repetitions for each macro-blocks 
                                    the network should contain.
    :param use_bias:                Flag to use bias or not in Conv layer.
    :param kernel_initializer:      Kernel initialisation method name.
    :param kernel_regularizer:      Kernel regularizer.
    :return:                        ResNet model.
    """

    # Input and 1st layers:
    inputs = Input(shape=input_shape)
    conv = _res_conv(filters=64, kernel_size=7, strides=2, use_relu=True, 
                     use_bias=use_bias,kernel_initializer=kernel_initializer, 
                     kernel_regularizer=kernel_regularizer)(inputs)
    maxpool = MaxPooling2D(pool_size=3, strides=2, padding='same')(conv)

    # Chain of residual blocks:
    filters = 64
    strides = 2
    res_block = maxpool
    for i, repet in enumerate(repetitions):
        # We do not further reduce the input size for the 1st block
        # (max-pool applied just before):
        block_strides = strides if i != 0 else 1
        macroblock_name = "block_{}".format(i) 
        res_block = _residual_macroblock(
            block_fn=block_fn, repetitions=repet, 
            name=macroblock_name,filters=filters, 
            strides_1st_block=block_strides, 
            use_bias=use_bias, kernel_initializer=kernel_initializer, 
            kernel_regularizer=kernel_regularizer)(res_block)
        filters = min(filters * 2, 1024) # we limit to 1024 filters max

    # Final layers for prediction:
    res_spatial_dim = tf.keras.backend.int_shape(res_block)[1:3]
    avg_pool = AveragePooling2D(pool_size=res_spatial_dim, strides=1)(res_block)
    flatten = Flatten()(avg_pool)
    predictions = Dense(units=num_classes, 
                        
                        kernel_initializer=kernel_initializer, 
                        activation='softmax')(flatten)

    # Model:
    model = Model(inputs=inputs, outputs=predictions)
    return model

In [0]:
def ResNet18(input_shape, num_classes=1000, use_bias=True,
             kernel_initializer='he_normal', kernel_regularizer=None):
    return ResNet(input_shape, num_classes, block_fn=_residual_block_basic, 
                  repetitions=(2, 2, 2, 2), use_bias=use_bias, 
                  kernel_initializer=kernel_initializer, 
                  kernel_regularizer=kernel_regularizer)


def ResNet34(input_shape, num_classes=1000, use_bias=True,
             kernel_initializer='he_normal', kernel_regularizer=None):
    return ResNet(input_shape, num_classes, block_fn=_residual_block_basic, 
                  repetitions=(3, 4, 6, 3), use_bias=use_bias, 
                  kernel_initializer=kernel_initializer, 
                  kernel_regularizer=kernel_regularizer)


def ResNet50(input_shape, num_classes=1000, use_bias=True,
             kernel_initializer='he_normal', kernel_regularizer=None):
    # Note: ResNet50 is similar to ResNet34,
    # with the basic blocks replaced by bottleneck ones.
    return ResNet(input_shape, num_classes, block_fn=_residual_block_bottleneck,
                  repetitions=(3, 4, 6, 3), use_bias=use_bias, 
                  kernel_initializer=kernel_initializer, 
                  kernel_regularizer=kernel_regularizer)


def ResNet101(input_shape, num_classes=1000, use_bias=True,
             kernel_initializer='he_normal', kernel_regularizer=None):
    return ResNet(input_shape, num_classes, block_fn=_residual_block_bottleneck,
                  repetitions=(3, 4, 23, 3), use_bias=use_bias, 
                  kernel_initializer=kernel_initializer, 
                  kernel_regularizer=kernel_regularizer)


def ResNet152(input_shape, num_classes=1000, use_bias=True,
             kernel_initializer='he_normal', kernel_regularizer=None):
    return ResNet(input_shape, num_classes, block_fn=_residual_block_bottleneck,
                  repetitions=(3, 8, 36, 3), use_bias=use_bias, 
                  kernel_initializer=kernel_initializer, 
                  kernel_regularizer=kernel_regularizer)

In [29]:
resnet50 = ResNet50(input_shape=input_shape, num_classes=num_classes)
resnet50.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
cbr_c (Conv2D)                  (None, 112, 112, 64) 9472        input_1[0][0]                    
__________________________________________________________________________________________________
cbr_bn (BatchNormalization)     (None, 112, 112, 64) 256         cbr_c[0][0]                      
__________________________________________________________________________________________________
cbr_r (Activation)              (None, 112, 112, 64) 0           cbr_bn[0][0]                     
______________________________________________________________________________________________

In [0]:
optimizer = tf.keras.optimizers.Adam()
accuracy_metric = tf.metrics.SparseCategoricalAccuracy(name='acc')
top5_accuracy_metric = tf.metrics.SparseTopKCategoricalAccuracy(
    k=5, name='top5_acc'
)
resnet50.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', 
                 metrics=[accuracy_metric, top5_accuracy_metric])

In [0]:
# Setting some variables to format the logs:
log_begin_red, log_begin_blue, log_begin_green = '\033[91m','\033[94m','\033[92m'
log_begin_bold, log_begin_underline = '\033[1m', '\033[4m'
log_end_format = '\033[0m'

class SimpleLogCallback(tf.keras.callbacks.Callback):
    """ Keras callback for simple, denser console logs."""

    def __init__(self, metrics_dict, num_epochs='?', log_frequency=1,
                 metric_string_template=
                 '\033[1m[[name]]\033[0m = \033[94m{[[value]]:5.3f}\033[0m'):
        """
        Initialize the Callback.
        :param metrics_dict:            Dictionary containing mappings for 
                                        metrics names/keys e.g. {"accuracy": 
                                        "acc", "val. accuracy": "val_acc"}
        :param num_epochs:              Number of training epochs
        :param log_frequency:           Log frequency (in epochs)
        :param metric_string_template:  (opt.) String template to print each 
                                        metric
        """
        super().__init__()

        self.metrics_dict = collections.OrderedDict(metrics_dict)
        self.num_epochs = num_epochs
        self.log_frequency = log_frequency

        # We build a format string to later print the metrics, 
        # (e.g. "Epoch 0/9: loss = 1.00; val-loss = 2.00")
        log_string_template = 'Epoch {0:2}/{1}: '
        separator = '; '

        i = 2
        for metric_name in self.metrics_dict:
            templ = metric_string_template.replace(
                '[[name]]', metric_name).replace(
                    '[[value]]', str(i))
            log_string_template += templ + separator
            i += 1

        # We remove the "; " after the last element:
        log_string_template = log_string_template[:-len(separator)]
        self.log_string_template = log_string_template

    def on_train_begin(self, logs=None):
        print(f"Training: {log_begin_red}start{log_end_format}")

    def on_train_end(self, logs=None):
        print(f"Training: {log_begin_green}end{log_end_format}")

    def on_epoch_end(self, epoch, logs={}):
        if (epoch - 1) % self.log_frequency == 0 or epoch == self.num_epochs:
            values = [logs[self.metrics_dict[metric_name]] 
                      for metric_name in self.metrics_dict]
            print(self.log_string_template.format(epoch, self.num_epochs, 
                                                  *values))

In [0]:
metrics_to_print = collections.OrderedDict([("loss", "loss"), 
                                            ("v-loss", "val_loss"),
                                            ("acc", "acc"), 
                                            ("v-acc", "val_acc"),
                                            ("top5-acc", "top5_acc"), 
                                            ("v-top5-acc", "val_top5_acc")])

callback_simple_log = SimpleLogCallback(metrics_to_print, 
                                        num_epochs=num_epochs, log_frequency=2)

In [0]:
model_dir = './models/resnet_from_scratch'
callbacks = [
    # Callback to interrupt the training if the validation loss/metrics 
    # stops improving for some epochs:
    tf.keras.callbacks.EarlyStopping(patience=8, monitor='val_acc',
                                     restore_best_weights=True),
    # Callback to log the graph, losses and metrics into TensorBoard:
    tf.keras.callbacks.TensorBoard(log_dir=model_dir, histogram_freq=0, 
                                   write_graph=True),
    # Callback to save the model (e.g., every 5 epochs), specifying the epoch 
    # and val-loss in the filename:
    tf.keras.callbacks.ModelCheckpoint(os.path.join(model_dir, 
                     'weights-epoch{epoch:02d}-loss{val_loss:.2f}.h5'), 
                     save_freq=5),
    # Log callback:
    callback_simple_log 
]

In [0]:
history = resnet50.fit(train_cifar_dataset,  
                       epochs=num_epochs, steps_per_epoch=train_steps_per_epoch,
                       validation_data=(val_cifar_dataset), 
                       validation_steps=val_steps_per_epoch,
                       verbose=0)

In [0]:
fig, ax = plt.subplots(3, 2, figsize=(15, 10), sharex='col')
ax[0, 0].set_title("loss")
ax[0, 1].set_title("val-loss")
ax[1, 0].set_title("acc")
ax[1, 1].set_title("val-acc")
ax[2, 0].set_title("top5-acc")
ax[2, 1].set_title("val-top5-acc")

ax[0, 0].plot(history.history['loss'])
ax[0, 1].plot(history.history['val_loss'])
ax[1, 0].plot(history.history['acc'])
ax[1, 1].plot(history.history['val_acc'])
ax[2, 0].plot(history.history['top5_acc'])
ax[2, 1].plot(history.history['val_top5_acc'])

In [0]:
best_val_acc = max(history.history['val_acc']) * 100
best_val_top5 = max(history.history['val_top5_acc']) * 100

print(f'Best val acc:  {best_val_acc:2.2f}%')
print(f'Best val top5: {best_val_top5:2.2f}%')

In [0]:
def load_image(image_path, size):
    """
    Load an image as a Numpy array.
    :param image_path:  Path of the image
    :param size:        Target size
    :return             Image array, normalized between 0 and 1
    """
    image = img_to_array(load_img(image_path, target_size=size)) / 255.
    return image


def process_predictions(class_probabilities, class_readable_labels, k=5):
    """
    Process a batch of predictions from our estimator.
    :param class_probabilities:     Prediction results returned by the Keras 
                                    classifier for a batch of data
    :param class_readable_labels:   List of readable-class labels, for display
    :param k:                       Number of top predictions to consider
    :return                         Readable labels and probabilities for the 
                                    predicted classes
    """
    topk_labels, topk_probabilities = [], []
    for i in range(len(class_probabilities)):
        # Getting the top-k predictions:
        topk_classes = sorted(np.argpartition(class_probabilities[i], -k)[-k:])
    
        # Getting the corresponding labels and probabilities:
        topk_labels.append([class_readable_labels[predicted] 
                            for predicted in topk_classes])
        topk_probabilities.append(class_probabilities[i][topk_classes])
    
    return topk_labels, topk_probabilities


def display_predictions(images, topk_labels, topk_probabilities):
    """
    Plot a batch of predictions.
    :param images:                  Batch of input images
    :param topk_labels:             String labels of predicted classes
    :param topk_probabilities:      Probabilities for each class
    """
    num_images = len(images)
    num_images_sqrt = np.sqrt(num_images)
    plot_cols = plot_rows = int(np.ceil(num_images_sqrt))
    
    figure = plt.figure(figsize=(13,10))
    grid_spec = gridspec.GridSpec(plot_cols, plot_rows)
    
    for i in range(num_images):
        img, pred_labels, pred_proba = images[i], topk_labels[i], topk_probabilities[i]
        # Shortening the labels to better fit in the plot:
        pred_labels = [label.split(',')[0][:20] for label in pred_labels]
        
        grid_spec_i = gridspec.GridSpecFromSubplotSpec(3, 1, 
                                                       subplot_spec=grid_spec[i], 
                                                       hspace=0.1)
        
        # Drawing the input image:
        ax_img = figure.add_subplot(grid_spec_i[:2])
        ax_img.axis('off')
        ax_img.imshow(img)
        ax_img.autoscale(tight=True)
        
        # Plotting a bar chart for the predictions:
        ax_pred = figure.add_subplot(grid_spec_i[2])
        ax_pred.spines['top'].set_visible(False)
        ax_pred.spines['right'].set_visible(False)
        ax_pred.spines['bottom'].set_visible(False)
        ax_pred.spines['left'].set_visible(False)
        y_pos = np.arange(len(pred_labels))
        ax_pred.barh(y_pos, pred_proba, align='center')
        ax_pred.set_yticks(y_pos)
        ax_pred.set_yticklabels(pred_labels)
        ax_pred.invert_yaxis()
        
    plt.tight_layout()
    plt.show()

In [0]:
test_filenames = glob.glob(os.path.join('res', '*'))
test_images = np.asarray([load_image(file, size=input_shape[:2]) 
                         for file in test_filenames])
print(f'Test Images: {test_images.shape}')

In [0]:
image_batch = test_images[:16]

# Our model was trained on CIFAR images, which originally are 32x32px. 
# We scaled them up to 224x224px to train our model on, but this means 
# the resulting images had important artifacts/low quality.
# To test on images of the same quality, we first resize them to 32x32px,
# then to the expected input size (i.e., 224x224px):
cifar_original_image_size = cifar_builder.info.features['image'].shape[:2]
image_batch_low_quality = tf.image.resize(image_batch, cifar_original_image_size)
image_batch_low_quality = tf.image.resize(image_batch_low_quality, input_shape[:2])
    

predictions = resnet50.predict_on_batch(image_batch_low_quality)
print(f'Predicted class probabilities: {predictions.shape}')

class_readable_labels = cifar_builder.info.features["label"].names
top5_labels, top5_probabilities = process_predictions(predictions, 
                                                      class_readable_labels)
    
display_predictions(image_batch, top5_labels, top5_probabilities)