# Transfer Learning and Fine-tuning with Keras Hub
This notebook demonstrates how to use **transfer learning** and **fine-tuning** with pre-trained models from [Keras Hub](https://keras.io/keras_hub/) to achieve better image classification accuracy than training a model from scratch. This is a common and effective strategy when you have a small dataset, limited computational resources, or insufficient time to train a large model from the beginning.

**Learning Objectives**
- **Transfer Learning**: Learn how to leverage a pre-trained model as a feature extractor.
- **Fine-tuning**: Understand how to fine-tune a pre-trained model's weights for a specific task.
- **Learning Rate Scheduling**: Discover how a learning rate schedule can be used to achieve stable and effective fine-tuning.

## Setup

In [None]:
import os
import warnings

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
warnings.filterwarnings("ignore")

In [None]:
import pathlib

import IPython.display as display
import keras
import keras_hub
import matplotlib.pylab as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from keras import Sequential
from keras.layers import (
    Conv2D,
    Dense,
    Dropout,
    Flatten,
    MaxPooling2D,
    Softmax,
)
from PIL import Image

## Exploring the data

As usual, let's take a look at the data before we start building our model. We'll be using a creative-commons licensed flower photo dataset of 3670 images falling into 5 categories: 'daisy', 'roses', 'dandelion', 'sunflowers', and 'tulips'.

The below [keras.utils.get_file](https://keras.io/api/utils/python_utils/#getfile-function) command downloads a dataset to the local Keras cache. To see the files through a terminal, copy the output of the cell below.

In [None]:
data_dir = keras.utils.get_file(
    "flower_photos",
    "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
    untar=True,
)

# Print data path
print("cd", data_dir)

We can use python's built in [pathlib](https://docs.python.org/3/library/pathlib.html) tool to get a sense of this unstructured data.

In [None]:
data_dir_path = pathlib.Path(data_dir)

image_count = len(list(data_dir_path.glob("*/*/*.jpg")))
print("There are", image_count, "images.")

CLASS_NAMES = np.array(
    [
        item.name
        for item in data_dir_path.glob("*/*")
        if item.name != "LICENSE.txt"
    ]
)
print("These are the available classes:", CLASS_NAMES)

Let's display some images so we can see what our model will be trying to learn.

In [None]:
roses = list(data_dir_path.glob("*/roses/*"))

for image_path in roses[:3]:
    display.display(Image.open(str(image_path)))

## Building the Dataset
For training a machine learning model, the data must be organized and loaded efficiently. Here we use [keras.utils.image_dataset_from_directory](https://keras.io/api/data_loading/image/#imagedatasetfromdirectory-function) to create a tf.data.Dataset object. 

Ths method assumes the data is stored in this directory structure:



```
main_directory/
...class_a/
......a_image_1.jpg
......a_image_2.jpg
...class_b/
......b_image_1.jpg
......b_image_2.jpg
```

The utility automatically handles:
- File Discovery: It scans the directory structure to find all image files.
- Labeling: It infers the class labels from the subdirectory names (e.g., roses, tulips).
- Splitting: It divides the data into training (80%) and validation (20%) subsets as specified by `validation_split` and return both subsets (`subset="both"`).
- Resizing: It resizes all images to a consistent size (224x224 pixels) to be compatible with the model architecture we'll use later.
- Batching: It groups the images into batches for efficient training, with a batch size of 32.

This process ensures the data is ready for model consumption without the need for manual loops or complex file handling.

In [None]:
IMG_HEIGHT = 224
IMG_WIDTH = 224
BATCH_SIZE = 32
# 10 is a magic number tuned for local training of this dataset.
VALIDATION_SPLIT = 0.2

SEED = 999

In [None]:
train_ds, eval_ds = keras.utils.image_dataset_from_directory(
    f"{data_dir}/flower_photos",
    seed=SEED,
    validation_split=VALIDATION_SPLIT,
    subset="both",
    batch_size=BATCH_SIZE,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    label_mode="categorical",
)

## Simple CNN Model

These flower photos are much larger than handwritting recognition images in MNIST. They are about 10 times as many pixels per axis **and** there are three color channels, making the information here over 200 times larger!

We've also added a `Rescaling` layer to rescale the pixel range from [0,255] to [0,1]. We can apply this operation in a data pipeline, but we can avoid training-serving schew by incorporating it in the model itself.

How do our current techniques stand up? Copy your best model architecture over from the <a href="2_mnist_cnn.ipynb">MNIST models lab</a> and see how well it does after training for 10 epochs.

In [None]:
nclasses = len(CLASS_NAMES)
hidden_layer_1_neurons = 400
hidden_layer_2_neurons = 100
dropout_rate = 0.25
num_filters_1 = 64
kernel_size_1 = 3
pooling_size_1 = 2
num_filters_2 = 32
kernel_size_2 = 3
pooling_size_2 = 2

layers = [
    keras.layers.Rescaling(scale=1.0 / 255),
    Conv2D(
        num_filters_1,
        kernel_size=kernel_size_1,
        activation="relu",
        input_shape=(IMG_WIDTH, IMG_HEIGHT, 3),
    ),
    MaxPooling2D(pooling_size_1),
    Conv2D(num_filters_2, kernel_size=kernel_size_2, activation="relu"),
    MaxPooling2D(pooling_size_2),
    Flatten(),
    Dense(hidden_layer_1_neurons, activation="relu"),
    Dense(hidden_layer_2_neurons, activation="relu"),
    Dropout(dropout_rate),
    Dense(nclasses),
    Softmax(),
]

cnn_model = Sequential(layers)
cnn_model.compile(
    optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]
)

In [None]:
history = cnn_model.fit(
    train_ds,
    epochs=10,
    validation_data=eval_ds,
)

In [None]:
ACCURACY_COLS = ["accuracy", "val_accuracy"]
_ = pd.DataFrame(history.history)[ACCURACY_COLS].plot()

## Transfer Learning with Keras Hub

If your model is like mine, it learns a little bit, slightly better then random, but *ugh*, it's overfitting! Since our model is too shallow for this dataset, it may be using very low-level features (like colors) and failing to capture the overall semantics.

To overcome the limitations of the simple CNN, we introduce transfer learning. 

We're leveraging a pre-trained image model—MobileNetV3 from Keras Hub—and repurposing it for our task. This model is ideal because it's lightweight and built for mobile inference.

We can retrieve a pre-trained feature extractor object by specifying a model identifier in the `keras_hub.models.Backbone.from_preset()` function. Please refer to [the document](https://keras.io/keras_hub/presets/) for available models.

Here we set `backbone.trainable` to `False` for transfer learning, so the millions of weights in the pre-trained model are not updated during training. 

In [None]:
backbone = keras_hub.models.Backbone.from_preset(
    "mobilenet_v3_large_100_imagenet_21k",
)
backbone.trainable = False

backbone.summary()

Note that the "Trainable params" is 0 in the backbone.

Instead, a new classification head (a simple Dense layer) is added on top, and only this new head is trained to classify the flower images. 

This approach leverages the powerful, pre-learned features of the backbone to achieve a strong performance baseline quickly, without the risk of overfitting the small dataset. 

In [None]:
transfer_model = keras.Sequential(
    [
        keras.Input(shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
        keras.layers.Rescaling(scale=1.0 / 255),
        backbone,
        keras.layers.GlobalMaxPooling2D(),
        keras.layers.Dropout(rate=0.2),
        keras.layers.Dense(
            nclasses,
            activation="softmax",
            kernel_regularizer=keras.regularizers.l2(0.0001),
        ),
    ]
)
# transfer_model.build((None,) + (IMG_HEIGHT, IMG_WIDTH, 3))
transfer_model.summary()

In [None]:
transfer_model.compile(
    optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]
)

In [None]:
history = transfer_model.fit(
    train_ds,
    epochs=10,
    validation_data=eval_ds,
)

In [None]:
ACCURACY_COLS = ["accuracy", "val_accuracy"]
_ = pd.DataFrame(history.history)[ACCURACY_COLS].plot()

Alright, looking better! Still, there's room to improve.

## Finetuning MobileNet
Following the transfer learning step, let's see how fine-tuning works. Here we unfreeze the backbone and train with our classification head. 

Note that we **don't** set `backbone.trainable` to `False` to leave it trainable.

In [None]:
backbone = keras_hub.models.Backbone.from_preset(
    "mobilenet_v3_large_100_imagenet_21k",
)

# Leave trainable True for fine-tuning
# backbone.trainable = False

finetune_model = keras.Sequential(
    [
        keras.Input(shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
        keras.layers.Rescaling(scale=1.0 / 255),
        backbone,
        keras.layers.GlobalMaxPooling2D(),
        keras.layers.Dropout(rate=0.2),
        keras.layers.Dense(
            nclasses,
            activation="softmax",
            kernel_regularizer=keras.regularizers.l2(0.0001),
        ),
    ]
)
finetune_model.summary()

In [None]:
finetune_model.compile(
    optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]
)

history = finetune_model.fit(
    train_ds,
    epochs=10,
    validation_data=eval_ds,
)

In [None]:
ACCURACY_COLS = ["accuracy", "val_accuracy"]
_ = pd.DataFrame(history.history)[ACCURACY_COLS].plot()

While the training accuracy looks better than the previous case, the validation accuracy remains the same.

Because now we are tuning a lot of parameters in Mobilenet, it easily overfits to this small dataset.

Furthermore, the default learning rate might have corrupted the knowledge MobileNet acquired during pre-training. A careful selection of the learning rate is crucial for effective fine-tuning.


## Learning Rate Schedule for Fine-tuning

Now, let's add an additional technique for fine-tuning.
Since fine-tuning is very sensitive to the learning rate, it is important to use a carefully designed **learning rate schedule**.


A typical learning rate schedule involves two phases: 
- **Warmup Phase**: The learning rate starts small and increases slowly to prevent breaking pre-trained model patterns and to ease into the new task
- **Decay Phase**: The learning rate is gradually reduced for finer tuning and improved convergence

Here we design this schedule using `keras.optimizers.schedules.CosineDecay`, where we can specify the warmup period flexibly.

In [None]:
epochs = 10
steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()
num_train_steps = steps_per_epoch * epochs
num_warmup_steps = int(0.1 * num_train_steps)

warmup_tgt = 0.001
scheduler = keras.optimizers.schedules.CosineDecay(
    0.0,
    num_train_steps,
    warmup_target=warmup_tgt,
    warmup_steps=num_warmup_steps,
)

optimizer = keras.optimizers.Adam(learning_rate=scheduler)

In [None]:
plt.plot([scheduler(lr) for lr in range(num_train_steps)])
plt.title("Learning Rate Schedule")
plt.xlabel("Steps")
plt.ylabel("Learning Rate")
plt.show()

In [None]:
backbone = keras_hub.models.Backbone.from_preset(
    "mobilenet_v3_large_100_imagenet_21k",
)

ft_lr_schedule_model = keras.Sequential(
    [
        keras.Input(shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
        keras.layers.Rescaling(scale=1.0 / 255),
        backbone,
        keras.layers.GlobalMaxPooling2D(),
        keras.layers.Dropout(rate=0.2),
        keras.layers.Dense(
            nclasses,
            activation="softmax",
            kernel_regularizer=keras.regularizers.l2(0.0001),
        ),
    ]
)

Let's train the same model using the new learning rate scheduler.

In [None]:
ft_lr_schedule_model.compile(
    optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]
)

history = ft_lr_schedule_model.fit(
    train_ds,
    epochs=10,
    validation_data=eval_ds,
)

In [None]:
ACCURACY_COLS = ["accuracy", "val_accuracy"]
_ = pd.DataFrame(history.history)[ACCURACY_COLS].plot()

Thanks to the new learning rate schedule, we're now seeing better training and validation performance, as well as a more stable learning curve.

Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.