In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

# Load in the data (takes about 5-6 minutes in Google Colab)
train_data, test_data = tfds.load(name="stanford_dogs", split=["train", "test"], shuffle_files=True, as_supervised=True)

In [2]:
# Make a function for preprocessing images
def preprocess_img(image, label, img_shape=(224,224)):
    image      = tf.image.resize(image, img_shape)
    img_tensor = tf.cast(image, tf.float32)

    return img_tensor, label

In [3]:
train_data = train_data.map(map_func=preprocess_img)
test_data  = test_data.map(map_func=preprocess_img)

train_data = train_data.batch(batch_size=32).shuffle(buffer_size=1000)
test_data  = test_data.batch(batch_size=32)

In [4]:
class Custom_Model_Architecture(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
        INPUT_SHAPE  = (224, 224, 3)
        NUM_CLASSES  = 120
        PRETRAINED_MODEL           = tf.keras.applications.EfficientNetB0(include_top=False)
        PRETRAINED_MODEL.trainable = False # freeze base model layers
        
        input_layer  = tf.keras.layers.Input(shape=INPUT_SHAPE, name="input_layer")
        embedding    = tf.keras.layers.GlobalAveragePooling2D()
        fc1          = tf.keras.layers.Dense(NUM_CLASSES)
        activation   = tf.keras.layers.Activation("softmax", dtype=tf.float32)

        self.complete_model = tf.keras.models.Sequential([
            input_layer,
            PRETRAINED_MODEL,
            embedding,
            fc1,
            activation
        ])

    def call(self, input_batch):
        final_output_probs = self.complete_model(input_batch)

        return final_output_probs

model = Custom_Model_Architecture()
model.compile(loss="sparse_categorical_crossentropy", # Use sparse_categorical_crossentropy when labels are *not* one-hot
              optimizer=tf.keras.optimizers.Adam(),
              metrics=["accuracy"])



In [5]:
# Fit the model with callbacks
history = model.fit(train_data, epochs=3, validation_data=test_data, callbacks=[])

model.save("07_efficientnetb0_feature_extract_model_mixed_precision")
# loaded_saved_model = tf.keras.models.load_model(save_dir)

Epoch 1/3