## Imports

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

import math
import numpy as np
import matplotlib.pyplot as plt

## Download dataset with tfds

In [2]:
(training_set, validation_set), dataset_info = tfds.load('tf_flowers',
                                                          download=True,
                                                          as_supervised=True,
                                                          with_info=True,
                                                          split=['train[:80%]', 'train[80%:]'])

local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.



[1mDownloading and preparing dataset tf_flowers/3.0.0 (download: 218.21 MiB, generated: Unknown size, total: 218.21 MiB) to /root/tensorflow_datasets/tf_flowers/3.0.0...[0m


HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=5.0, style=ProgressStyle(descriptio…



[1mDataset tf_flowers downloaded and prepared to /root/tensorflow_datasets/tf_flowers/3.0.0. Subsequent calls will reuse this data.[0m


In [3]:
print(f"Number of training examples : {dataset_info.splits['train'].num_examples}")
print(f"Number of classes : {dataset_info.features['label'].num_classes}")

Number of training examples : 3670
Number of classes : 5


In [4]:
num_examples = dataset_info.splits['train'].num_examples
num_classes = dataset_info.features['label'].num_classes

## Reformat Images and create batches

In [5]:
IMG_SIZE = 224
BATCH_SIZE = 32

In [6]:
def format_images(image, label):
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE)) / 255.0
  return image, label

In [7]:
train_batch = training_set.cache().map(format_images).shuffle(num_examples//4).batch(BATCH_SIZE).prefetch(1)

In [8]:
validation_batch = validation_set.cache().map(format_images).batch(BATCH_SIZE).prefetch(1)

## Download classifier from tensorflow hub

In [11]:
URL = 'https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4'

classifier = hub.KerasLayer(URL, input_shape=(IMG_SIZE, IMG_SIZE, 3))
classifier.trainable = False

## Define a model

In [13]:
model = tf.keras.Sequential([
                             classifier,
                             tf.keras.layers.Dense(num_classes)
])

In [14]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [15]:
EPOCHS = 10

In [16]:
history = model.fit(
    train_batch,
    validation_data=validation_batch,
    epochs=EPOCHS
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
