In [1]:
import os
import json
import tensorflow as tf
tf.get_logger().setLevel('ERROR')

# Cluster setup

In [None]:
tf_config = {
    'cluster': {
        'worker': ['192.168.1.1:12345', '192.168.1.2:12345'],
        'ps': ['192.168.1.3:12345', '192.168.1.4:12345'],
        'chief': ['192.168.1.5:12345']
    },
    'task': {'type': 'chief', 'index': 0}
}
os.environ.pop('TF_CONFIG', None)
os.environ['TF_CONFIG'] = json.dumps(tf_config)

# Allow reporting worker and ps failure to the coordinator
os.environ['GRPC_FAIL_FAST'] = 'use_caller'

In [None]:
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
if cluster_resolver.task_type != 'chief':
    raise SystemError('Machine is in wrong role')

## Instantiate a ParameterServerStrategy

In [None]:
variable_partitioner = (
    tf.distribute.experimental.partitioners.MinSizePartitioner(
        min_shard_bytes = (256 << 10),
        max_shards = len(tf_config['cluster']['ps'])
    )
)
strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver,
    variable_partitioner = variable_partitioner
)
print(f'Number of devices: {strategy.num_replicas_in_sync}')

# Path setup

In [2]:
TRAIN_PATH = 'Dataset/Train'
VALIDATE_PATH = 'Dataset/Validate'
TEST_PATH = 'Dataset/Test'

In [3]:
MODEL_PATH = 'Model'

BASE_MODEL_BEST = os.path.join(MODEL_PATH, 'base_model_best.hdf5')
BASE_MODEL_TRAINED = os.path.join(MODEL_PATH, 'base_model_trained.hdf5')
BASE_MODEL_FIG = os.path.join(MODEL_PATH, 'base_model_fig.jpg')
BASE_MODEL_LOG = os.path.join(MODEL_PATH, 'base_model_log')
BASE_MODEL_BACKUP = os.path.join(MODEL_PATH, 'base_model_backup')

FINE_TUNE_MODEL_BEST = os.path.join(MODEL_PATH, 'fine_tune_model_best.hdf5')
FINE_TUNE_MODEL_TRAINED = os.path.join(MODEL_PATH, 'fine_tune_model_trained.hdf5')
FINE_TUNE_MODEL_FIG = os.path.join(MODEL_PATH, 'fine_tune_model_fig.jpg')
FINE_TUNE_MODEL_LOG = os.path.join(MODEL_PATH, 'fine_tune_model_log')
FINE_TUNE_MODEL_BACKUP = os.path.join(MODEL_PATH, 'fine_tune_backup')

# Preparing data

In [4]:
CLASSES = 30
IMAGE_SIZE = (300, 300)
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

In [5]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_generator = ImageDataGenerator(
    rescale = 1./255,
    rotation_range = 40, 
    width_shift_range = 0.2, 
    height_shift_range = 0.2,
    shear_range = 0.2,
    zoom_range = 0.2,
    horizontal_flip = True
)
validate_generator = ImageDataGenerator(rescale=1./255)
test_generator = ImageDataGenerator(rescale=1./255)

## Input data

In [6]:
def train_dataset_fn(input_context):
    batch_size = input_context.get_per_replica_batch_size(GLOBAL_BATCH_SIZE)
    train_dataset = tf.data.Dataset.from_generator(
        lambda: train_generator.flow_from_directory(
            TRAIN_PATH, 
            target_size = IMAGE_SIZE, 
            batch_size = batch_size
        ), 
        output_types = (tf.float32, tf.float32), 
        output_shapes = ([batch_size, *IMAGE_SIZE, 3], [batch_size, CLASSES])
    ).shard(
        input_context.num_input_pipelines, 
        input_context.input_pipeline_id
    ).cache()
    return train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

In [7]:
def validate_dataset_fn(input_context):
    batch_size = input_context.get_per_replica_batch_size(GLOBAL_BATCH_SIZE)
    validate_dataset = tf.data.Dataset.from_generator(
        lambda: validate_generator.flow_from_directory(
            VALIDATE_PATH, 
            target_size = IMAGE_SIZE, 
            batch_size = batch_size
        ), 
        output_types = (tf.float32, tf.float32), 
        output_shapes = ([batch_size, *IMAGE_SIZE, 3], [batch_size, CLASSES])
    ).shard(
        input_context.num_input_pipelines, 
        input_context.input_pipeline_id
    ).cache()
    return validate_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

