This notebook shows how to build TFRecords from a custom image classification dataset and how to use the TFRecords to train a deep learning model in `tf.keras.` This notebook has some major sections:
- Writing TFRecords
- Loading in the TFRecords
- Model building
- Training models with TFRecords

Acknowledgements: [Martin Görner](https://twitter.com/martin_gorner) & his amazing [tutorial notebook](https://nbviewer.jupyter.org/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/keras_flowers_gputputpupod_tf2.1.ipynb). 

## Initial setup and imports

In [1]:
# Select TensorFlow 2.0 environment
# This will only work in Colab
%tensorflow_version 2.x

TensorFlow 2.x selected.


In [0]:
# Imports
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
from imutils import paths
import tensorflow as tf
import numpy as np
import pathlib
import re

In [4]:
print(tf.__version__)

2.0.0


## Data gathering and inspection

In [5]:
# Get the flowers' dataset
flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz


In [6]:
# We have got five different classes
!ls {flowers}

daisy  dandelion  LICENSE.txt  roses  sunflowers  tulips


In [0]:
CLASSES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

In [8]:
# There are a total of ____ images
total_data = len(list(paths.list_images(flowers)))
total_data

3670

In [9]:
# Gather all the image paths
dataset = tf.data.Dataset.list_files(str(pathlib.Path(flowers)/'*/*'), seed=666)
for filename in dataset.take(5):
    print(filename.numpy())

b'/root/.keras/datasets/flower_photos/tulips/14087425312_2b5846b570_n.jpg'
b'/root/.keras/datasets/flower_photos/sunflowers/3846717708_ea11383ed8.jpg'
b'/root/.keras/datasets/flower_photos/sunflowers/244074259_47ce6d3ef9.jpg'
b'/root/.keras/datasets/flower_photos/dandelion/19812060274_c432f603db.jpg'
b'/root/.keras/datasets/flower_photos/sunflowers/3062794421_295f8c2c4e.jpg'


Note that the above paths are byte-strings not text strings. 

## Prepare helper functions for writing TFRecords

In [0]:
# Function to read the image from the path,
# parse its labels, cast the pixel pvalues to float,
# and resize the image
def parse_image(filename):
    parts = tf.strings.split(filename, '/')
    label = parts[-2]

    image = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [128, 128])
    return (image, label)

In [0]:
# AUTOTUNE makes it easier to make the parallelization dynamic
dataset = dataset.map(parse_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [0]:
# Images are byte-strings
def _bytestring_feature(list_of_bytestrings):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings))

# Classes would be integers
def _int_feature(list_of_ints): 
    return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))

In [0]:
# Function that prepares a record for the tfrecord file
# A record contains the image and its label
def to_tfrecord(img_bytes, label):  
    class_num = np.argmax(np.array(CLASSES)==label) 
    feature = {
      "image": _bytestring_feature([img_bytes]), 
      "class": _int_feature([class_num]),             
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

In [0]:
# We need to convert the image to byte strings
def recompress_image(image, label):
    image = tf.cast(image, tf.uint8)
    image = tf.image.encode_jpeg(image, optimize_size=True, chroma_downsampling=False)
    return (image, label)

In [0]:
# Make full use of `map`
dataset = dataset.map(recompress_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(32) 

In [16]:
# Prepare tfrecords
for shard, (image, label) in enumerate(dataset):
    shard_size = image.numpy().shape[0]
    filename = "flowers-" + "{:02d}-{}.tfrec".format(shard, shard_size)
  
    with tf.io.TFRecordWriter(filename) as out_file:
        for i in range(shard_size):
            example = to_tfrecord(image.numpy()[i],label.numpy()[i])
            out_file.write(example.SerializeToString())
        print("Wrote file {} containing {} records".format(filename, shard_size))

Wrote file flowers-00-32.tfrec containing 32 records
Wrote file flowers-01-32.tfrec containing 32 records
Wrote file flowers-02-32.tfrec containing 32 records
Wrote file flowers-03-32.tfrec containing 32 records
Wrote file flowers-04-32.tfrec containing 32 records
Wrote file flowers-05-32.tfrec containing 32 records
Wrote file flowers-06-32.tfrec containing 32 records
Wrote file flowers-07-32.tfrec containing 32 records
Wrote file flowers-08-32.tfrec containing 32 records
Wrote file flowers-09-32.tfrec containing 32 records
Wrote file flowers-10-32.tfrec containing 32 records
Wrote file flowers-11-32.tfrec containing 32 records
Wrote file flowers-12-32.tfrec containing 32 records
Wrote file flowers-13-32.tfrec containing 32 records
Wrote file flowers-14-32.tfrec containing 32 records
Wrote file flowers-15-32.tfrec containing 32 records
Wrote file flowers-16-32.tfrec containing 32 records
Wrote file flowers-17-32.tfrec containing 32 records
Wrote file flowers-18-32.tfrec containing 32 r

## Loading TFRecords

In [0]:
# Function to read the TFRecords, segregate the images and labels
def read_tfrecord(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string), 
        "class": tf.io.FixedLenFeature([], tf.int64)
    }
    
    example = tf.io.parse_single_example(example, features)
    image = tf.image.decode_jpeg(example['image'], channels=3)
    image = tf.cast(image, tf.float32) / 255.0  
    image = tf.reshape(image, [128, 128, 3]) 
    class_label = tf.cast(example['class'], tf.int32)
    
    return (image, class_label)

In [0]:
# Load the TFRecords and create tf.data.Dataset
def load_dataset(filenames):
    dataset = tf.data.Dataset.from_tensor_slices(filenames)
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=16) 
    dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    
    return dataset

