 Imports & constants

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, utils, datasets
from tensorflow.keras.applications import Xception
from tensorflow.keras.callbacks import EarlyStopping

# number of CIFAR-10 classes
NUM_CLASSES = 10


Load & preprocess data

In [2]:
# 1. Load CIFAR-10
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()

# normalize to [0,1]
x_train = x_train.astype("float32") / 255.0
x_test  = x_test.astype("float32") / 255.0

# resize to >=71×71 for Xception
x_train = tf.image.resize(x_train, (75, 75))
x_test  = tf.image.resize(x_test,  (75, 75))

# one-hot encode labels
y_train = utils.to_categorical(y_train, NUM_CLASSES)
y_test  = utils.to_categorical(y_test,  NUM_CLASSES)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step


Build the model

In [3]:
# 2. Base Xception (frozen)
base_model = Xception(
    include_top=False,
    weights="imagenet",
    input_shape=(75, 75, 3),
    pooling="avg"
)
base_model.trainable = False

# add custom head
input_layer = layers.Input(shape=(75, 75, 3))
x = base_model(input_layer, training=False)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.5)(x)
output_layer = layers.Dense(NUM_CLASSES, activation="softmax")(x)

model = models.Model(input_layer, output_layer)
model.summary()


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m83683744/83683744[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


Compile & train top layers

In [9]:
# 3. Compile
opt = optimizers.Adam(learning_rate=1e-3)
model.compile(
    loss="categorical_crossentropy",
    optimizer=opt,
    metrics=["accuracy"]
)

# 4. Train head only
history = model.fit(
    x_train, y_train,
    batch_size=64,
    epochs=15,
    validation_data=(x_test, y_test),
    callbacks=[EarlyStopping(patience=3, restore_best_weights=True)],
    shuffle=True,
)


Epoch 1/15
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m150s[0m 143ms/step - accuracy: 0.8123 - loss: 0.6119 - val_accuracy: 0.7929 - val_loss: 1.4557
Epoch 2/15
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m110s[0m 121ms/step - accuracy: 0.9030 - loss: 0.3073 - val_accuracy: 0.8691 - val_loss: 0.4378
Epoch 3/15
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 120ms/step - accuracy: 0.9392 - loss: 0.1925 - val_accuracy: 0.8579 - val_loss: 0.5181
Epoch 4/15
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m143s[0m 121ms/step - accuracy: 0.9497 - loss: 0.1602 - val_accuracy: 0.8900 - val_loss: 0.3813
Epoch 5/15
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 121ms/step - accuracy: 0.9630 - loss: 0.1196 - val_accuracy: 0.8760 - val_loss: 0.4353
Epoch 6/15
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 120ms/step - accuracy: 0.9643 - loss: 0.1140 - val_accuracy: 0.8933 - val_loss: 0.4760
Epoc

 Unfreeze & fine-tune

In [10]:
# 5. Fine-tune entire model
base_model.trainable = True
opt_fine = optimizers.Adam(learning_rate=1e-5)

model.compile(
    loss="categorical_crossentropy",
    optimizer=opt_fine,
    metrics=["accuracy"]
)

history_fine = model.fit(
    x_train, y_train,
    batch_size=64,
    epochs=10,
    validation_data=(x_test, y_test),
    callbacks=[EarlyStopping(patience=2, restore_best_weights=True)],
    shuffle=True,
)


Epoch 1/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m150s[0m 143ms/step - accuracy: 0.9723 - loss: 0.0860 - val_accuracy: 0.9277 - val_loss: 0.2363
Epoch 2/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m110s[0m 121ms/step - accuracy: 0.9846 - loss: 0.0493 - val_accuracy: 0.9325 - val_loss: 0.2358
Epoch 3/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 121ms/step - accuracy: 0.9895 - loss: 0.0342 - val_accuracy: 0.9350 - val_loss: 0.2399
Epoch 4/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 121ms/step - accuracy: 0.9913 - loss: 0.0258 - val_accuracy: 0.9362 - val_loss: 0.2476
