In [None]:
import time
import tensorflow as tf
from tensorflow import keras
import keras_cv

### Add RandAugment

In [None]:
# Optimization Setting
device_index = -1
MIXED_PRECISION_FLAG = False
JIT_COMPILE_FLAG = False

# Dataloader Setting
batch_size = 512
validation_batch_size = 1000
drop_remainder = True
num_parallel_calls = tf.data.AUTOTUNE
rand_augment = True

# Model Setting
resizing = 32
n = 3
rate = 0.2
classes = 100

# Training Setting
epochs = 160 # 160
## loss function
learning_rate = 1e-1
momentum = 0.9
weight_decay = 1e-4 # 1e-4
## lr scheduler
milestones = [80, 120] # [80, 120]
gamma = 0.1

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
print(f'Numbers of Physical Devices: {len(physical_devices)}')
tf.config.set_visible_devices(physical_devices[device_index], 'GPU')
tf.config.experimental.set_memory_growth(physical_devices[device_index], True)
print(f'Using device: {physical_devices[device_index]}')

In [None]:
# only TPUs support 'mixed_bfloat16'
# if using NVIDIA GPUs, choose 'mixed_float16'
if MIXED_PRECISION_FLAG:
    policy = keras.mixed_precision.Policy('mixed_float16')
    keras.mixed_precision.set_global_policy(policy)
    print(f'Policy: {policy.name}')
    print(f'Compute dtype: {policy.compute_dtype}')
    print(f'Variable dtype: {policy.variable_dtype}')

In [None]:
def load_cifar100(
    batch_size: int,
    validation_batch_size: int,
    resizing: int = 32,
    drop_remainder: bool = False,
    num_parallel_calls: int = tf.data.AUTOTUNE,
    rand_augment: bool = False
):
    # for cifar-10
    ## mean = [0.491, 0.482, 0.447]
    ## variance = [0.061, 0.059, 0.068]
    # for cifar-100
    ## mean = [0.507, 0.487, 0.441]
    ## variance = [0.072, 0.066, 0.076]
    mean = [0.507, 0.487, 0.441]
    variance = [0.072, 0.066, 0.076]
    def map_train_before_cache(image, resizing):
        transform = keras.Sequential([
            keras.layers.Resizing(height=resizing, width=resizing)
        ])
        return transform(image)
    
    def map_train_after_cache(image, rand_augment):
        transform = tf.keras.Sequential()
        if rand_augment:
            transform.add(keras_cv.layers.RandAugment(
                value_range=(0, 255),
                augmentations_per_image=2,
                magnitude=0.2
            ))
        else:
            transform.add(keras.layers.RandomTranslation(
                height_factor=0.125,
                width_factor=0.125,
                fill_mode='constant'
            ))
        transform.add(keras.layers.RandomFlip('horizontal'))
        transform.add(keras.layers.Rescaling(1/255))
        transform.add(keras.layers.Normalization(mean=mean, variance=variance))
        return transform(image)
    
    def map_test(image, resizing):
        transform = keras.Sequential([
            keras.layers.Resizing(height=resizing, width=resizing),
            keras.layers.Rescaling(1/255),
            keras.layers.Normalization(mean=mean, variance=variance)
        ])
        return transform(image)
    
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
    dataloader = {
        'train': (tf.data.Dataset.from_tensor_slices((x_train, y_train))
                  .map(lambda x, y: (map_train_before_cache(x, resizing), y),
                       num_parallel_calls=num_parallel_calls)
                  .cache()
                  .shuffle(buffer_size=len(x_train))
                  .map(lambda x, y: (map_train_after_cache(x, rand_augment), y),
                       num_parallel_calls=num_parallel_calls)
                  .batch(batch_size=batch_size, drop_remainder=drop_remainder)
                  .prefetch(buffer_size=tf.data.AUTOTUNE)),
        'test': (tf.data.Dataset.from_tensor_slices((x_test, y_test))
                 .map(lambda x, y: (map_test(x, resizing), y),
                      num_parallel_calls=num_parallel_calls)
                 .cache()
                 .batch(batch_size=validation_batch_size)
                 .prefetch(buffer_size=tf.data.AUTOTUNE))
    }
    return dataloader

