In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, datasets
from tensorflow.keras.applications import ResNet50

# Load MNIST data
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

# Preprocess: resize to 32x32 (minimum for ResNet) and convert to 3 channels
def preprocess(images):
    images = tf.image.resize(images[..., tf.newaxis], [32, 32])  # Add channel and resize
    images = tf.repeat(images, 3, axis=-1)  # Convert grayscale to RGB
    return tf.keras.applications.resnet50.preprocess_input(images)

train_images = preprocess(train_images)
test_images = preprocess(test_images)

# Load pretrained ResNet50 (without top layer)
base_model = ResNet50(
    weights='imagenet',
    include_top=False,
    input_shape=(32, 32, 3)
)
base_model.trainable = False  # Freeze pretrained layers

# Add simple classification head
model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(10, activation='softmax')  # 10 classes for MNIST
])

# Compile and train
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    train_images, train_labels,
    epochs=5,
    batch_size=64,
    validation_data=(test_images, test_labels)
)

# Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"\nTest Accuracy: {test_acc:.4f}")