In [None]:
import time
import tensorflow as tf
from tensorflow import keras

In [None]:
# Optimization Setting
device_index = -1
MIXED_PRECISION_FLAG = False
JIT_COMPILE_FLAG = False

# Dataloader Setting
batch_size = 100
validation_batch_size = 1000
seed = None
num_parallel_calls = tf.data.AUTOTUNE

# Training Setting
epochs = 5
## loss function
learning_rate = 1e-1
momentum = 0.9
weight_decay = 1e-4
## lr scheduler
step_size = 25
gamma = 0.2

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
print(f'Numbers of Physical Devices: {len(physical_devices)}')
tf.config.set_visible_devices(physical_devices[device_index], 'GPU')
tf.config.experimental.set_memory_growth(physical_devices[device_index], True)
print(f'Using device: {physical_devices[device_index]}')

In [None]:
# only TPUs support 'mixed_bfloat16'
# if using NVIDIA GPUs, choose 'mixed_float16'
if MIXED_PRECISION_FLAG:
    policy = keras.mixed_precision.Policy('mixed_float16')
    keras.mixed_precision.set_global_policy(policy)
    print(f'Policy: {policy.name}')
    print(f'Compute dtype: {policy.compute_dtype}')
    print(f'Variable dtype: {policy.variable_dtype}')

In [None]:
def load_cifar100(
    batch_size: int,
    validation_batch_size: int,
    seed: int = None,
    num_parallel_calls: int = tf.data.AUTOTUNE
):
    def map_preprocessing(image):
        # for cifar-10
        #mean = [0.49137255, 0.48235294, 0.44666667]
        #variance = [0.06103806, 0.05930657, 0.06841815]
        # for cifar-100
        mean = [0.50705882, 0.48666667, 0.44078431]
        variance = [0.07153003, 0.06577716, 0.0762193 ]
        transform = keras.Sequential([
            keras.layers.Rescaling(1/255),
            keras.layers.Normalization(mean=mean, variance=variance)
        ])
        return transform(image)
    
    def map_augmentation(image):
        transform = keras.Sequential([
            keras.layers.RandomTranslation(
                height_factor=0.1,
                width_factor=0.1,
                fill_mode='constant'
            ),
            keras.layers.RandomFlip('horizontal')
        ])
        return transform(image)
    
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
    dataloader = {
        'train': (tf.data.Dataset.from_tensor_slices((x_train, y_train))
                  .map(lambda x, y: (map_preprocessing(x), y),
                       num_parallel_calls=num_parallel_calls)
                  .cache()
                  .shuffle(buffer_size=len(x_train), seed=seed)
                  .map(lambda x, y: (map_augmentation(x), y),
                       num_parallel_calls=num_parallel_calls)
                  .batch(batch_size=batch_size)
                  .prefetch(buffer_size=tf.data.AUTOTUNE)),
        'test': (tf.data.Dataset.from_tensor_slices((x_test, y_test))
                 .map(lambda x, y: (map_preprocessing(x), y),
                      num_parallel_calls=num_parallel_calls)
                 .cache()
                 .batch(batch_size=validation_batch_size)
                 .prefetch(buffer_size=tf.data.AUTOTUNE))
    }
    return dataloader

In [None]:
def make_resnet18(
    inputs: keras.Input = keras.Input(shape=(32, 32, 3)),
    classes: int = 100
) -> keras.Model:
    def basicblock(x: keras.Input, filters: int, conv_shortcut: bool = False):
        if conv_shortcut:
            shortcut = keras.layers.Conv2D(filters, 1, strides=2)(x)
            shortcut = keras.layers.BatchNormalization(epsilon=1.001e-5)(shortcut)
            x = keras.layers.Conv2D(filters, 3, strides=2, padding='same')(x)
        else:
            shortcut = x
            x = keras.layers.Conv2D(filters, 3, padding='same')(x)
        x = keras.layers.BatchNormalization(epsilon=1.001e-5)(x)
        x = keras.layers.Activation('relu')(x)
        x = keras.layers.Conv2D(filters, 3, padding='same')(x)
        x = keras.layers.BatchNormalization(epsilon=1.001e-5)(x)
        x = keras.layers.Add()([shortcut, x])
        x = keras.layers.Activation('relu')(x)
        return x
    
    filters = 16
    x = keras.layers.Conv2D(filters, 3, padding='same')(inputs)
    x = keras.layers.BatchNormalization(epsilon=1.001e-5)(x)
    x = keras.layers.Activation('relu')(x)
    x = basicblock(x, filters)
    x = basicblock(x, filters)
    x = basicblock(x, filters)
    filters *= 2
    x = basicblock(x, filters, True)
    x = basicblock(x, filters)
    x = basicblock(x, filters)
    filters *= 2
    x = basicblock(x, filters, True)
    x = basicblock(x, filters)
    x = basicblock(x, filters)
    x = keras.layers.GlobalAveragePooling2D()(x)
    outputs = keras.layers.Dense(classes, activation='softmax')(x)
    return keras.Model(inputs=inputs, outputs=outputs)

In [None]:
dataloader = load_cifar100(
    batch_size=batch_size,
    validation_batch_size=validation_batch_size,
    seed=seed,
    num_parallel_calls=num_parallel_calls
)

In [None]:
model = make_resnet18()
model.summary()

In [None]:
def lr_schedule(epoch, lr, step_size: int = 25, gamma: float = 0.2):
    if not epoch % step_size and epoch:
        lr *= gamma
    return lr

class TimeCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.history = []
    def on_epoch_begin(self, epoch, logs=None):
        self.time_epoch_begin = time.perf_counter()
    def on_epoch_end(self, epoch, logs=None):
        self.history.append(time.perf_counter() - self.time_epoch_begin)

lr_scheduler_callback = keras.callbacks.LearningRateScheduler(
    lambda x, y: lr_schedule(x, y, step_size=step_size, gamma=gamma)
)
time_callback = TimeCallback()

In [None]:
model.compile(
    optimizer=keras.optimizers.SGD(
        learning_rate=learning_rate,
        momentum=momentum,
        weight_decay=weight_decay
    ),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'],
    jit_compile=JIT_COMPILE_FLAG
)

In [None]:
logs = model.fit(
    dataloader['train'],
    epochs=epochs,
    callbacks=[lr_scheduler_callback, time_callback],
    validation_data=dataloader['test']
)
logs.history['t'] = time_callback.history

In [None]:
logs.history

In [None]:
print('----')
print(f'batch_size: {batch_size}')
print(f'MIXED_PRECISION: {MIXED_PRECISION_FLAG}')
print(f'JIT_COMPILE: {JIT_COMPILE_FLAG}')
print(f'time: {logs.history["t"]}')
print(f'learning_rate: {logs.history["lr"]}')
print(f'loss: {logs.history["loss"]}')
print(f'acc: {logs.history["accuracy"]}')
print(f'val_loss: {logs.history["val_loss"]}')
print(f'val_acc: {logs.history["val_accuracy"]}')
print('----')