# Transfer learning & fine-tuning

**Author:** Charles Raja R <br>
**Description:** Complete guide to transfer learning & fine-tuning in Keras.<br>
**Link:** https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/transfer_learning.ipynb#scrollTo=rBENYxlIsMCm

In [7]:
#Setup
import numpy as np
import keras
from keras import layers
import matplotlib.pyplot as plt
# import tensorflow_datasets as tfds

## Introduction

**Transfer learning** consists of taking features learned on one problem, and
leveraging them on a new, similar problem. For instance, features from a model that has
learned to identify racoons may be useful to kick-start a model meant to identify tanukis.

Transfer learning is usually done for tasks where your dataset has too little data to train a full-scale model from scratch.

1. Take layers from a previously trained model.
2. Freeze them, so as to avoid destroying any of the information they contain during
 future training rounds.
3. Add some new, trainable layers on top of the frozen layers. They will learn to turn
 the old features into predictions on a  new dataset.
4. Train the new layers on your dataset.

A last, optional step, is **fine-tuning**, which consists of unfreezing the entire model you obtained above (or part of it), and re-training it on the new data with a very low learning rate. This can potentially achieve meaningful improvements, by incrementally adapting the pretrained features to the new data.

In [5]:
import tensorflow_datasets as tfds

# List some available datasets to confirm connectivity
print(tfds.list_builders()[:5])

2026-01-10 11:47:50.482557: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-01-10 11:47:50.530683: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-01-10 11:47:52.655881: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
  if not hasattr(np, "object"):
2026-01-10 11:47:54.763879: W external/local_xla/xla/tsl/platform/cloud/google_auth_provider.cc:185] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing

['abstract_reasoning', 'accentdb', 'aeslc', 'aflw2k3d', 'ag_news_subset']


## Freezing layers: understanding the `trainable` attribute

Layers & models have three weight attributes:

- `weights` is the list of all weights variables of the layer.
- `trainable_weights` is the list of those that are meant to be updated (via gradient
 descent) to minimize the loss during training.
- `non_trainable_weights` is the list of those that aren't meant to be trained.
 Typically they are updated by the model during the forward pass.

**Example: the `Dense` layer has 2 trainable weights (kernel & bias)**

In [None]:
layer = keras.layers.Dense(3)
layer

<Dense name=dense, built=False>

In general, all weights are trainable weights. The only built-in layer that has
non-trainable weights is the `BatchNormalization` layer. It uses non-trainable weights
 to keep track of the mean and variance of its inputs during training.
To learn how to use non-trainable weights in your own custom layers, see the
[guide to writing new layers from scratch](/guides/making_new_layers_and_models_via_subclassing/).

**Example: the `BatchNormalization` layer has 2 trainable weights and 2 non-trainable
 weights**

**What is Batch Normalization?**<br>
Batch Normalization accelerates training and improves stability by normalizing the inputs of each layer. It scales and shifts activations to maintain a consistent mean and variance, reducing sensitivity to initialization.

**What is convolutional layers?**
A convolutional layer uses small filters (kernels) to scan images, extracting key features like edges or textures. It preserves spatial relationships while reducing data complexity, forming the foundation of CNNs.

**What is activation Layer?**
It applies a mathematical function (like ReLU) to the output, introducing non-linearity. This allows the model to learn complex patterns instead of just simple linear relationships.

**What is MaxPooling2D layer?**
This layer reduces the spatial dimensions (height/width) of images by keeping only the maximum value in a window. It shrinks data size while retaining the most important features.

**What is Flatten Layer?**
It converts a multi-dimensional feature map (like a 2D image) into a single long 1D vector. This "unrolling" prepares the data for the final classification layers.

**What is Dense Layer?**
A fully connected layer where every input neuron connects to every output neuron. It performs the final logic, combining all extracted features to predict the specific image category.

In [None]:
#Example code using batch normalization
import tensorflow as tf
from keras import models

