We will learn how to take [full advantage of a TPU](https://www.tensorflow.org/guide/tpu)
[TPU strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/TPUStrategy?version=nightly)

Setting up the [Cluster Resolver](https://www.tensorflow.org/api_docs/python/tf/distribute/cluster_resolver/TPUClusterResolver)

Run below in cloud shell to set up TPU
`ctpu up --zone=us-central1-b  --tf-version=2.1 --name=my_tpu`

It should automatically SSH into the TPU, but alternatively can be accessed from [Compute Engine Interface](https://console.cloud.google.com/compute/instances).

In [2]:
%%writefile tpu_models/trainer/util.py
import tensorflow as tf

IMG_HEIGHT = 224
IMG_WIDTH = 224
IMG_CHANNELS = 3

BATCH_SIZE = 32
# 10 is a magic number tuned for local training of this dataset.
SHUFFLE_BUFFER = 10 * BATCH_SIZE
AUTOTUNE = tf.data.experimental.AUTOTUNE
CLASS_NAMES = ['roses', 'sunflowers', 'tulips', 'dandelion', 'daisy']

VALIDATION_IMAGES = 370
VALIDATION_STEPS = VALIDATION_IMAGES // BATCH_SIZE

CROP_SCALING = [IMG_HEIGHT + 10, IMG_WIDTH + 10]
MAX_DELTA = 63.0 / 255.0
CONTRAST_LOWER = 0.2
CONTRAST_UPPER = 1.8


def decode_img(img, reshape_dims):
    img = tf.image.decode_jpeg(img, channels=IMG_CHANNELS)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return tf.image.resize(img, reshape_dims)


def decode_csv(csv_row):
    record_defaults = ["path", "flower"]
    filename, label_string = tf.io.decode_csv(csv_row, record_defaults)
    image_bytes = tf.io.read_file(filename=filename)
    label = tf.math.equal(CLASS_NAMES, label_string)
    return image_bytes, label


def read_and_preprocess(image_bytes, label, random_augment=False):
    if random_augment:
        img = decode_img(image_bytes, CROP_SCALING)
        img = tf.image.random_crop(img, [IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS])
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_brightness(img, MAX_DELTA)
        img = tf.image.random_contrast(img, CONTRAST_LOWER, CONTRAST_UPPER)
    else:
        img = decode_img(image_bytes, [IMG_WIDTH, IMG_HEIGHT])
    return img, label


def read_and_preprocess_with_augment(image_bytes, label):
    return read_and_preprocess(image_bytes, label, random_augment=True)


def load_dataset(csv_of_filenames, training=True):
    dataset = tf.data.TextLineDataset(filenames=csv_of_filenames) \
        .map(decode_csv).cache()

    if training:
        dataset = dataset \
            .map(read_and_preprocess_with_augment) \
            .shuffle(SHUFFLE_BUFFER) \
            .repeat(count=None)
    else:
        dataset = dataset \
            .map(read_and_preprocess) \
            .repeat(count=1)

    # Prefetch prepares the next set of batches while current batch is in use.
    return dataset.batch(batch_size=BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)


Overwriting tpu_models/trainer/util.py


In [37]:
%%writefile tpu_models/trainer/model.py
import os
import shutil

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.layers import Dense, Dropout
import tensorflow_hub as hub

from . import util

NCLASSES = len(util.CLASS_NAMES)
LEARNING_RATE = 0.0001
DROPOUT = .2


def build_model(output_dir):
    """Compiles keras model for image classification."""
    module_selection = "mobilenet_v2_100_224"
    module_handle = "https://tfhub.dev/google/imagenet/{}/feature_vector/4" \
        .format(module_selection)

    model = tf.keras.Sequential([
        hub.KerasLayer(module_handle, trainable=False),
        tf.keras.layers.Dropout(rate=DROPOUT),
        tf.keras.layers.Dense(
            NCLASSES,
            activation='softmax',
            kernel_regularizer=tf.keras.regularizers.l2(LEARNING_RATE))
    ])
    model.build((None,)+(util.IMG_HEIGHT, util.IMG_WIDTH, util.IMG_CHANNELS))
    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['accuracy'])
    return model


def train_and_evaluate(
    model, num_epochs, steps_per_epoch, train_data, eval_data, output_dir):
    """Compiles keras model and loads data into it for training."""
    callbacks = []
    if output_dir:
        tensorboard_callback = TensorBoard(log_dir=output_dir)
        callbacks = [tensorboard_callback]

    history = model.fit(
        train_data,
        validation_data=eval_data,
        validation_steps=util.VALIDATION_STEPS,
        epochs=num_epochs,
        steps_per_epoch=steps_per_epoch,
        callbacks=callbacks)

    if output_dir:
        export_path = os.path.join(output_dir, 'keras_export')
        model.save(export_path, save_format='tf')

    return history


Overwriting tpu_models/trainer/model.py


In [38]:
%%writefile tpu_models/trainer/task.py
import argparse
import json
import os
import sys

import tensorflow as tf

from . import model
from . import util


def _parse_arguments(argv):
    """Parses command-line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--epochs',
        help='The number of epochs to train',
        type=int, default=10)
    parser.add_argument(
        '--steps_per_epoch',
        help='The number of steps per epoch to train',
        type=int, default=500)
    parser.add_argument(
        '--train_path',
        help='The path to the training data',
        type=str, default="gs://cloud-ml-data/img/flower_photos/train_set.csv")
    parser.add_argument(
        '--eval_path',
        help='The path to the evaluation data',
        type=str, default="gs://cloud-ml-data/img/flower_photos/eval_set.csv")
    parser.add_argument(
        '--tpu_address',
        help='The path to the evaluation data',
        type=str, required=True)
    parser.add_argument(
        '--job-dir',
        help='Directory where to save the given model',
        type=str, default='tpu_models/')
    return parser.parse_known_args(argv)


def main():
    """Parses command line arguments and kicks off model training."""
    args = _parse_arguments(sys.argv[1:])[0]
    
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        tpu=args.tpu_address)
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(resolver)
    
    with strategy.scope():
        train_data = util.load_dataset(args.train_path)
        eval_data = util.load_dataset(args.eval_path, training=False)
        image_model = model.build_model(args.job_dir)

    model_history = model.train_and_evaluate(
        image_model, args.epochs, args.steps_per_epoch,
        train_data, eval_data, args.job_dir)
    print("done!")


if __name__ == '__main__':
    main()


Overwriting tpu_models/trainer/task.py


Use below to copy model code from this notebook to GCS.

In [39]:
!gsutil rm -r gs://ddetering-experimental/tpu_models
!gsutil cp -r tpu_models gs://ddetering-experimental/tpu_models

Removing gs://ddetering-experimental/tpu_models/trainer/__init__.py#1584499666535273...
Removing gs://ddetering-experimental/tpu_models/trainer/__pycache__/__init__.cpython-35.pyc#1584499667482608...
Removing gs://ddetering-experimental/tpu_models/trainer/__pycache__/model.cpython-35.pyc#1584499667673468...
Removing gs://ddetering-experimental/tpu_models/trainer/__pycache__/task.cpython-35.pyc#1584499667108376...
/ [4 objects]                                                                   
==> NOTE: You are performing a sequence of gsutil operations that may
run significantly faster if you instead use gsutil -m rm ... Please
see the -m section under "gsutil help options" for further information
about when gsutil -m can be advantageous.

Removing gs://ddetering-experimental/tpu_models/trainer/__pycache__/util.cpython-35.pyc#1584499667299167...
Removing gs://ddetering-experimental/tpu_models/trainer/model.py#1584499666355621...
Removing gs://ddetering-experimental/tpu_models/trainer/t

Run below on the TPU to copy the model from GCS to the TPU

In [None]:
gsutil cp -r gs://ddetering-experimental/tpu_models .

Run below command on TPU to kick off training of the model. Output directory must be a GCS storage bucket as TPUs have a [restricted local file system](https://cloud.google.com/tpu/docs/troubleshooting#cannot_use_local_filesystem). The first few epochs will be slow as the tensorflow graph is built, but the rest will be very fast. Wheeee!!

The UI is a little laggy in showing the epoch runs. It takes about 10 minutes for it to really start flying.

In [None]:
python3 -m tpu_models.trainer.task \
    --tpu_address=my_tpu \
    --job-dir=gs://ddetering-experimental/flowers_tpu_$(date -u +%y%m%d_%H%M%S)