In [337]:
# https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/load_data/images.ipynb#scrollTo=3SDhbo8lOBQv

import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pathlib

In [338]:
data_dir = pathlib.Path("./data/dogs-vs-cats/")

In [339]:
image_count = len(list(data_dir.glob('*/*.jpg')))

In [340]:
SEED = 0
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
STEPS_PER_EPOCH = np.ceil(image_count/BATCH_SIZE)
CLASS_NAMES = np.array(["cat", "dog"])

tf.random.set_seed(SEED)

In [341]:
def get_label(path):
    s_path = tf.strings.split(path, os.path.sep)[-1]
    label = tf.strings.split(s_path, ".")[0]
    return label == CLASS_NAMES

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  img = tf.image.convert_image_dtype(img, tf.float32)
  # resize the image to the desired size.
  return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])

def process_path(path):
    label = get_label(path)
    img = tf.io.read_file(path)
    img = decode_img(img)
    return img, label

In [345]:
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))
labeled_ds = list_ds.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [347]:
for image, label in labeled_ds.take(5):
    print("Image shape: ", image.numpy().shape)
    print("Label: ", label.numpy())

Image shape:  (224, 224, 3)
Label:  [False  True]
Image shape:  (224, 224, 3)
Label:  [False  True]
Image shape:  (224, 224, 3)
Label:  [False  True]
Image shape:  (224, 224, 3)
Label:  [False  True]
Image shape:  (224, 224, 3)
Label:  [False  True]


In [352]:
def prepare_for_training(ds, batch_size=BATCH_SIZE, shuffle_buff_size=200):
    return ds.shuffle(shuffle_buff_size).batch(batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

In [353]:
train_ds = prepare_for_training(labeled_ds)

In [354]:
next(iter(train_ds))

(<tf.Tensor: id=38000, shape=(32, 224, 224, 3), dtype=float32, numpy=
 array([[[[0.34259096, 0.33474782, 0.2759243 ],
          [0.35234594, 0.3445028 , 0.28567928],
          [0.3402311 , 0.33238795, 0.27356443],
          ...,
          [0.2041312 , 0.18515356, 0.16193925],
          [0.23707995, 0.17537439, 0.1785782 ],
          [0.24798661, 0.17498918, 0.1823282 ]],
 
         [[0.34752074, 0.3396776 , 0.28085408],
          [0.3529412 , 0.34509805, 0.28627452],
          [0.3402325 , 0.33238935, 0.27356583],
          ...,
          [0.20031373, 0.18133609, 0.15812178],
          [0.23432876, 0.18608469, 0.18608469],
          [0.27061248, 0.2109137 , 0.21331884]],
 
         [[0.35166317, 0.34382004, 0.2849965 ],
          [0.3506989 , 0.34285575, 0.28403223],
          [0.34075865, 0.3329155 , 0.274092  ],
          ...,
          [0.20042485, 0.18102238, 0.1580205 ],
          [0.22494625, 0.18317452, 0.18053097],
          [0.27540186, 0.22879831, 0.22778335]],
 
         ...