# Chapter 4: Influential Classification Tools - Transfer Learning with Keras

In this last notebook covering Chapter 4, we will demonstrate how transfer learning can be achieved with Keras. More precisely, we will present how Keras Applications pre-trained on rich datasets can be reused for new tasks. Unlike Notebook [4-3](./ch4_notebook_3_resnet_from_keras_app.ipynb) where we instantiated a ResNet-50 from Keras-App with random parameters, we will this time ask Keras to fetch for us the parameters pre-trained on ImageNet. This will give us the opportunity to test different types of transfer learning; i.e. **_freezing_** or **_fine-tuning_** the feature-extractor layers.

In [1]:
import tensorflow as tf
import os
from matplotlib import pyplot as plt
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
from tiny_imagenet import (
    tiny_imagenet, _training_augmentation_fn, 
    IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS, NUM_CLASSES)

TINY_IMAGENET_ROOT_FOLDER = os.path.expanduser('~/datasets/tiny-imagenet-200/')

NUM_TRAINING_IMAGES = 500 * NUM_CLASSES
NUM_VAL_IMAGES = 50 * NUM_CLASSES
batch_size = 32
num_epochs = 30

train_steps_per_epoch = NUM_TRAINING_IMAGES // batch_size
val_steps_per_epoch = NUM_VAL_IMAGES // batch_size

# Like in previous notebooks, we actually resize the Tiny-ImageNet images to ImageNet commonly-used dimensions:
IMG_HEIGHT, IMG_WIDTH = 224, 224

## ResNet with Frozen Feature Extractor

In this first section, we will use the ResNet-50 from Keras Application, pre-trained on ImageNet, as a feature extractor, and build a new classifier for Tiny-ImageNet on top (c.f. Chapter 4). We will then illustrate the first transfer learning use-case presented in the book, i.e., completely freezing the feature extractor and only training the new dense layers on top.

### Building a New Classifier from Pre-trained Keras Applications

We first build our model, a ResNet-50 solution to predict from the 200 classes of Tiny-ImageNet.

To do so, we first use Keras Applications to instantiate a network with pre-trained weights, but without any top layers (i.e., without the final dense layers leading to predictions):

In [4]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense

resnet50_feature_extractor = tf.keras.applications.resnet50.ResNet50(
    include_top=False, weights='imagenet', 
    input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
# resnet50_feature_extractor.summary()



As mentioned earler, we then _freeze_ this feature extractor; i.e., we make the layers of this network non-trainable, as we want to preserve the knowledge this ResNet obtained by being trained on ImageNet, a much richer dataset.

However, while we want to preserve the feature-extracting layers (i.e., the convolutional layers making most of ResNet), **we should be careful not to freeze some other layers like the regularization ones**. Layers like the _batch-normalization_ ones (added after most of the convolutions in ResNet architectures) have some trainable parameters (c.f. Chapter 3) which tend to become too dataset-specific. It is often recommended not to freeze such layers and let them adapt to the new task/dataset. Therefore, we check the layers' type before freezing them or not:

In [None]:
for layer in resnet50_feature_extractor.layers:
    if isinstance(layer, tf.keras.layers.Conv2D):
        layer.trainable = False
        print("Layer {}: trainable = False".format(layer.name))

Layer conv1: trainable = False
Layer res2a_branch2a: trainable = False
Layer res2a_branch2b: trainable = False
Layer res2a_branch2c: trainable = False
Layer res2a_branch1: trainable = False
Layer res2b_branch2a: trainable = False
Layer res2b_branch2b: trainable = False
Layer res2b_branch2c: trainable = False
Layer res2c_branch2a: trainable = False
Layer res2c_branch2b: trainable = False
Layer res2c_branch2c: trainable = False
Layer res3a_branch2a: trainable = False
Layer res3a_branch2b: trainable = False
Layer res3a_branch2c: trainable = False
Layer res3a_branch1: trainable = False
Layer res3b_branch2a: trainable = False
Layer res3b_branch2b: trainable = False
Layer res3b_branch2c: trainable = False
Layer res3c_branch2a: trainable = False
Layer res3c_branch2b: trainable = False
Layer res3c_branch2c: trainable = False
Layer res3d_branch2a: trainable = False
Layer res3d_branch2b: trainable = False
Layer res3d_branch2c: trainable = False
Layer res4a_branch2a: trainable = False
Layer res4a

We now add on top of this network the trainable layers to make predictions from the features:

In [None]:
features = resnet50_feature_extractor.output
avg_pool = GlobalAveragePooling2D(data_format='channels_last')(features)
predictions = Dense(NUM_CLASSES, activation='softmax')(avg_pool)

resnet50_freeze = Model(resnet50_feature_extractor.input, predictions)

### Preparing the Data

Once again, we reuse the functions we implemented in a previous [notebook](./ch4_notebook_1_data_preparation.ipynb) to set up the input pipelines:

In [None]:
train_images, train_labels, class_ids, class_readable_labels = tiny_imagenet(
    phase='train', shuffle=True, batch_size=batch_size, num_epochs=num_epochs, wrap_for_estimator=False,
    augmentation_fn=_training_augmentation_fn, root_folder=TINY_IMAGENET_ROOT_FOLDER,
    resize_to=[IMG_HEIGHT, IMG_WIDTH])

val_images, val_labels, _, _ = tiny_imagenet(
    phase='val', shuffle=False, batch_size=batch_size, num_epochs=None, wrap_for_estimator=False,
    augmentation_fn=None, root_folder=TINY_IMAGENET_ROOT_FOLDER,
    resize_to=[IMG_HEIGHT, IMG_WIDTH])

### Training the Network

Similarly, the training script itself is purely copy-pasted from previous notebooks (we invite our readers to check them if details are needed). Indeed, with the loading of the pre-trained weights and the freezing of the desired layers already covered, the resulting model can be trained like any others:

In [None]:
import functools

sparse_top_5_categorical_accuracy = functools.partial(
    tf.keras.metrics.sparse_top_k_categorical_accuracy, k=5)
sparse_top_5_categorical_accuracy.__name__ = 'sparse_top_5_categorical_accuracy'

optimizer = tf.keras.optimizers.SGD(momentum=0.9, nesterov=True)

model_dir = './models/resnet_keras_app_transfer_learning_freeze'
callbacks = [
    # Callback to log the graph, losses and metrics into TensorBoard:
    tf.keras.callbacks.TensorBoard(log_dir=model_dir, histogram_freq=0, write_graph=True),
    # Callback to save the model (e.g., every 5 epochs), specifying the epoch and val-loss in the filename:
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(model_dir, 'weights-epoch{epoch:02d}-loss{val_loss:.2f}.h5'), period=5)
]


resnet50_freeze.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', 
                 metrics=['sparse_categorical_accuracy', sparse_top_5_categorical_accuracy])
history_freeze = resnet50_freeze.fit(
    train_images, train_labels,  epochs=num_epochs, steps_per_epoch=train_steps_per_epoch,
    validation_data=(val_images, val_labels), validation_steps=val_steps_per_epoch,
    verbose=1, callbacks=callbacks)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
 685/3125 [=====>........................] - ETA: 7:29 - loss: 0.3022 - sparse_categorical_accuracy: 0.9099 - sparse_top_5_categorical_accuracy: 0.9889

In [None]:
fig, ax = plt.subplots(3, 2, figsize=(15, 10), sharex='col')
ax[0, 0].set_title("loss")
ax[0, 1].set_title("val-loss")
ax[1, 0].set_title("acc")
ax[1, 1].set_title("val-acc")
ax[2, 0].set_title("top5-acc")
ax[2, 1].set_title("val-top5-acc")

ax[0, 0].plot(history_freeze.history['loss'])
ax[0, 1].plot(history_freeze.history['val_loss'])
ax[1, 0].plot(history_freeze.history['sparse_categorical_accuracy'])
ax[1, 1].plot(history_freeze.history['val_sparse_categorical_accuracy'])
ax[2, 0].plot(history_freeze.history['sparse_top_5_categorical_accuracy'])
ax[2, 1].plot(history_freeze.history['val_sparse_top_5_categorical_accuracy'])

Carefully freezing the feature extractor, we achieved a new high in terms of accuracy! With ~75% top-1 / ~92% top-5 accuracy, we are now far from the original ~37% top-1 / ~64% top-5 accuracy obtained with the same model, without transfer learning.

## ResNet with Fine-tuned Feature Extractor

In the following section, we will define the exact same ResNet-50 network with pre-trained layers. However this time, we will not completely freeze its feature-extractor component, in order to _fine-tune_ the latest, higher-level convolutional layers. As we explained in Chapter 4, this fine-tuning can benefit the new classifier which may learn to extract more task-relevant features _(fine-tuning is recommend only if the training dataset is big enough to avoid over-fitting)_.

### Building a New Classifier from Pre-trained Keras Applications

We start by building the same network:

In [None]:
resnet50_feature_extractor = tf.keras.applications.resnet50.ResNet50(
    include_top=False, weights='imagenet', 
    input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), classes=NUM_CLASSES)

features = resnet50_feature_extractor.output
avg_pool = GlobalAveragePooling2D(data_format='channels_last')(features)
predictions = Dense(NUM_CLASSES, activation='softmax')(avg_pool)

