In [14]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras.models import Model
from keras.layers import Input, Flatten, Dense, UpSampling2D, GlobalAveragePooling2D
from keras.optimizers import Adam
from keras.losses import binary_crossentropy
from keras.activations import relu, softmax
from keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input,decode_predictions
from keras.datasets import mnist

In [15]:
#I am using mnist digit dataset. This is a grey-scale image. Reason to use this dataset is to learn to process the
#gray-scale data for pre-trained model.
# Dataset detail: https://keras.io/api/datasets/mnist/ 
(X_train,y_train), (X_val,y_val) = mnist.load_data()
X_train.shape

(60000, 28, 28)

In [16]:
def transfer_learning():
    input_image = Input(shape=(28,28,3))
    
    resize = UpSampling2D(size=(8,8))(input_image) 
    #ResNet50 was trained on (224,224,3) image originally, so to make the balance, upsampling the mnist data.
    
    pretrained_model = ResNet50(input_shape=(224,224,3),
                                include_top=True,
                                weights='imagenet'
                               )(resize)
    
    pretrained_model.trainable = False
    
#     downstream = GlobalAveragePooling2D()(pretrained_model)
    flatten = Flatten()(pretrained_model)
    dense_1 = Dense(1024, activation=relu)(flatten)
    dense_2 = Dense(1024, activation=relu)(dense_1)
    output = Dense(10, activation=softmax)(dense_2)
    
    model = Model(input_image, output)
    
    return model
    
    

In [17]:
model = transfer_learning()

In [21]:
model.compile(optimizer=Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [None]:
def data_augmentation(images):
    
    return images

In [19]:
def preprocess_img(images):
    images = images.astype('float32')
    print(images.shape)
    images = images.reshape((images.shape[0], images.shape[1], images.shape[2], 1))
    images = images.repeat(3, -1)     
    print(images.shape)    
    return images
X_train = preprocess_img(X_train)
X_val = preprocess_img(X_val)

(60000, 28, 28)
(60000, 28, 28, 3)
(10000, 28, 28)
(10000, 28, 28, 3)


In [22]:
history = model.fit(X_train, y_train,epochs=5, verbose=1, validation_data=(X_val, y_val))

Epoch 1/5


2022-10-10 18:44:40.981785: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 564480000 exceeds 10% of free system memory.


  26/1875 [..............................] - ETA: 5:59:26 - loss: 2.0285 - accuracy: 0.2728

KeyboardInterrupt: 

In [None]:
def pre_trained_model(image_data):
    '''
    Feature Extraction part
    '''
    ptr = ResNet50(input_shape=(224,224,3),
                   include_top=False, 
                   weights='imagenet')(image_data)
    
    ptr.trainable = False
    return ptr

In [None]:
def classifier(ptr):
    flatten = Flatten()(ptr)
    dense_1 = Dense(1024, activation=relu)(flatten)
    output = Dense(10, activation=softmax)(dense_1)
    
    return output

In [None]:
def final_model(input_images):
    resizing = UpSampling2D(size=(8,8))(input_images)
    pretrained = pre_trained_model(resizing)
    classifier_output = classifier(pretrained)
    
    return classifier_output

In [None]:
def model_compile():
    input_img = Input(shape=(28,28,1))
    
    img_conc = tf.keras.layers.Concatenate()([input_img, input_img, input_img])    
    classifier_output = final_model(input_img)
    model = Model(inputs=input_img, outputs=classifier_output)
    
    model.compile(optimizer=Adam(), loss=binary_crossentropy, metrics=['accuracy'])
    
    return model

model = model_compile()

In [None]:
def preprocess_image(input_images):
    input_images = input_images.astype('float32')
    output_ims = tf.keras.applications.resnet50.preprocess_input(input_images)
    return output_ims

train_X = preprocess_image(X_train)
valid_X = preprocess_image(X_test)

In [None]:
# this will take around 20 minutes to complete
EPOCHS = 4
history = model.fit(train_X, training_labels, epochs=EPOCHS, validation_data = (valid_X, validation_labels), batch_size=64)