model = models.Sequential([
    # 1. Convolutional Layer
    layers.Conv2D(32, (3, 3), input_shape=(224, 224, 3)),
    
    # 2. Batch Normalization (Placed BEFORE the activation)
    layers.BatchNormalization(),
    # why this layer? 
    # Without it: If the pixel values in the input image are very different,
    #the model might struggle to learn at a steady pace. 

    # With it: It "recenters" the data. This allows you to use a higher
    #  learning rate, making your training much faster and less likely to
    #  crash.

    
    # 3. Activation
    layers.Activation('relu'),
    
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(1, activation='sigmoid')
])

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [17]:
layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = True  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

weights: 2
trainable_weights: 2
non_trainable_weights: 0


In [18]:
layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

weights: 2
trainable_weights: 0
non_trainable_weights: 2


In [None]:
shape=(3,)
shape

(3,)

In [26]:
np.random.random((2, 3))

array([[0.51906519, 0.81934024, 0.19999991],
       [0.72525594, 0.46474565, 0.52695146]])

In [27]:
np.random.random((2, 3)), np.random.random((2, 3))

(array([[0.023283  , 0.1121046 , 0.42043985],
        [0.80346184, 0.81025372, 0.73415878]]),
 array([[0.11021325, 0.52658641, 0.10283493],
        [0.0625319 , 0.71689214, 0.69571295]]))

In [28]:
keras.Input(shape=(3, ))

<KerasTensor shape=(None, 3), dtype=float32, sparse=False, ragged=False, name=keras_tensor_14>

**what is relu activation function?**<br>
ReLU (Rectified Linear Unit) An activation function that outputs the input directly if positive, otherwise zero. It introduces non-linearity, helping models learn complex patterns while preventing training slowdowns compared to older functions.
<br>
<br>
**What is adam optimizer?**<br>
Adam An advanced optimizer that automatically adjusts the learning rate for each parameter. It combines the benefits of `momentum and adaptive gradients`, making it fast, robust, and very popular.
<br>
<br>
**what is weights?**<br>
The learnable parameters inside neurons that determine the "strength" of a connection. During training, the model adjusts these numbers to minimize errors and improve prediction accuracy.
<br>
<br>
**what is optimizer?**<br>
The algorithm that updates the model's weights based on the loss function. It acts like a guide, telling the model how to change its parameters to reach the best performance.


In [31]:
# make a model with 2 layers
layer1 = keras.layers.Dense(units=3, activation="relu")
# what is units? 
# Units represent the number of neurons in the layer. 
# Each unit acts as a single learning node that detects a specific feature,
# outputting 3 distinct values for the next layer.
layer2 = keras.layers.Dense(units=3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3, )), layer1, layer2])

layer1.trainable = False
initial_layer1_weights_values = layer1.get_weights()

model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 999ms/step - loss: 0.1403


<keras.src.callbacks.history.History at 0x7add815415b0>

In [35]:
len(initial_layer1_weights_values)

2

In [34]:
initial_layer1_weights_values[0], initial_layer1_weights_values[1]

(array([[-0.06940031,  0.9313228 ,  0.10630417],
        [ 0.18633103,  0.28914762,  0.8375237 ],
        [ 0.9452481 , -0.12755728, -0.13433862]], dtype=float32),
 array([0., 0., 0.], dtype=float32))

In [36]:
final_layer1_weights_values = layer1.get_weights()

In [37]:
final_layer1_weights_values[0], final_layer1_weights_values[1]

(array([[-0.06940031,  0.9313228 ,  0.10630417],
        [ 0.18633103,  0.28914762,  0.8375237 ],
        [ 0.9452481 , -0.12755728, -0.13433862]], dtype=float32),
 array([0., 0., 0.], dtype=float32))

In [42]:
final_layer1_weights_values = layer1.get_weights()
# This code Raises an AssertionError if two objects are not equal up to desired tolerance.
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)

In [41]:
np.testing.assert_allclose(
    0, 1
)

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 1 / 1 (100%)
Max absolute difference among violations: 1
Max relative difference among violations: 1.
 ACTUAL: array(0)
 DESIRED: array(1)

## Recursive setting of the `trainable` attribute

If you set `trainable = False` on a model or on any layer that has sublayers,
all children layers become non-trainable as well.