In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from bokeh.io import output_notebook
from image_plotter import show_image, show_images
import numpy as np

In [2]:
output_notebook()

## Tensorflow datasets
This is an external library that has a bunch of usual ML datasets packaged as `Dataset` objects. It has a helpful `DatasetInfo` which has all the metadata associated with the dataset. I can specify the data splits using strings. See [doc](https://www.tensorflow.org/datasets/splits#s3_slicing_api)

Usually, the dataset outputted by the `tfds` library will yeild a `dict` with `X` and `y` keys. Most `keras` models want datasets in a tuple form, `(X, y)`. It is often useful to define a mapper to do this right off the bat.

In [3]:
flowers, flowers_info = tfds.load(name="tf_flowers", data_dir="/data", with_info=True)
flowers



{'train': <_OptionsDataset shapes: {image: (None, None, 3), label: ()}, types: {image: tf.uint8, label: tf.int64}>}

In [4]:
flowers_info

tfds.core.DatasetInfo(
    name='tf_flowers',
    version=1.0.0,
    description='A large set of images of flowers',
    urls=['http://download.tensorflow.org/example_images/flower_photos.tgz'],
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    total_num_examples=3670,
    splits={
        'train': 3670,
    },
    supervised_keys=('image', 'label'),
    citation="""@ONLINE {tfflowers,
    author = "The TensorFlow Team",
    title = "Flowers",
    month = "jan",
    year = "2019",
    url = "http://download.tensorflow.org/example_images/flower_photos.tgz" }""",
    redistribution_info=,
)

In [5]:
classes = flowers_info.features["label"].names
classes

['dandelion', 'daisy', 'tulips', 'sunflowers', 'roses']

In [6]:
for x in flowers["train"].take(3):
    print(type(x))
    print(x.keys())

<class 'dict'>
dict_keys(['image', 'label'])
<class 'dict'>
dict_keys(['image', 'label'])
<class 'dict'>
dict_keys(['image', 'label'])


In [7]:
def to_tpl(elem):
    return elem["image"], elem["label"]

images = []
labels = []

for x in flowers["train"].map(to_tpl).take(3):
    print(type(x))
    image, label = x
    images.append(image)
    labels.append(label)
    

show_images(images, labels, classes)

<class 'tuple'>
<class 'tuple'>
<class 'tuple'>


## Image Pipeline
Each image in the `flowers` dataset is sized differently. Lets create a pipeline which will resize all the images as 192x192, normalize them in the [0,1] range, and then rescale them in the [-1, 1] range because one of the layers we will use in this example expects the pixel values to be in this range.

In [8]:
def resize(image, label):
    return tf.image.resize(image, (192, 192)), label


ds = flowers["train"]
ds = ds.map(to_tpl)
ds = ds.map(resize)
images, labels = [], []
for image, label in ds.take(3):
    images.append(image)
    labels.append(label)
show_images(images, labels, classes)

In [9]:
def normalize(image, label):
    return image/255., label


def rescale(image, label):
    return 2*image-1, label


ds = flowers["train"]
ds = ds.map(to_tpl)
ds = ds.map(resize)
ds = ds.map(normalize)
ds = ds.map(rescale)

for image, label in ds.take(3):
    print(image.shape)
    print(np.min(image.numpy()), np.max(image.numpy()))


(192, 192, 3)
-1.0 1.0
(192, 192, 3)
-1.0 1.0
(192, 192, 3)
-1.0 0.99025965


## Shuffling, Reepating, and Batching
Lets shuffle with shuffle buffer of 512 images, and create batch sizes of 32. And because `keras` likes datasets that loop endlessly, lets set it up so.

This time lets parallelize the mappers.

In [10]:
auto = tf.data.experimental.AUTOTUNE

In [11]:
SHUFFLE_BUFFER = 512
BATCH_SIZE = 32

In [12]:
ds = flowers["train"]
ds = ds.map(to_tpl, auto)
ds = ds.map(resize, auto)
ds = ds.map(normalize, auto)
ds = ds.map(rescale, auto)
ds = ds.shuffle(SHUFFLE_BUFFER)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE, drop_remainder=True)
ds = ds.prefetch(auto)

## Creating the model
Our goal is to create this model:
MobileNet --> Average Pool --> Softmax

We'll use MobileNet as an embedding layer. Lets create that first and check that it works.

In [13]:
mobile_net = tf.keras.applications.MobileNetV2(input_shape=(192, 192, 3), include_top=False)
mobile_net.trainable = False

In [14]:
image_batch = None
for images, labels in ds.take(1):
    image_batch = images
image_batch.shape

TensorShape([32, 192, 192, 3])

In [15]:
tp = mobile_net(image_batch)
tp.shape

TensorShape([32, 6, 6, 1280])

MobileNet seems to be working fine. Lets build the full model.

In [16]:
model = tf.keras.Sequential([
    mobile_net,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(len(classes), activation="softmax")
])
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
mobilenetv2_1.00_192 (Model) (None, 6, 6, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________


Lets do a single iteration and see if it works.

In [17]:
logits_batch = model(image_batch)
print(type(logits_batch), logits_batch.shape)

<class 'tensorflow.python.framework.ops.EagerTensor'> (32, 5)


Now set the optimizer and the loss function.

In [18]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

Before we can use this model, we need to know the number of batches per epoch.

In [19]:
ds_size = flowers_info.splits["train"].num_examples
ds_size

3670

In [20]:
num_batches = ds_size // BATCH_SIZE
num_batches

114

In [21]:
model.fit(ds, epochs=3, steps_per_epoch=num_batches)

Epoch 1/3


W0815 18:56:19.391416 4386481600 deprecation.py:323] From /Users/avilay/venvs/ai/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x13ec15ba8>

## Caching
One way to speed up the training is to cache the dataset. Calling `cache()` as-is will cache the entire dataset in memory. If that is not feasible, we can also use a file cache.

In [22]:
ds = flowers["train"]
ds = ds.map(to_tpl, auto)
ds = ds.map(resize, auto)
ds = ds.map(normalize, auto)
ds = ds.map(rescale, auto)
ds = ds.cache()
ds = ds.shuffle(SHUFFLE_BUFFER)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE, drop_remainder=True)
ds = ds.prefetch(auto)

model = tf.keras.Sequential([
    mobile_net,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(len(classes), activation="softmax")
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

In [23]:
model.fit(ds, epochs=3, steps_per_epoch=num_batches)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x166385898>

In [24]:
ds = flowers["train"]
ds = ds.map(to_tpl, auto)
ds = ds.map(resize, auto)
ds = ds.map(normalize, auto)
ds = ds.map(rescale, auto)
ds = ds.cache("./cache.tf-data")
ds = ds.shuffle(SHUFFLE_BUFFER)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE, drop_remainder=True)
ds = ds.prefetch(auto)

model = tf.keras.Sequential([
    mobile_net,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(len(classes), activation="softmax")
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

In [25]:
model.fit(ds, epochs=3, steps_per_epoch=num_batches)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x16a2ea588>