In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import numpy as np

In [None]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

In [None]:
class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

In [None]:
x_train, x_test = x_train/255.0, x_test/255.0

In [None]:
# resize for mobilenet
x_train = tf.image.resize(x_train, (96,96))
x_test  = tf.image.resize(x_test, (96,96))

In [None]:
#load pre trained model and freeze layers
base_model = MobileNetV2(weights = 'imagenet', include_top=False, input_shape=(96,96,3))

In [None]:
#freez base model layers
base_model.trainable = False

In [None]:
#add custom classifiers
model = Sequential([
    base_model,
    GlobalAveragePooling2D(),
   Dense(128, activation='relu'),
    Dense(10, activation='softmax')     # 10 classes in CIFAR10
])

In [None]:
model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy']
             )

In [None]:
# d) Train classifier layers only
model.fit(x_train, y_train, epochs=2, validation_data=(x_test, y_test))

In [None]:
# e) Fine-tune: unfreeze base model and train again with smaller LR
base_model.trainable = True

model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
model.fit(x_train, y_train, epochs=1, validation_data=(x_test, y_test))

In [None]:
loss, acc = model.evaluate(x_test, y_test)
print("Final Test Accuracy :", acc)

In [None]:
plt.imshow(x_test[0])
plt.title("Sample Test Image")
plt.axis('off')
plt.show()

img = x_test[0].numpy().reshape(1,96,96,3)
pred = model.predict(img)
pred_class = np.argmax(pred)

print("Predicted Class Index :", pred_class)
print("Predicted Class Name  :", class_names[pred_class])
print(x_test[0])