In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
import cv2

### Load and prepare data

In [None]:
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()

In [None]:
# Resize and convert grayscale images to RGB
training_images_resized = np.array([cv2.resize(img, (32, 32)) for img in training_images])
test_images_resized = np.array([cv2.resize(img, (32, 32)) for img in test_images])

# Add a channel dimension for RGB (grayscale -> RGB)
training_images_rgb = np.stack([training_images_resized] * 3, axis=-1)
test_images_rgb = np.stack([test_images_resized] * 3, axis=-1)

# Normalize the pixel values to [0, 1]
training_images_rgb = training_images_rgb / 255.0
test_images_rgb = test_images_rgb / 255.0

### Build the model

In [None]:
# One-hot encode the labels
training_labels = to_categorical(training_labels, num_classes=10)
test_labels = to_categorical(test_labels, num_classes=10)

# Load pre-trained VGG16 model without the top (classification) layer
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(32, 32, 3))


In [None]:
# Freeze the weights of the pre-trained layers
for layer in base_model.layers:
    layer.trainable = False

# Create a new model on top of the pre-trained base model
model = Sequential([
    base_model,
    Flatten(),
    Dense(256, activation='relu'),
    Dense(10, activation='softmax')  # Output layer with 10 classes
])

# Compile the model
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
# Train the model
history = model.fit(training_images_rgb, training_labels, epochs=10, batch_size=64, validation_data=(test_images_rgb, test_labels))

#### Eval the model

In [None]:
# Evaluate the model
test_loss, test_acc = model.evaluate(test_images_rgb, test_labels)
print("Test Accuracy:", test_acc)