In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2

# Enable mixed precision
tf.keras.mixed_precision.set_global_policy('mixed_float16')

# Hyperparameters
num_epochs = 5
learning_rate = 0.001
batch_size = 32
resize_to = (160, 160)

# Load CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Build tf.data.Dataset
def preprocess(image, label):
    image = tf.image.resize(image, resize_to)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(5000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)

# Build model
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(160,160,3))
base_model.trainable = False  # Freeze

model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax', dtype='float32')  # Output back to float32
])

# Compile
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train
history = model.fit(train_ds, epochs=num_epochs, validation_data=test_ds)

# Evaluate
test_loss, test_acc = model.evaluate(test_ds, verbose=2)
print(f'Test Accuracy: {test_acc * 100:.2f}%')


Epoch 1/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m56s[0m 26ms/step - accuracy: 0.6727 - loss: 0.9676 - val_accuracy: 0.8000 - val_loss: 0.5663
Epoch 2/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 17ms/step - accuracy: 0.7796 - loss: 0.6465 - val_accuracy: 0.8140 - val_loss: 0.5369
Epoch 3/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 17ms/step - accuracy: 0.7947 - loss: 0.5942 - val_accuracy: 0.8148 - val_loss: 0.5368
Epoch 4/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 19ms/step - accuracy: 0.8042 - loss: 0.5651 - val_accuracy: 0.8232 - val_loss: 0.5146
Epoch 5/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 17ms/step - accuracy: 0.8126 - loss: 0.5434 - val_accuracy: 0.8197 - val_loss: 0.5229
313/313 - 3s - 11ms/step - accuracy: 0.8197 - loss: 0.5229
Test Accuracy: 81.97%
