<a href="https://colab.research.google.com/github/MarieLvsq/MachineLearning/blob/master/CIFAR10_TransferLearning_ResNet50V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ✅ CIFAR-10 Transfer Learning with ResNet50V2 on Google Colab

import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# Load CIFAR-10 dataset
(x_train_full, y_train_full), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train_full, x_test = x_train_full / 255.0, x_test / 255.0
x_train, x_val, y_train, y_val = train_test_split(
    x_train_full, y_train_full, test_size=5000, stratify=y_train_full, random_state=42
)
y_train = y_train.squeeze()
y_val = y_val.squeeze()
y_test = y_test.squeeze()

# Data augmentation and resizing
data_augment = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomTranslation(0.1, 0.1),
    tf.keras.layers.RandomRotation(0.1),
    tf.keras.layers.RandomContrast(0.1),
])
resize = tf.keras.layers.Resizing(224, 224)
AUTOTUNE = tf.data.AUTOTUNE

# tf.data pipelines
def preprocess_train(x, y):
    x = resize(x)
    x = data_augment(x)
    x = tf.keras.applications.resnet_v2.preprocess_input(x)
    return x, y

def preprocess_val(x, y):
    x = resize(x)
    x = tf.keras.applications.resnet_v2.preprocess_input(x)
    return x, y

train_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(10000)
    .map(preprocess_train, num_parallel_calls=AUTOTUNE)
    .batch(128)
    .prefetch(AUTOTUNE)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices((x_val, y_val))
    .map(preprocess_val, num_parallel_calls=AUTOTUNE)
    .batch(128)
    .prefetch(AUTOTUNE)
)

# Build the model
inp = tf.keras.Input(shape=(224, 224, 3))
x = tf.keras.applications.ResNet50V2(include_top=False, weights="imagenet", pooling="avg")(inp)
x = tf.keras.layers.Dense(512, activation="relu")(x)
x = tf.keras.layers.Dropout(0.5)(x)
out = tf.keras.layers.Dense(10, activation="softmax")(x)
model = tf.keras.Model(inputs=inp, outputs=out)

# Phase 1: Train classifier head (frozen ResNet)
model.layers[1].trainable = False  # ResNet50V2

model.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

history_frozen = model.fit(train_ds, validation_data=val_ds, epochs=10);
model.save("frozen_model.h5");  # Save after phase 1

# Plot Phase 1
plt.plot(history_frozen.history["accuracy"], label="Train Acc (Frozen)")
plt.plot(history_frozen.history["val_accuracy"], label="Val Acc (Frozen)")
plt.xlabel("Epochs"); plt.ylabel("Accuracy"); plt.legend(); plt.show()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50v2_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94668760/94668760[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 0us/step
Epoch 1/10
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m546s[0m 1s/step - accuracy: 0.1815 - loss: 2.2224 - val_accuracy: 0.3186 - val_loss: 1.8920
Epoch 2/10
[1m344/352[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m11s[0m 1s/step - accuracy: 0.2799 - loss: 1.9550

In [None]:
# Phase 2: Fine-tuning

# Load model and unfreeze top 50 layers
model = tf.keras.models.load_model("frozen_model.h5")
base_model = model.layers[1]  # ResNet50V2
for layer in base_model.layers[-50:]:
    layer.trainable = True

# Compile with lower learning rate
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4),
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

# Callbacks
callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3)
]

# Fine-tune
history_fine = model.fit(train_ds, validation_data=val_ds, epochs=20, callbacks=callbacks)

# Plot Phase 2
plt.plot(history_fine.history["accuracy"], label="Train Acc (Fine-Tuned)")
plt.plot(history_fine.history["val_accuracy"], label="Val Acc (Fine-Tuned)")
plt.xlabel("Epochs"); plt.ylabel("Accuracy"); plt.legend(); plt.show()

# Evaluate on test set
test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .map(preprocess_val)
    .batch(128)
    .prefetch(AUTOTUNE)
)
loss, acc = model.evaluate(test_ds)
print(f"Test Accuracy: {acc:.4f}")