https://keras.io/guides/transfer_learning/#the-typical-transferlearning-workflow

Transfer learning & fine-tuning

In [49]:
import numpy as np
import tensorflow as tf
import keras
from keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

Freezing layers: understanding the trainable attribute

In [50]:
layer = layers.Dense(2)
print("weights:", len(layer.weights))

layer.build((None, 3))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
layer.weights

weights: 0
weights: 2
trainable_weights: 2
non_trainable_weights: 0


[<tf.Variable 'kernel:0' shape=(3, 2) dtype=float32, numpy=
 array([[ 0.10018468, -0.5019272 ],
        [ 0.27510214,  0.8585675 ],
        [ 0.8921156 ,  0.74958456]], dtype=float32)>,
 <tf.Variable 'bias:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>]

In [51]:
# https://keras.io/api/layers/

layer = layers.Dense(3, activation='relu')
print("weights:", len(layer.weights))
inputs = tf.random.uniform(shape=(2, 4))
print(inputs)
print("weights:", len(layer.weights))
outputs = layer(inputs)
print("weights:", len(layer.weights))

layer.weights

weights: 0
tf.Tensor(
[[0.588096   0.7298387  0.8849908  0.09413445]
 [0.70745265 0.44963872 0.459329   0.58028686]], shape=(2, 4), dtype=float32)
weights: 0
weights: 2


[<tf.Variable 'dense_41/kernel:0' shape=(4, 3) dtype=float32, numpy=
 array([[-0.48006186, -0.5630181 ,  0.7930138 ],
        [ 0.46786106, -0.80529165, -0.01191491],
        [ 0.20314252, -0.23659569,  0.1299913 ],
        [ 0.40866828,  0.39261127,  0.18158376]], dtype=float32)>,
 <tf.Variable 'dense_41/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]

In [52]:
layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

weights: 4
trainable_weights: 2
non_trainable_weights: 2


In [53]:
layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("len weights:", len(layer.weights))
print("weights:", layer.weights)
print(layer.get_weights())
# print("weights:", layer.weights['numpy'])
# print("weights shape:", layer.weights.shape)
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

len weights: 2
weights: [<tf.Variable 'kernel:0' shape=(4, 3) dtype=float32, numpy=
array([[-6.6014224e-01, -3.4047961e-03, -8.5234791e-01],
       [-4.9411654e-02, -5.6203139e-01,  2.8578508e-01],
       [ 8.9249194e-01, -3.0526870e-01, -8.3198309e-01],
       [-3.3195263e-01,  3.3527553e-01,  7.8141689e-05]], dtype=float32)>, <tf.Variable 'bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]
[array([[-6.6014224e-01, -3.4047961e-03, -8.5234791e-01],
       [-4.9411654e-02, -5.6203139e-01,  2.8578508e-01],
       [ 8.9249194e-01, -3.0526870e-01, -8.3198309e-01],
       [-3.3195263e-01,  3.3527553e-01,  7.8141689e-05]], dtype=float32), array([0., 0., 0.], dtype=float32)]
trainable_weights: 0
non_trainable_weights: 2


In [54]:
# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()
print(initial_layer1_weights_values)


# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)

[array([[-0.556417  , -0.6231663 , -0.9735708 ],
       [ 0.18122339, -0.563746  , -0.45962238],
       [-0.6202657 ,  0.40935802,  0.7165365 ]], dtype=float32), array([0., 0., 0.], dtype=float32)]


Recursive setting of the trainable attribute

In [55]:
inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        inner_model,
        keras.layers.Dense(3, activation="sigmoid"),
    ]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively