In [None]:
# ---------------- Cell 1: Imports & reproducibility ----------------
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

print("TensorFlow:", tf.__version__)


In [None]:
# ---------------- Cell 2: Configure dataset paths and params ----------------
# Replace these with your actual directories
TRAIN_DIR = "dataset/train/"   # should contain subfolders for each class
VAL_DIR   = "dataset/val/"     # optional, else use validation_split in generator
TEST_DIR  = "dataset/test/"    # optional

# Image sizing and batch
IMG_SIZE = (224, 224)   # common for most pretrained nets; change to (32,32) for small nets
BATCH_SIZE = 32
NUM_CLASSES = None      # will be inferred from generator

# Set this to 'mobilenet' or 'vgg16' or 'resnet50' depending on the backbone you want
BACKBONE = "mobilenet"  # options: "mobilenet", "vgg16", "resnet50"


In [None]:
# ------------------ Safe CIFAR-10 fallback using tf.data (replacement) ------------------
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
AUTOTUNE = tf.data.AUTOTUNE

# Ensure IMG_SIZE, BATCH_SIZE, SEED are defined — set sensible defaults if not
try:
    IMG_SIZE
except NameError:
    IMG_SIZE = (224, 224)   # you can reduce to (128,128) or (96,96) if still tight on memory
try:
    BATCH_SIZE
except NameError:
    BATCH_SIZE = 32
try:
    SEED
except NameError:
    SEED = 42

print("Using CIFAR-10 fallback. Lazy resizing via tf.data to avoid OOM.")
(x_train_all, y_train_all), (x_test_all, y_test_all) = tf.keras.datasets.cifar10.load_data()

# Flatten label arrays
y_train_all = y_train_all.flatten()
y_test_all  = y_test_all.flatten()

# We'll create a train/val split from CIFAR train set (10% val)
X_tr, X_val, y_tr, y_val = train_test_split(
    x_train_all, y_train_all, test_size=0.1, random_state=SEED, stratify=y_train_all
)

NUM_CLASSES = 10
class_indices = {
    'airplane':0,'automobile':1,'bird':2,'cat':3,'deer':4,
    'dog':5,'frog':6,'horse':7,'ship':8,'truck':9
}

# Preprocessing functions (applied per-sample)
def preprocess_train(image, label):
    # convert to float [0,1]
    image = tf.cast(image, tf.float32) / 255.0
    # resize (done per image to avoid creating huge tensor at once)
    image = tf.image.resize(image, IMG_SIZE)
    # simple augmentation (random flip + slight random brightness)
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.08)
    # ensure shape
    image.set_shape([IMG_SIZE[0], IMG_SIZE[1], 3])
    label = tf.one_hot(label, NUM_CLASSES)
    return image, label

def preprocess_val(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.image.resize(image, IMG_SIZE)
    image.set_shape([IMG_SIZE[0], IMG_SIZE[1], 3])
    label = tf.one_hot(label, NUM_CLASSES)
    return image, label

# Create tf.data datasets (these do not allocate entire resized arrays at once)
train_ds = tf.data.Dataset.from_tensor_slices((X_tr, y_tr))
train_ds = train_ds.shuffle(buffer_size=5000, seed=SEED)
train_ds = train_ds.map(preprocess_train, num_parallel_calls=AUTOTUNE)
train_ds = train_ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_ds = val_ds.map(preprocess_val, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)

print("train_ds element spec:", train_ds.element_spec)
print("val_ds element spec  :", val_ds.element_spec)
print("NUM_CLASSES:", NUM_CLASSES)


In [None]:
# ---------------- Cell 4: Load pretrained backbone (without top) ----------------
from tensorflow.keras import layers, models

if BACKBONE == "mobilenet":
    base_model = keras.applications.MobileNetV2(
        input_shape=(*IMG_SIZE, 3), include_top=False, weights='imagenet'
    )
elif BACKBONE == "resnet50":
    base_model = keras.applications.ResNet50(
        input_shape=(*IMG_SIZE, 3), include_top=False, weights='imagenet'
    )
elif BACKBONE == "vgg16":
    base_model = keras.applications.VGG16(
        input_shape=(*IMG_SIZE, 3), include_top=False, weights='imagenet'
    )
else:
    raise ValueError("BACKBONE must be one of: mobilenet, resnet50, vgg16")

# Print summary of backbone
base_model.summary()


In [None]:
# ---------------- Cell 5: Freeze lower layers (all for initial training) ----------------
# Freeze all layers first (so only classifier head trains)
for layer in base_model.layers:
    layer.trainable = False

print("Backbone trainable layers after freeze:", sum([1 for l in base_model.layers if l.trainable]))


In [None]:
# ---------------- Cell 6: Add custom classifier head ----------------
# Typical head: GlobalAveragePooling -> Dense -> Dropout -> Dense(num_classes, softmax)
inputs = keras.Input(shape=(*IMG_SIZE, 3))
x = base_model(inputs, training=False)   # pass training=False so BN layers run in inference mode while frozen
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(512, activation='relu')(x)
x = layers.Dropout(0.4)(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)

model = models.Model(inputs, outputs, name=f"{BACKBONE}_transfer")
model.summary()


In [None]:
# ---------------- Cell 7: Compile model (train only head) ----------------
# Use SGD as requested (or Adam). SGD with momentum is common for fine-tuning.
initial_lr = 1e-3
optimizer = keras.optimizers.SGD(learning_rate=initial_lr, momentum=0.9)

model.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy']
)


