In [1]:
import os
import json
import h5py
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
%load_ext tensorboard

# Cluster setup

In [2]:
tf_config = {
    'cluster': {
        'worker': ['192.168.1.1:12345', '192.168.1.2:12345'],
        'ps': ['192.168.1.3:12345', '192.168.1.4:12345'],
        'chief': ['192.168.1.5:12345']
    },
    'task': {'type': 'chief', 'index': 0}
}
os.environ.pop('TF_CONFIG', None)
os.environ['TF_CONFIG'] = json.dumps(tf_config)

# Allow reporting worker and ps failure to the coordinator
os.environ['GRPC_FAIL_FAST'] = 'use_caller'

## Instantiate a ParameterServerStrategy

In [3]:
variable_partitioner = (
    tf.distribute.experimental.partitioners.MinSizePartitioner(
        min_shard_bytes = (256 << 10),
        max_shards = len(tf_config['cluster']['ps'])
    )
)
strategy = tf.distribute.experimental.ParameterServerStrategy(
    tf.distribute.cluster_resolver.TFConfigClusterResolver(),
    variable_partitioner = variable_partitioner
)
strategy

<tensorflow.python.distribute.parameter_server_strategy_v2.ParameterServerStrategyV2 at 0x13378310c10>

# Path setup

In [4]:
TRAIN_PATH = 'Dataset/Train'
VALIDATE_PATH = 'Dataset/Validate'
TEST_PATH = 'Dataset/Test'

In [5]:
MODEL_PATH = 'Model'
MODEL_CKPT = os.path.join(MODEL_PATH, 'ckpt-{epoch}')
MODEL_TRAINED = os.path.join(MODEL_PATH, 'model.hdf5')
MODEL_BACKUP = os.path.join(MODEL_PATH, 'backup')

# Preparing data

In [6]:
CLASSES = 30
IMAGE_SIZE = (224, 224)
PER_WORKER_BATCH_SIZE = 32
NUM_WORKERS = len(tf_config['cluster']['worker'])
GLOBAL_BATCH_SIZE = PER_WORKER_BATCH_SIZE * NUM_WORKERS
EPOCHS = 3

In [7]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_generator = ImageDataGenerator(
    rescale = 1./255,
    rotation_range = 40, 
    width_shift_range = 0.2, 
    height_shift_range = 0.2,
    shear_range = 0.2,
    zoom_range = 0.2,
    horizontal_flip = True
)

## Input data

In [8]:
def train_dataset_fn(input_context):
    batch_size = input_context.get_per_replica_batch_size(GLOBAL_BATCH_SIZE)
    train_dataset = tf.data.Dataset.from_generator(
        lambda: train_generator.flow_from_directory(
            TRAIN_PATH, 
            target_size = IMAGE_SIZE, 
            batch_size = batch_size
        ), 
        output_types = (tf.float32, tf.float32), 
        output_shapes = ([batch_size, *IMAGE_SIZE, 3], [batch_size, CLASSES])
    ).shard(
        input_context.num_input_pipelines, 
        input_context.input_pipeline_id
    ).cache()
    return train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

In [9]:
from tensorflow.keras.utils.experimental import DatasetCreator
train_dataset = DatasetCreator(train_dataset_fn)
num_train = !find {TRAIN_PATH} -type f | wc -l
num_train = int(num_train[0])
print(f'Found {num_train} files')

Found 17581 files


# Model implement

In [10]:
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Model

## Define the model

In [11]:
def build_and_compile_model():
    base_model = MobileNetV2(
        input_shape = IMAGE_SIZE + (3,), 
        include_top = False,
        weights = None
    )
    
    x = preprocess_input(base_model.output)
    x = GlobalAveragePooling2D()(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.2)(x)
    outputs = Dense(CLASSES, activation='softmax')(x)
    
    model = Model(inputs=base_model.input, outputs=outputs)
    model.compile(
        optimizer = 'adam', 
        loss = 'categorical_crossentropy', 
        metrics = ['accuracy']
    )
    return model

## Callbacks

In [12]:
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint 
from tensorflow.keras.callbacks import Callback, LearningRateScheduler
from tensorflow.keras.callbacks.experimental import BackupAndRestore

In [13]:
def decay(epoch):
    if epoch < 3: return 1e-3
    elif epoch >= 3 and epoch < 7: return 1e-4
    return 1e-5

In [14]:
# Define a callback for printing the learning rate at the end of each epoch.
class PrintLR(Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f'\nLearning rate for epoch {epoch + 1} is {model.optimizer.lr.numpy()}')

In [15]:
callbacks = [
    TensorBoard(log_dir='./logs'),
    BackupAndRestore(backup_dir=MODEL_BACKUP),
    ModelCheckpoint(filepath=MODEL_CKPT, save_weights_only=True, verbose=1),
    LearningRateScheduler(decay),
    PrintLR()
]
!rm -rf logs

## Training

In [16]:
with strategy.scope(): 
    model = build_and_compile_model()

history = model.fit(
    train_dataset,
    epochs = EPOCHS,
    steps_per_epoch = num_train // (GLOBAL_BATCH_SIZE * NUM_WORKERS),
    # steps_per_epoch = num_train // GLOBAL_BATCH_SIZE,
    # callbacks = callbacks,
    # verbose = 1, # not allowed with ParameterServerStrategy
)
model.save(MODEL_TRAINED)

Epoch 1/3
Found 17581 images belonging to 30 classes.
Found 17581 images belonging to 30 classes.
137/137 - 128s - loss: 3.8640 - accuracy: 0.0466 - 128s/epoch - 934ms/step
Epoch 2/3
137/137 - 110s - loss: 3.3802 - accuracy: 0.0520 - 110s/epoch - 803ms/step
Epoch 3/3
137/137 - 108s - loss: 3.3887 - accuracy: 0.0589 - 108s/epoch - 791ms/step


In [None]:
model.save(MODEL_TRAINED)
%tensorboard --logdir=logs