In [None]:
def build_cifar_resnet(
    resizing: int = 32,
    n: int = 3,
    rate: float = 0.2,
    classes: int = 100
) -> keras.Model:
    def basic_block(x: keras.Input, filters: int, conv_shortcut: bool = False):
        if conv_shortcut:
            shortcut = keras.layers.Conv2D(
                filters, 1, strides=2, kernel_initializer='he_normal'
            )(x)
            shortcut = keras.layers.BatchNormalization(epsilon=1.001e-5)(shortcut)
            x = keras.layers.Conv2D(
                filters, 3, strides=2, padding='same', kernel_initializer='he_normal'
            )(x)
        else:
            shortcut = x
            x = keras.layers.Conv2D(
                filters, 3, padding='same', kernel_initializer='he_normal'
            )(x)
        x = keras.layers.BatchNormalization(epsilon=1.001e-5)(x)
        x = keras.layers.Activation('relu')(x)
        x = keras.layers.Conv2D(filters, 3, padding='same', kernel_initializer='he_normal')(x)
        x = keras.layers.BatchNormalization(epsilon=1.001e-5)(x)
        x = keras.layers.Add()([shortcut, x])
        x = keras.layers.Activation('relu')(x)
        return x
    def basic_stack(x: keras.Input, filters: int, conv_shortcut: bool = False):
        for i in range(n):
            if i == 0 and conv_shortcut == True:
                filters *= 2
                x = basic_block(x, filters, conv_shortcut)
            else:
                x = basic_block(x, filters)
        return x, filters
    inputs = keras.Input(shape=(resizing, resizing, 3))
    filters = 16
    x = keras.layers.Conv2D(filters, 3, padding='same', kernel_initializer='he_normal')(inputs)
    x = keras.layers.BatchNormalization(epsilon=1.001e-5)(x)
    x = keras.layers.Activation('relu')(x)
    x, filters = basic_stack(x, filters)
    x, filters = basic_stack(x, filters, True)
    x, filters = basic_stack(x, filters, True)
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(rate)(x)
    outputs = keras.layers.Dense(classes, activation='softmax')(x)
    return keras.Model(inputs=inputs, outputs=outputs)

In [None]:
dataloader = load_cifar100(
    batch_size=batch_size,
    validation_batch_size=validation_batch_size,
    resizing=resizing,
    drop_remainder=drop_remainder,
    num_parallel_calls=num_parallel_calls,
    rand_augment=rand_augment
)

In [None]:
model = build_cifar_resnet(resizing=resizing, n=n, rate=rate, classes=classes)
#model.summary()

In [None]:
def lr_schedule(epoch, lr, milestones, gamma: float = 0.2):
    if epoch in milestones:
        lr *= gamma
    return lr

class TimeCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.history = []
    def on_epoch_begin(self, epoch, logs=None):
        self.time_epoch_begin = time.perf_counter()
    def on_epoch_end(self, epoch, logs=None):
        self.history.append(time.perf_counter() - self.time_epoch_begin)

lr_scheduler_callback = keras.callbacks.LearningRateScheduler(
    lambda x, y: lr_schedule(x, y, milestones=milestones, gamma=gamma)
)
time_callback = TimeCallback()

In [None]:
model.compile(
    optimizer=keras.optimizers.SGD(
        learning_rate=learning_rate,
        momentum=momentum,
        weight_decay=weight_decay
    ),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'],
    jit_compile=JIT_COMPILE_FLAG
)

In [None]:
# for training time warm-up
'''
class StopCallBack(keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        self.model.stop_training = True
stop_call = StopCallBack()
model.fit(dataloader['train'], verbose=2, callbacks=[stop_call])
'''
pass

In [None]:
logs = model.fit(
    dataloader['train'],
    epochs=epochs,
    verbose=2,
    callbacks=[lr_scheduler_callback, time_callback],
    validation_data=dataloader['test']
)
logs.history['t'] = time_callback.history

In [None]:
logs.history

In [None]:
index = -1
print('----')
print('PRINT RESULTS')
print(f'batch_size: {batch_size}')
print(f'MIXED_PRECISION: {MIXED_PRECISION_FLAG}')
print(f'JIT_COMPILE: {JIT_COMPILE_FLAG}')
print(f'time: {logs.history["t"][index]}')
print(f'learning_rate: {logs.history["lr"][index]}')
print(f'loss: {logs.history["loss"][index]}')
print(f'acc: {logs.history["accuracy"][index]}')
print(f'val_loss: {logs.history["val_loss"][index]}')
print(f'val_acc: {logs.history["val_accuracy"][index]}')
print('----')

In [None]:
# bs = 512
#Epoch 160/160
#97/97 - 20s - loss: 1.2973 - accuracy: 0.6332 - val_loss: 1.2738 - val_accuracy: 0.6409 - lr: 1.0000e-03 - 20s/epoch - 205ms/step

# bs = 256
#Epoch 160/160
#390/390 - 20s - loss: 1.1997 - accuracy: 0.6605 - val_loss: 1.2067 - val_accuracy: 0.6594 - lr: 1.0000e-03 - 20s/epoch - 51ms/step