In [None]:
import tensorflow as tf
from tensorflow.keras.applications import VGG16, ResNet50, MobileNetV2
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
import numpy as np

# Choose a pre-trained model (Change here for different models)
BASE_MODEL = 'VGG16'  # Options: 'VGG16', 'ResNet50', 'MobileNetV2'

# Load the dataset
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

# Normalize the images
train_images, test_images = train_images / 255.0, test_images / 255.0

# Convert labels to one-hot encoding
train_labels = to_categorical(train_labels, 10)
test_labels = to_categorical(test_labels, 10)

# Load the selected model
if BASE_MODEL == 'VGG16':
    base_model = VGG16(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
elif BASE_MODEL == 'ResNet50':
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
elif BASE_MODEL == 'MobileNetV2':
    base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
else:
    raise ValueError("Invalid BASE_MODEL. Choose from 'VGG16', 'ResNet50', or 'MobileNetV2'")

# Freeze initial layers
base_model.trainable = False  # Set to True if fine-tuning is needed

# Add custom classification head
model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax')  # Adjust for number of classes
])

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(train_images, train_labels, epochs=5, batch_size=32, validation_data=(test_images, test_labels))

# Evaluate the model
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc:.4f}")


Epoch 1/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 41ms/step - accuracy: 0.4110 - loss: 1.6723 - val_accuracy: 0.5517 - val_loss: 1.2816
Epoch 2/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 40ms/step - accuracy: 0.5421 - loss: 1.3075 - val_accuracy: 0.5695 - val_loss: 1.2145
Epoch 3/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 40ms/step - accuracy: 0.5632 - loss: 1.2509 - val_accuracy: 0.5820 - val_loss: 1.1868
Epoch 4/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 40ms/step - accuracy: 0.5763 - loss: 1.2083 - val_accuracy: 0.5862 - val_loss: 1.1754
Epoch 5/5
[1m 991/1563[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m19s[0m 34ms/step - accuracy: 0.5831 - loss: 1.1946