In [None]:
import tensorflow as tf

In [None]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
assert X_train.shape == (50000, 32, 32, 3)
assert X_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

In [None]:
def preprocess_image_input(input_images):
  input_images = input_images.astype('float32')
  output_ims = tf.keras.applications.xception.preprocess_input(input_images)
  return output_ims

X_train = preprocess_image_input(X_train)
X_test = preprocess_image_input(X_test)

In [None]:
# Define the input layer
inputs = tf.keras.layers.Input(shape=(32,32,3), name="input_layer")

# Resize the images to the size expected by Xception
resize = tf.keras.layers.UpSampling2D(size=(7,7), name="resize_layer")(inputs)

# Load the Xception model with pre-trained ImageNet weights
feature_extractor = tf.keras.applications.Xception(input_shape=(224, 224, 3),
                                                   include_top=False,
                                                   weights='imagenet'
                                                   )(resize)

# Add the classification layers
x = tf.keras.layers.GlobalAveragePooling2D(name="global_avg_pool_layer")(feature_extractor)
x = tf.keras.layers.Flatten(name="flatten_layer")(x)
x = tf.keras.layers.Dense(1024, activation="relu", name="dense_layer_1", kernel_initializer="he_normal")(x)
x = tf.keras.layers.Dense(512, activation="relu", name="dense_layer_2", kernel_initializer="he_normal")(x)
classification_output = tf.keras.layers.Dense(10, activation="softmax", name="output_layer")(x)

# Define the model
model = tf.keras.Model(inputs=inputs, outputs=classification_output, name="my_model")

# Compile the model
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])

# Print the model summary
model.summary()

In [None]:
# Train the model
history = model.fit(
    X_train, y_train, epochs=5, validation_data=(X_test,y_test))

In [None]:
for layer in model.layers:
  print(layer, layer.trainable)

In [None]:
len(model.layers[2].layers)

In [None]:
for layer in model.layers[2].layers[:100]:
  layer.trainable = False # Freeze the layers

In [None]:
history = model.fit(
    X_train, y_train, epochs=2, validation_data=(X_test,y_test))