# Transfer learning tutorial

Based on code from https://keras.io/guides/transfer_learning/
Python available here: https://github.com/keras-team/keras-io/blob/master/guides/transfer_learning.py 


In [None]:
import numpy as np

In [None]:
import tensorflow
from tensorflow import keras

In [None]:
import matplotlib.pyplot as plt


In [None]:
import tensorflow_datasets as tfds



Demonstrating trainable vs non-trainable weights

In [None]:
layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

setting layer to non-trainable

In [None]:

layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

In [None]:
tfds.disable_progress_bar()


demonstrating that layer weights dont change during training

In [None]:
# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([tensorflow.keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)

In [None]:
inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), 
     inner_model, 
     keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

# Two transfer learning approaches

The usual transfer learning workflow is summarised as follows:
* Instantiate a base model and load pre-trained weights into it.
* Freeze all layers in the base model by setting trainable = False.
* Create a new model on top of the output of one (or several) layers from the base model.
* Train your new model on your new dataset.
* optional - fine tune model bhy unfreezing weights and running training with new data

Note that an alternative, more lightweight workflow could also be:
1. Instantiate a base model and load pre-trained weights into it.
1. Run your new dataset through it and record the output of one (or several) layers from the base model. This is called feature extraction.
1.  Use that output as input data for a new, smaller model.


basic transfer learning workflow demo, demonstrating the four steps

In [None]:
base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

In [None]:
base_model.trainable = False

In [None]:
inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

In [None]:
model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])


In [None]:
# if we had an additional image dataset to train on, we would then use it here as follows
# model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

Once we have trained the new layers, we may wish to fine tune the whole model with the new dataset

In [None]:
# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])


In [None]:
# Train end-to-end. Be careful to stop before you overfit!
## model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

#### Cstuom training loop

In more specialised, you may use a custom training loop rather than the standard `model.fit()` call. Transfer learning can easily be done in a custom traiining loop where required.

In [None]:
# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False


In [None]:
# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

In [None]:
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()


In [None]:
model.input, model.output

In [None]:
# Iterate over the batches of a dataset.
for inputs, targets in validation_ds:
    # Open a GradientTape.
    with tensorflow.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

## Transfer learning example - cats vs dogs

In [None]:

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test}
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)


In [None]:
print("Number of training samples: %d" % tensorflow.data.experimental.cardinality(train_ds))
print("Number of validation samples: %d" % tensorflow.data.experimental.cardinality(validation_ds))
print("Number of test samples: %d" % tensorflow.data.experimental.cardinality(test_ds))

In [None]:

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

In [None]:
size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tensorflow.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tensorflow.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tensorflow.image.resize(x, size), y))

In [None]:
train_ds.cardinality()

In [None]:
batch_size = 32

train_ds_batch = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds_batch = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds_batch = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

In [None]:
data_augmentation = keras.Sequential(
    [keras.layers.RandomFlip("horizontal"), 
     keras.layers.RandomRotation(0.1),]
)


In [None]:
data_augmentation

In [None]:
for image, label in train_ds.take(1):
    print(image.shape)
    print(label)
    plt.figure(figsize=(10, 10))
    first_image = image
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tensorflow.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(label))
        plt.axis("off")

In [None]:
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.


In [None]:
# base_model.summary()

In [None]:
# Freeze the base_model
base_model.trainable = False


In [None]:

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation
# x = inputs

In [None]:

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)


In [None]:
# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout


In [None]:
outputs = keras.layers.Dense(1)(x)


In [None]:
model = keras.Model(inputs, outputs)


In [None]:

model.summary()

In [None]:
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20


In [None]:
train_ds

In [None]:
train_ds.

In [None]:
%%time
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)

In [None]:
# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()



In [None]:
# recompile the model after changing some parameters
model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)



In [None]:
epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)