In [None]:
import os
import tensorflow as tf
import tensorflow_datasets as tfds

from typing import (
    Callable,
    Dict,
    Optional,
    Tuple,
)

In [None]:
# download tfrecords to local machine
train_data = tfds.load(<dataset>, split='train', with_info=True)

In [3]:
def get_preprocessed_dataset(
    filepaths: str,
    feature_spec: dict,
    is_training: bool,
    image_size: int,
    batch_size: int = 2,
    block_length: int = 2,
    max_seq_len: int = 12,
    subsample_size: Optional[int] = None,
    shuffle_buffer_size: int = 100,
    tpu: bool = False,
    prefetch_dataset_buffer_size: int = 8 * 1024 * 1024,
    num_classes: Optional[int] = None,
    use_bfloat16: bool = False,
    deterministic: bool = False,
    multi_gpu: bool = False,
) -> tf.data.Dataset:

    if deterministic:
        files = tf.data.Dataset.list_files(filepaths, shuffle=False)
    else:
        files = tf.data.Dataset.list_files(filepaths, shuffle=True)

    if is_training and not deterministic:
        files = files.repeat()

    def prefetch_dataset(filename):
        dataset = tf.data.TFRecordDataset(
            filename, buffer_size=prefetch_dataset_buffer_size
        )
        return dataset

    """
    TPU Performance guidance:
    1. Avoid using AUTOTUNE as it could cause unrestricted use of host memory, instead use a fixed number. 
    Changing num_parallel_calls=tf.data.AUTOTUNE to a fixed number
    2. prefetch_dataset_buffer_size will be overwritten to 127MB regardless of size specified if using GCS, so limit number of open files
    3. Ensure tfrecord size is between 100-200MB when using GCS
    4. avoid using DATA shard policy as this will unnecessary create buffers. Instead use default
    Changing tf.data.experimental.AutoShardPolicy.DATA to tf.data.experimental.AutoShardPolicy.DEFAULT
    """
    ds = files.interleave(
        prefetch_dataset,
        block_length=block_length,
        num_parallel_calls=2,
        deterministic=True if deterministic else False,
    )
    if subsample_size and subsample_size > 0:
        ds = ds.take(subsample_size)
    if is_training and not deterministic:
        ds = ds.shuffle(shuffle_buffer_size)

    if multi_gpu or not is_training:
        options = tf.data.Options()
        options.experimental_distribute.auto_shard_policy = (
            tf.data.experimental.AutoShardPolicy.DEFAULT
        )
        ds = ds.with_options(options)

    if use_bfloat16:
        img_dtype = tf.float16 if multi_gpu else tf.bfloat16
    else:
        img_dtype = tf.float32
    ds = ds.map(
        lambda x: parse_single_example(
            x,
            feature_spec,
            image_size,
            img_dtype
        ),
        num_parallel_calls=2,
        deterministic=True if deterministic else None,
    )

    ds = ds.prefetch(batch_size)
    ds = ds.batch(batch_size)
    
    return ds

In [4]:
def parse_single_example(
    example_proto: tf.Tensor,
    feature_spec: dict,
    image_size,
    image_dtype,
) -> Tuple[Dict, Dict]:
    parsed_features = tf.io.parse_single_example(
        serialized=example_proto,
        features=feature_spec,
    )
    image = tf.io.decode_jpeg(parsed_features["image"], channels=3)
    image = tf.image.resize(
        [image], [image_size, image_size], method=tf.image.ResizeMethod.BICUBIC
    )[0]
    image = tf.reshape(image, [image_size, image_size, 3])
    image = tf.image.convert_image_dtype(
        image, dtype=image_dtype
    )
    image = image / 255.

    return image, parsed_features["label"]

In [5]:
train_fps = f"/home/jupyter/tensorflow_datasets/<datapath>/train.tfrecord*"

In [6]:
FEATURE_SPEC = {
    "image": tf.io.FixedLenFeature((), tf.string, ""),
    "label": tf.io.FixedLenFeature((), tf.int64, -1),
}

In [7]:
img_size = 300

In [8]:
"""
TPU Performance guidance:
1. If num records in tfrecord is large, limiting batch size will help reduce buffer size
changing 4 to 1 as number of files open is very large for prefetch
"""

training_dataset = get_preprocessed_dataset(
        filepaths=train_fps,
        feature_spec=FEATURE_SPEC,
        is_training=True,
        batch_size=1,
        block_length=2,
        prefetch_dataset_buffer_size=4,
        max_seq_len=0,
        subsample_size=None,
        shuffle_buffer_size=4,
        tpu=False,
        image_size=img_size,
        num_classes=2,
        use_bfloat16=False,
        deterministic=False,
        multi_gpu=True,
    )

In [9]:
# didn't split properly for simplicity; only focused on data pipeline
valid_fps = f"/home/jupyter/tensorflow_datasets/<datapath>/train.tfrecord-00000-of-00016"

In [10]:
"""
TPU Performance guidance:
1. If num records in tfrecord is large, limiting batch size will help reduce buffer size
changing 4 to 1 as number of files open is very large for prefetch
"""

validation_dataset = get_preprocessed_dataset(
        filepaths=train_fps,
        feature_spec=FEATURE_SPEC,
        is_training=False,
        batch_size=1,
        block_length=2,
        prefetch_dataset_buffer_size=4,
        max_seq_len=0,
        subsample_size=None,
        shuffle_buffer_size=4,
        tpu=False,
        image_size=img_size,
        num_classes=2,
        use_bfloat16=False,
        deterministic=False,
        multi_gpu=True,
    )

In [11]:
def create_model():
    return tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(300, 300, 3)),
        tf.keras.layers.Conv2D(16, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2, 2),
        tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2,2),
        tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2,2),
        tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2,2),
        tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2,2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])

In [None]:
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = create_model()
    model.compile(optimizer='Adam', loss='binary_crossentropy', 
                  metrics=['accuracy'])

In [13]:
distributed_training_dataset = strategy.experimental_distribute_dataset(
    training_dataset.repeat()
)
distributed_validation_dataset = strategy.experimental_distribute_dataset(
    validation_dataset.repeat()
)

In [14]:
batch_size=4
steps_per_epoch=25000//batch_size
validation_steps=390

In [None]:
model.fit(
    distributed_training_dataset,
    batch_size=batch_size,
    epochs=5,
    validation_data=distributed_validation_dataset,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
)