In [0]:
# We need this to derive steps
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [0]:
# Batch, shuffle and repeat the dataset and pre-fetch it
# well before the current epoch ends
def batch_dataset(filenames, batch_size, train):
    dataset = load_dataset(filenames)
    n = count_data_items(filenames)
    
    if train:
        dataset = dataset.shuffle(buffer_size=1000).repeat()
    else:
        dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) 
    return (dataset, n//batch_size)

In [22]:
tfrecord_pattern = "*.tfrec"
filenames = tf.io.gfile.glob(tfrecord_pattern)
filenames[:10]

['./flowers-22-32.tfrec',
 './flowers-57-32.tfrec',
 './flowers-65-32.tfrec',
 './flowers-16-32.tfrec',
 './flowers-94-32.tfrec',
 './flowers-18-32.tfrec',
 './flowers-71-32.tfrec',
 './flowers-46-32.tfrec',
 './flowers-84-32.tfrec',
 './flowers-112-32.tfrec']

## Model building

In [23]:
BATCH_SIZE = 64

split = len(filenames) - int(len(filenames) * 0.2)
train_filenames = filenames[:split]
valid_filenames = filenames[split:]

training_dataset, steps_per_epoch = batch_dataset(train_filenames, BATCH_SIZE, True)
validation_dataset, validation_steps = batch_dataset(valid_filenames, BATCH_SIZE, False)

print("TRAINING   IMAGES: ", count_data_items(train_filenames), ", STEPS PER EPOCH: ", steps_per_epoch)
print("VALIDATION IMAGES: ", count_data_items(valid_filenames), ", STEPS PER EPOCH: ", validation_steps)

TRAINING   IMAGES:  2934 , STEPS PER EPOCH:  45
VALIDATION IMAGES:  736 , STEPS PER EPOCH:  11


In [0]:
def get_training_model():
    baseModel = VGG16(weights="imagenet", include_top=False,
        input_tensor=Input(shape=(128, 128, 3)))

    headModel = baseModel.output
    headModel = Flatten(name="flatten")(headModel)
    headModel = Dense(512, activation="relu")(headModel)
    headModel = Dropout(0.5)(headModel)
    headModel = Dense(5, activation="softmax")(headModel)

    model = Model(inputs=baseModel.input, outputs=headModel)

    for layer in baseModel.layers:
        layer.trainable = False

    opt = SGD(lr=1e-4, momentum=0.9)
    model.compile(loss="sparse_categorical_crossentropy", optimizer=opt,
        metrics=["accuracy"])
    return model

## Model training

In [25]:
model = get_training_model()
model.fit(training_dataset, 
         steps_per_epoch=steps_per_epoch,
         validation_data=validation_dataset,
         validation_steps=validation_steps,
         epochs=5)

Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
Train for 45 steps, validate for 11 steps
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7f3d20160b00>

## Explore more:

- https://www.tensorflow.org/tutorials/load_data/tfrecord
- https://codelabs.developers.google.com/codelabs/keras-flowers-data/
- https://medium.com/ymedialabs-innovation/how-to-use-tfrecord-with-datasets-and-iterators-in-tensorflow-with-code-samples-ffee57d298af