resnet50_finetune = Model(resnet50_feature_extractor.input, predictions)
# resnet50_finetune.summary()

However, the idea this time is to fix the weights of the first layers, and retrain all the others. For instance, we will here fix the weights of the 3 first _macro-blocks_ (see Notebook [4-2](./ch4_notebook_2_resnet_from_scratch.ipynb) for definition) and fine-tune the 2 remaining ones while training the new final layers.

Note that in practice, deciding which layers to fine-tune or not may require several trainings to compare the performance of the corresponding models.

In [None]:
for layer in resnet50_finetune.layers:
    if 'res4' in layer.name:
        # Keras developers named the layers in their ResNet implementation to explicitly 
        # identify which macro-block and block each layer belongs to.
        # If we reach a layer which has a name starting by 'resnet4', it means we reached 
        # the 4th macro-block / we are done with the 3rd one:
        break
    if isinstance(layer, tf.keras.layers.Conv2D):
        layer.trainable = False
        print("Layer {}: trainable = False".format(layer.name))

### Preparing the Data

To start from the beginning the data iteration, we re-instantiate the input pipelines (same parameters):

In [None]:
train_images, train_labels, _, _ = tiny_imagenet(
    phase='train', shuffle=True, batch_size=batch_size, num_epochs=num_epochs, wrap_for_estimator=False,
    augmentation_fn=_training_augmentation_fn, root_folder=TINY_IMAGENET_ROOT_FOLDER,
    resize_to=[IMG_HEIGHT, IMG_WIDTH])

val_images, val_labels, _, _ = tiny_imagenet(
    phase='val', shuffle=False, batch_size=batch_size, num_epochs=None, wrap_for_estimator=False,
    augmentation_fn=None, root_folder=TINY_IMAGENET_ROOT_FOLDER,
    resize_to=[IMG_HEIGHT, IMG_WIDTH])

### Training the Network

The training takes place as usual:

In [None]:
# We set a smaller learning rate for the fine-tuning:
optimizer = tf.keras.optimizers.SGD(lr=1e-4, decay=1e-6, momentum=0.9, nesterov=True)

model_dir = './models/resnet_keras_app_transfer_learning_finetune'
callbacks = [
    # Callback to log the graph, losses and metrics into TensorBoard:
    tf.keras.callbacks.TensorBoard(log_dir=model_dir, histogram_freq=0, write_graph=True),
    # Callback to save the model (e.g., every 5 epochs), specifying the epoch and val-loss in the filename:
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(model_dir, 'weights-epoch{epoch:02d}-loss{val_loss:.2f}.h5'), period=5)
]

resnet50_finetune.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', 
                          metrics=['sparse_categorical_accuracy', sparse_top_5_categorical_accuracy])
history_finetune = resnet50_finetune.fit(
    train_images, train_labels, epochs=num_epochs, steps_per_epoch=train_steps_per_epoch,
    validation_data=(val_images, val_labels), validation_steps=val_steps_per_epoch,
    verbose=1, callbacks=callbacks)

In [None]:
fig, ax = plt.subplots(3, 2, figsize=(15, 10), sharex='col') # add parameter `sharey='row'` for a more direct comparison
ax[0, 0].set_title("loss")
ax[0, 1].set_title("val-loss")
ax[1, 0].set_title("acc")
ax[1, 1].set_title("val-acc")
ax[2, 0].set_title("top5-acc")
ax[2, 1].set_title("val-top5-acc")

histories = {'freezing': history_freeze 'fine-tuning': history_finetune}
lines, labels = [], []
for config_name in histories:
    history = histories[config_name]
    ax[0, 0].plot(history.history['loss'])
    ax[0, 1].plot(history.history['val_loss'])
    ax[1, 0].plot(history.history['sparse_categorical_accuracy'])
    ax[1, 1].plot(history.history['val_sparse_categorical_accuracy'])
    ax[2, 0].plot(history.history['sparse_top_5_categorical_accuracy'])
    line = ax[2, 1].plot(history.history['val_sparse_top_5_categorical_accuracy'])
    lines.append(line[0])
    labels.append(config_name)

fig.legend(lines, labels, loc='center right', borderaxespad=0.1)
plt.subplots_adjust(right=0.85)

In [None]:
best_val_acc = max(history_finetune.history['val_sparse_categorical_accuracy']) * 100
best_val_top5 = max(history_finetune.history['val_sparse_top_5_categorical_accuracy']) * 100

print('Best val acc:  {:2.2f}%'.format(best_val_acc))
print('Best val top5: {:2.2f}%'.format(best_val_top5))