In [9]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_hub as hub
import PIL.Image as Image

In [10]:
image_shape = (224, 224)

In [11]:
# Now use some other available data
# Download it to keras 
data_root = tf.keras.utils.get_file(
    'flower_photos', 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True
)
# Use an image generator to build these images
img_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
img_data = img_generator.flow_from_directory(str(data_root), target_size=image_shape)
for img_batch, label_batch in img_data:
    print(img_batch.shape)
    print(label_batch.shape)
    break

Found 3670 images belonging to 5 classes.
(32, 224, 224, 3)
(32, 5)


In [21]:
# include_top = False removes the classifier layer from the MobileNet
headless_model = tf.keras.applications.MobileNet(
    input_shape=image_shape+(3,),
    alpha=1.0,
    depth_multiplier=1,
    dropout=0.001,
    include_top=False,
    weights='imagenet',
)
# Lock the weights for the convolution
headless_model.trainable = False

# Add the fully connected dense layers after the convolution has been don
model = tf.keras.Sequential([
    headless_model,
    # I think this layer is needed due to the Mobile net feature shape (but I may be wrong)
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(img_data.num_classes, activation='softmax')
])

In [17]:
features_img_batch = headless_model(img_batch)
features_img_batch.shape

TensorShape([32, 7, 7, 1024])

In [20]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])