In [None]:
# ---------------- Cell 8: Callbacks ----------------
checkpoint_cb = keras.callbacks.ModelCheckpoint("best_transfer_model.keras", save_best_only=True, monitor="val_loss")
earlystop_cb = keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True)
reduce_lr_cb = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3)

callbacks = [checkpoint_cb, earlystop_cb, reduce_lr_cb]


In [None]:
# ---------------- Fixed Cell 9 — Train classifier head ----------------
EPOCHS_HEAD = 1  # number of epochs for the frozen backbone stage

# Pick whichever dataset variables exist
if "train_ds" in locals():
    train_data = train_ds
    val_data = val_ds
    print("Training on tf.data datasets (CIFAR-10 fallback).")
elif "train_gen" in locals():
    train_data = train_gen
    val_data = val_gen
    print("Training on ImageDataGenerator generators (custom dataset).")
else:
    raise NameError("No training dataset found. Define train_ds/val_ds or train_gen/val_gen first.")

history_head = model.fit(
    train_data,
    validation_data=val_data,
    epochs=EPOCHS_HEAD,
    callbacks=callbacks,
    verbose=2
)


In [None]:
# ---------------- Cell 10: Evaluate after head training ----------------
val_loss, val_acc = model.evaluate(val_ds, verbose=0)
print(f"Validation loss: {val_loss:.4f}, Validation accuracy: {val_acc:.4f}")



In [None]:
# ---------------- Cell 11: Fine-tune - unfreeze top layers of backbone ----------------
# Strategy: unfreeze a fraction of top layers (closer to output)
# Option: unfreeze last N layers or unfreeze by layer name.
# Here we unfreeze the top convolutional block (last 20% of layers).

# Count layers and decide cutoff
total_layers = len(base_model.layers)
# Choose how many layers to unfreeze (e.g., last 20%)
pct_unfreeze = 0.2
num_to_unfreeze = int(total_layers * pct_unfreeze)
cutoff = total_layers - num_to_unfreeze

for i, layer in enumerate(base_model.layers):
    layer.trainable = True if i >= cutoff else False

print("Total backbone layers:", total_layers)
print("Unfrozen layers (trainable):", sum([1 for l in base_model.layers if l.trainable]))

# Re-compile with lower learning rate for fine-tuning
fine_tune_lr = 1e-4
optimizer_finetune = keras.optimizers.SGD(learning_rate=fine_tune_lr, momentum=0.9)
model.compile(optimizer=optimizer_finetune, loss='categorical_crossentropy', metrics=['accuracy'])


In [None]:
# ---------------- Cell 12: Continue training (fine-tuning) ----------------
EPOCHS_FINETUNE = 2

history_finetune = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS_FINETUNE,
    callbacks=callbacks,
    verbose=2
)


In [None]:
# ---------------- Cell 13: Final evaluation on validation / test set ----------------
test_gen = val_datagen.flow_from_directory(TEST_DIR, target_size=IMG_SIZE, batch_size=BATCH_SIZE, class_mode='categorical', shuffle=False) if os.path.isdir(TEST_DIR) else None

if test_gen is not None:
    test_loss, test_acc = model.evaluate(test_gen, verbose=0)
    print(f"Test loss: {test_loss:.4f}, Test accuracy: {test_acc:.4f}")
else:
    val_loss, val_acc = model.evaluate(val_ds, verbose=0)
    print(f"Validation loss (final): {val_loss:.4f}, Validation accuracy (final): {val_acc:.4f}")


In [None]:
# ---------------- Cell 14: Visualize some predictions ----------------
# Create an iterator for the validation dataset
val_iter = iter(val_ds)

# Get the next batch
x_batch, y_batch = next(val_iter)

# Predict
preds = model.predict(x_batch)
pred_labels = np.argmax(preds, axis=1)
true_labels = np.argmax(y_batch, axis=1)

# Plot first 12 images with true and predicted labels
import matplotlib.pyplot as plt

plt.figure(figsize=(12,8))
for i in range(min(12, x_batch.shape[0])):
    ax = plt.subplot(3,4,i+1)
    plt.imshow(x_batch[i])
    plt.title(f"True: {list(class_indices.keys())[true_labels[i]]}\nPred: {list(class_indices.keys())[pred_labels[i]]}")
    plt.axis('off')
plt.tight_layout()
plt.show()


In [None]:
# ---------------- Cell 15: Save final model and tips ----------------
model.save("transfer_finetuned_model.keras")
print("Saved model to transfer_finetuned_model.keras")

# Tips:
# - If validation accuracy is poor: try stronger augmentation, more training epochs,
#   or use a deeper backbone (ResNet50 / EfficientNet) and larger IMG_SIZE.
# - For very large datasets, use tf.data pipeline for performance instead of ImageDataGenerator.
# - If using VGG16 with small input (32x32), note pretrained ImageNet weights expect >=32x32; prefer 224x224 for VGG.