In [30]:
from tensorflow.keras.utils.experimental import DatasetCreator
train_dataset = DatasetCreator(train_dataset_fn)
validate_dataset = DatasetCreator(validate_dataset_fn)

num_train = !find {TRAIN_PATH} -type f | wc -l
num_validate = !find {VALIDATE_PATH} -type f | wc -l
num_train, num_validate = int(num_train[0]), int(num_validate[0])

# Model implement

In [None]:
INITIAL_EPOCHS = 15
FINE_TUNE_EPOCHS = 15
TOTAL_EPOCHS = INITIAL_EPOCHS + FINE_TUNE_EPOCHS
FINE_TUNE_AT = 516

In [None]:
from tensorflow.keras.applications.resnet_v2 import ResNet152V2
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Model

## Callbacks

In [None]:
from tensorflow.keras.callbacks import TensorBoard, BackupAndRestore

base_tensorboard = TensorBoard(log_dir=BASE_MODEL_LOG)
base_backup = BackupAndRestore(backup_dir=BASE_MODEL_BACKUP)

fine_tune_tensorboard = TensorBoard(log_dir=FINE_TUNE_MODEL_LOG)
fine_tunee_backup = BackupAndRestore(backup_dir=FINE_TUNE_MODEL_BACKUP)

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

base_checkpointer = ModelCheckpoint(
    filepath = BASE_MODEL_BEST, 
    save_best_only = True, 
    verbose = 1
)

fine_tune_checkpointer = ModelCheckpoint(
    filepath = FINE_TUNE_MODEL_BEST, 
    save_best_only = True,
    verbose = 1, 
)

# Stop if no improvement after 3 epochs
early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1)

## Stage 1: Transfer learning

In [None]:
with strategy.scope():
    pretrained_model = ResNet50(weights='imagenet', include_top=False)
    last_output = pretrained_model.output
    x = GlobalAveragePooling2D()(last_output)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.2)(x)
    outputs = Dense(CLASSES, activation='softmax')(x)
    model = Model(inputs=pretrained_model.input, outputs=outputs)
    
    for layer in pretrained_model.layers: layer.trainable = False
    model.compile(
        optimizer = 'rmsprop', 
        loss = 'categorical_crossentropy', 
        metrics = ['accuracy']
    )

In [None]:
history = model.fit(
    train_dataset,
    validation_data = validate_dataset,
    validation_steps = num_validate // GLOBAL_BATCH_SIZE,
    steps_per_epoch = num_train // GLOBAL_BATCH_SIZE,
    callbacks = [
        base_tensorboard,
        base_backup,
        base_checkpointer,
        early_stopping,
    ],
    epochs = INITIAL_EPOCHS,
    verbose = 1,
)
model.save(BASE_MODEL_TRAINED)

## Stage 2: Fine tuning

In [None]:
from tensorflow.keras.optimizers import SGD
with strategy.scope():
    for layer in pretrained_model.layers[:FINE_TUNE_AT]: 
        layer.trainable = False
    for layer in pretrained_model.layers[FINE_TUNE_AT:]: 
        layer.trainable = True
    model.compile(
        optimizer = SGD(learning_rate=1e-4, momentum=0.9), 
        loss = 'categorical_crossentropy', 
        metrics = ['accuracy']
    )

In [None]:
history_fine = model.fit(
    train_dataset,
    validation_data = validate_dataset,
    validation_steps = num_validate // GLOBAL_BATCH_SIZE,
    steps_per_epoch = num_train // GLOBAL_BATCH_SIZE,
    epochs = TOTAL_EPOCHS,
    initial_epoch = history.epoch[-1],
    callbacks = [
        fine_tune_tensorboard,
        fine_tune_backup,
        fine_tune_checkpointer,
        early_stopping,
    ]
    verbose = 1,
)
model.save(FINE_TUNE_MODEL_TRAINED)