In [None]:
import os
import time
import tensorflow as tf
from tensorflow import keras
import tf_data_model

In [None]:
# for device
device_index = -1
MIXED_PRECISION_FLAG = True
## TF 2.6 does not support jit_compile in keras.Model.compile() yet.
## So, just set it to False.
## Another way is to use environment variable 'TF_XLA_FLAGS'.
## Set os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'.
JIT_COMPILE_FLAG = True

# for dataset
dataset = 'cifar100'
dir_path = '/ssd'
resolution = 32
batch_size = 500

# for model
depth = 18
dropout_rate = 0.2 # [0.1, 0.2]

# for training
learning_rate = 1e-1
momentum = 0.9
epochs = 6 # 90
## TF 2.6 does not support weight_decay in keras.optimizers.SGD() yet.
## So, it might be set in the model.
weight_decay = 1e-4

# for learning rate scheduler
milestones = [2, 4] # [30, 60]
gamma = 0.1

######## for testing: BS and LR are propotional
#learning_rate *= batch_size / (32 * 8) # PyTorch uses batch size 32 with 8 GPUs

In [None]:
######## for testing: TF 2.6 jit_compile, it must call berfore any tensorflow function.
## For unknown reasons, '--tf_xla_cpu_global_jit' only supports the first GPU.
## Otherwise an error will result.
if JIT_COMPILE_FLAG:
    if device_index == 0:
        # can not use the condition 'len(tf.config.list_physical_devices('GPU')) == 1'
        # since it calls tf function...
        os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
    else:
        os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2'

if MIXED_PRECISION_FLAG:
    policy = keras.mixed_precision.Policy('mixed_float16')
    keras.mixed_precision.set_global_policy(policy)
    print(f'Policy: {policy.name}')
    print(f'Compute dtype: {policy.compute_dtype}')
    print(f'Variable dtype: {policy.variable_dtype}')

print('----')
print(f'MIXED_PRECISION: {MIXED_PRECISION_FLAG}')
print(f'JIT_COMPILE: {JIT_COMPILE_FLAG}')
print('----')

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
print(f'Numbers of Physical Devices: {len(physical_devices)}')
tf.config.set_visible_devices(physical_devices[device_index], 'GPU')
tf.config.experimental.set_memory_growth(physical_devices[device_index], True)
print(f'Using Device: {physical_devices[device_index]}')

In [None]:
dataloader = tf_data_model.load_cifar(
    resolution=resolution, batch_size=batch_size, dataset=dataset
)

In [None]:
model = tf_data_model.build_resnet(
    dataset=dataset, depth=depth, dropout_rate=dropout_rate, resolution=resolution
)
#print(model.summary())

In [None]:
def lr_schedule(epoch, lr, milestones, gamma: float = 0.1):
    if epoch in milestones:
        lr *= gamma
    return lr

class TimeCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.history = []
    def on_epoch_begin(self, epoch, logs=None):
        self.time_epoch_begin = time.perf_counter()
    def on_epoch_end(self, epoch, logs=None):
        self.history.append(time.perf_counter() - self.time_epoch_begin)

lr_scheduler_callback = keras.callbacks.LearningRateScheduler(
    lambda x, y: lr_schedule(x, y, milestones=milestones, gamma=gamma)
)

time_callback = TimeCallback()

In [None]:
model.compile(
    optimizer=keras.optimizers.SGD(
        learning_rate=learning_rate,
        momentum=momentum,
        #weight_decay=weight_decay
    ),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'],
    #jit_compile=JIT_COMPILE_FLAG
)

In [None]:
logs = model.fit(
    dataloader['train'],
    epochs=epochs,
    verbose='auto',
    callbacks=[time_callback, lr_scheduler_callback],
    validation_data=dataloader['val']
)
logs.history['t'] = time_callback.history

In [None]:
logs.history