In [None]:
# How to load an image dataset using tf.data

In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

In [3]:
tf.enable_eager_execution()

In [4]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [5]:
import IPython.display as display
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os

In [None]:
# Retrieve the images
import pathlib
data_dir = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    fname='flower_photos',
    untar=True
)
data_dir = pathlib.Path(data_dir)

In [None]:
data_dir

In [None]:
image_count = len(llist(data_dir.glob('*/*.jpg')))
image_count

In [None]:
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*') if item.name !='LICENSE.txt'])
CLASS_NAMES

In [None]:
roses = list(data_dir.glob('roses/*'))
for image_path in roses[:3]:
    display.display(Image.open(str(image_path)))

In [None]:
# Load using keras.preprocessing
# (tf.keras.preprocessing)

# The 1./255 is to convert from uint8 to float32 in range [0,1]

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255
)

# Define some parameters for the loader
BATCH_SIZE = 32
IMG_HEIGHT = 224,
IMG_WIDTH = 224
STEPS_PER_EPOCH = np.ceil(image_count/BATCH_SIZE)

In [None]:
train_data_gen = image_generator.flow_from_directory(
    directory=str(data_dir),
    batch_size=BATCH_SIZE,
    shuffle=True,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    classes = list(CLASS_NAMES)
)

In [7]:
# inspect batch

def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(25):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
        plt.axis('off')

In [None]:
image_batch, label_batch = next(train_data_gen)
show_batch(image_batch, label_batch)

In [None]:
# Load using tf.data
# first create a dataset of the file paths
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))

In [None]:
for f in list_ds.take(5):
    print(f.numpy())

In [None]:
# a pure tensorflow function that converts a file paths to 
# an (image_data, label) pair
def get_label(file_path):
    # convert the path to a list of path components
    parts = tf.strings.split(file_path, os.path.sep)
    # the second to last is the class_directory
    return parts[-2]== CLASS_NAMES

In [None]:
def decode_img(img):
    # convert the compressed string to a 3D uint* tensotr
    img = tf.image.decode_jpeg(img, channels=3)
    # Use `convert_image_dtype` to convert to floats in the
    # [0,1] range.
    img = tf.image.convert_image_dtype(img, tf.float32)
    # resize the image to the desired size
    return tf.image.resize(img, [IMG_WIDTH,IMG_HEIGHT])

In [None]:
def process_path(file_path):
    label = get_label(file_path)
    # load the raw data from the file as a string
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    return img, label

In [None]:
# use Dataset.map to create a dataset of image, label pairs

In [None]:
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
for image, label in labeled_ds.take(1):
  print("Image shape: ", image.numpy().shape)
  print("Label: ", label.numpy())


In [None]:
def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):
  # This is a small dataset, only load it once, and keep it in memory.
  # use `.cache(filename)` to cache preprocessing work for datasets that don't
  # fit in memory.
  if cache:
    if isinstance(cache, str):
      ds = ds.cache(cache)
    else:
      ds = ds.cache()

  ds = ds.shuffle(buffer_size=shuffle_buffer_size)

  # Repeat forever
  ds = ds.repeat()

  ds = ds.batch(BATCH_SIZE)

  # `prefetch` lets the dataset fetch batches in the background while the model
  # is training.
  ds = ds.prefetch(buffer_size=AUTOTUNE)

  return ds

In [None]:
train_ds = prepare_for_training(labeled_ds)

image_batch, label_batch = next(iter(train_ds))
show_batch(image_batch.numpy(), label_batch.numpy())


In [None]:
# PERFORMANCE
import time
default_timeit_steps = 1000

def timeit(ds, steps=default_timeit_steps):
  start = time.time()
  it = iter(ds)
  for i in range(steps):
    batch = next(it)
    if i%10 == 0:
      print('.',end='')
  print()
  end = time.time()

  duration = end-start
  print("{} batches: {} s".format(steps, duration))
  print("{:0.5f} Images/s".format(BATCH_SIZE*steps/duration))


In [None]:
# Let's compare the speed of the two data generators:
# `keras.preprocessing`
timeit(train_data_gen)


In [None]:
# `tf.data`
timeit(train_ds)


In [None]:
# A large part of the performance gain comes from the use of .cache.

In [None]:
uncached_ds = prepare_for_training(labeled_ds, cache=False)
timeit(uncached_ds)


In [None]:
filecache_ds = prepare_for_training(labeled_ds, cache="./flowers.tfcache")
timeit(filecache_ds)
