In [None]:
import os
import json
import tensorflow as tf
%load_ext tensorboard

# Multi-worker configuration

In [None]:
gpu_devices = tf.config.list_physical_devices('GPU') 
if len(gpu_devices) == 0: raise SystemError('GPU device not found')
for gpu in gpu_devices: 
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
tf_config = {
    'cluster': {
        'worker': ['192.168.1.1:12345', '192.168.1.16:12345']
    },
    'task': {'type': 'worker', 'index': 0}
}
os.environ.pop('TF_CONFIG', None)
os.environ['TF_CONFIG'] = json.dumps(tf_config)

In [None]:
strategy = tf.distribute.MultiWorkerMirroredStrategy(
    communication_options = tf.distribute.experimental.CommunicationOptions(
        implementation = tf.distribute.experimental.CollectiveCommunication.RING
    )
)

# Preparing data

In [None]:
IMAGE_SIZE = (224, 224)
PER_WORKER_BATCH_SIZE = 32
NUM_WORKERS = len(tf_config['cluster']['worker'])
GLOBAL_BATCH_SIZE = PER_WORKER_BATCH_SIZE * NUM_WORKERS
EPOCHS = 10

In [None]:
from tensorflow.keras.layers import Rescaling
data_url = 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz'
data_root = tf.keras.utils.get_file('flower_photos', data_url, untar=True)
data_length = !find {data_root} -name *.jpg | wc -l
data_length = int(data_length[0])

In [None]:
def get_dataset(batch_size, subset):
    shuffle = False
    if subset == 'validation': length = int(data_length * 0.2)
    elif subset == 'training': 
        length = int(data_length * 0.8)
        shuffle = True
    else: 
        raise NameError("subset must be 'training' or 'validation'")
        
    dataset = tf.keras.utils.image_dataset_from_directory(
      str(data_root),
      validation_split = 0.2,
      subset = subset,
      image_size = IMAGE_SIZE,
      batch_size = batch_size,
      seed = 123,
    )
    normalization_layer = Rescaling(1./127.5, offset=-1)
    dataset = dataset.map(lambda x, y: (normalization_layer(x), y))
    
    if shuffle: dataset = dataset.shuffle(buffer_size=length)
    dataset = dataset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset, length

In [None]:
ds_train, num_train = get_dataset(GLOBAL_BATCH_SIZE, 'training')
ds_val, num_val = get_dataset(GLOBAL_BATCH_SIZE, 'validation')

# Model implement

## Define the model

In [None]:
import tensorflow_hub as hub
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import SparseCategoricalCrossentropy

In [None]:
def build_and_compile_model():
    feature_extractor_layer = hub.KerasLayer(
        'https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4',
        input_shape = IMAGE_SIZE + (3,),
        trainable = True
    )
    model = Sequential([feature_extractor_layer, Dense(5)])
    model.compile(
        optimizer = 'adam',
        loss = SparseCategoricalCrossentropy(from_logits=True),
        metrics = ['accuracy']
    )
    return model

In [None]:
with strategy.scope(): 
    model = build_and_compile_model()
    model.summary()

## Callbacks

In [None]:
def decay(epoch):
    if epoch < 3: return 1e-3
    elif epoch >= 3 and epoch < 7: return 1e-4
    return 1e-5

In [None]:
# Define a callback for printing the learning rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f'\nLearning rate for epoch {epoch + 1} is {model.optimizer.lr.numpy()}')

In [None]:
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, LearningRateScheduler
callbacks = [
    TensorBoard(log_dir='./logs'),
    EarlyStopping(monitor='val_loss', patience=3, verbose=1),
    LearningRateScheduler(decay),
    PrintLR()
]

# Training

In [None]:
history = model.fit(
    ds_train, 
    validation_data = ds_val,
    callbacks = callbacks,
    epochs = EPOCHS, 
    verbose = 1,
)
%tensorboard --logdir=logs