In [16]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
import warnings
warnings.filterwarnings('ignore')

In [6]:
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train.shape, x_test.shape

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


((60000, 28, 28), (10000, 28, 28))

In [9]:
x_train = x_train.reshape(x_train.shape[0],28,28)
x_test = x_test.reshape(x_test.shape[0],28,28)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

In [17]:
x_train = np.stack((x_train,)*3, axis = 1)
x_test = np.stack((x_test,)*3, axis = 1)

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

In [12]:
train_generator = ImageDataGenerator(
     rescale = 1./255,
     rotation_range = 40,
     shear_range = 0.2,
     zoom_range = 0.2,
     fill_mode = 'nearest' 
)

val_generator = ImageDataGenerator(rescale = 1./255)
train_iterator = train_generator.flow(x_train,y_train,batch_size = 512, shuffle = True)
val_iterator = val_generator.flow(x_test,y_test, batch_size = 512, shuffle = False)

In [14]:
from tensorflow.keras.applications.resnet50 import ResNet50
from keras.models import Sequential
from keras.layers import Dense

In [19]:
model = Sequential()
model.add(ResNet50(include_top = False, pooling="avg",weights="imagenet"))
model.add(Dense(512,activation = "relu"))
model.add(Dense(10,activation="softmax"))

# adjust resnet layers are not trainable
model.layers[0].trainable = False
model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50 (Functional)       (None, 2048)              23587712  
                                                                 
 dense_2 (Dense)             (None, 512)               1049088   
                                                                 
 dense_3 (Dense)             (None, 10)                5130      
                                                                 
Total params: 24,641,930
Trainable params: 1,054,218
Non-trainable params: 23,587,712
_________________________________________________________________


In [21]:
model.compile(optimizer="adam",loss = 'sparse_categorical_crossentropy',metrics=['accuracy'])

In [None]:
model.fit(train_iterator, epochs = 3, validation_data = val_iterator)