In [1]:
import tensorflow as tf
from tensorflow import keras
import pathlib
import random
import time
AUTOTUNE = tf.data.experimental.AUTOTUNE
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)

data_root = keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                 fname='flower_photos', untar=True)
data_root = pathlib.Path(data_root)

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz


In [2]:
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
# print(all_image_paths)
random.shuffle(all_image_paths)
image_count = len(all_image_paths)
print(image_count)

label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
print(label_names)
label_to_index = dict((name, index) for index, name in enumerate(label_names))
print(label_to_index)

all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
print(len(all_image_labels))

BATCH_SIZE = 32
STEP_PER_EPOCH = tf.math.ceil(len(all_image_paths) / BATCH_SIZE).numpy()

3670
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
3670


In [0]:
def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [224, 224])
    image = keras.applications.xception.preprocess_input(image)

    return image

def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)

In [4]:
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

for i in path_ds.take(3):
    print(i)

tf.Tensor(b'/root/.keras/datasets/flower_photos/dandelion/4632757134_40156d7d5b.jpg', shape=(), dtype=string)
tf.Tensor(b'/root/.keras/datasets/flower_photos/daisy/19280272025_57de24e940_m.jpg', shape=(), dtype=string)
tf.Tensor(b'/root/.keras/datasets/flower_photos/sunflowers/164668737_aeab0cb55e_n.jpg', shape=(), dtype=string)


In [5]:
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)

for i in image_ds.take(3):
    print(i)

tf.Tensor(
[[[-0.33032215 -0.25973392 -0.97345936]
  [-0.3623249  -0.29838932 -0.9764706 ]
  [-0.40327382 -0.3326856  -0.9824055 ]
  ...
  [-0.5303088  -0.49893624 -0.9616813 ]
  [-0.55232996 -0.50527114 -0.97585934]
  [-0.5370448  -0.489986   -0.9605742 ]]

 [[-0.33604693 -0.2654587  -0.9721304 ]
  [-0.36319387 -0.2992583  -0.9773395 ]
  [-0.39733893 -0.3267507  -0.9764706 ]
  ...
  [-0.5531114  -0.5102893  -0.97875917]
  [-0.5263767  -0.48504275 -0.9441814 ]
  [-0.5537307  -0.5123967  -0.9715353 ]]

 [[-0.33816528 -0.26757705 -0.9703347 ]
  [-0.37016803 -0.30623245 -0.9843137 ]
  [-0.40504563 -0.3344574  -0.9841773 ]
  ...
  [-0.5551331  -0.5097725  -0.97696435]
  [-0.5521406  -0.5078304  -0.9661287 ]
  [-0.54169744 -0.49738723 -0.95568556]]

 ...

 [[-0.3333333  -0.38039213 -0.8666667 ]
  [-0.3257479  -0.37280673 -0.85908127]
  [-0.3271882  -0.36640388 -0.86863756]
  ...
  [-0.41744435 -0.37822866 -0.8174443 ]
  [-0.39058    -0.3513643  -0.8062663 ]
  [-0.39672518 -0.3575095  -0.812

In [6]:
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))

for i in label_ds.take(3):
    print(i)

tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64)


In [8]:
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))

for i in image_label_ds.take(3):
    print(i)

(<tf.Tensor: shape=(224, 224, 3), dtype=float32, numpy=
array([[[-0.33032215, -0.25973392, -0.97345936],
        [-0.3623249 , -0.29838932, -0.9764706 ],
        [-0.40327382, -0.3326856 , -0.9824055 ],
        ...,
        [-0.5303088 , -0.49893624, -0.9616813 ],
        [-0.55232996, -0.50527114, -0.97585934],
        [-0.5370448 , -0.489986  , -0.9605742 ]],

       [[-0.33604693, -0.2654587 , -0.9721304 ],
        [-0.36319387, -0.2992583 , -0.9773395 ],
        [-0.39733893, -0.3267507 , -0.9764706 ],
        ...,
        [-0.5531114 , -0.5102893 , -0.97875917],
        [-0.5263767 , -0.48504275, -0.9441814 ],
        [-0.5537307 , -0.5123967 , -0.9715353 ]],

       [[-0.33816528, -0.26757705, -0.9703347 ],
        [-0.37016803, -0.30623245, -0.9843137 ],
        [-0.40504563, -0.3344574 , -0.9841773 ],
        ...,
        [-0.5551331 , -0.5097725 , -0.97696435],
        [-0.5521406 , -0.5078304 , -0.9661287 ],
        [-0.54169744, -0.49738723, -0.95568556]],

       ...,

    