In [None]:
import keras_cv
import tensorflow as tf
from time import perf_counter
from tensorflow import keras
#import keras_core as keras

In [None]:
# Optimization Setting
device_index = -1
MIXED_PRECISION_FLAG = True
JIT_COMPILE_FLAG = True

# Data && Model Setting
## Data Resize Shape
resizing = 32 # 16, '32'
## Dataloader Setting
batch_size = 500
validation_batch_size = 1000
drop_remainder = False
num_parallel_calls = tf.data.AUTOTUNE
## Model Setting
n = 3
rate = 0.2 # 0.1, '0.2'
classes = 100
## RandAugment
augmentations_per_image = 2
magnitude = 0.2 # 0.1, '0.2'
rand_augment = False # rand_augment failed if using mixed_precision

# Training Setting
epochs = 160 # 160
## loss function
learning_rate = 1e-1 #* batch_size / 128
momentum = 0.9
weight_decay = 1e-4 # 1e-4
## lr scheduler
milestones = [80, 120] # [80, 120]
gamma = 0.1

In [None]:
stage_schedule = {
    'milestones': [40, 80, 100, 120, 140],
    'counter': 0,
    'len': 2,
    'size': [16, 32],
    'magnitude': [5, 10],
    'p': [0.1, 0.2]
}

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
print(f'Numbers of Physical Devices: {len(physical_devices)}')
tf.config.set_visible_devices(physical_devices[device_index], 'GPU')
tf.config.experimental.set_memory_growth(physical_devices[device_index], True)
print(f'Using device: {physical_devices[device_index]}')

In [None]:
# only TPUs support 'mixed_bfloat16'
# if using NVIDIA GPUs, choose 'mixed_float16'
if MIXED_PRECISION_FLAG:
    policy = keras.mixed_precision.Policy('mixed_float16')
    keras.mixed_precision.set_global_policy(policy)
    print(f'Policy: {policy.name}')
    print(f'Compute dtype: {policy.compute_dtype}')
    print(f'Variable dtype: {policy.variable_dtype}')
    #keras.mixed_precision.set_dtype_policy('mixed_float16')
    #print(f'{keras.mixed_precision.dtype_policy()}')

In [None]:
def load_cifar100(
    batch_size: int = 128,
    validation_batch_size: int = 128,
    resizing: int = 32,
    augmentations_per_image: int = 2,
    magnitude: float = 0.2,
    drop_remainder: bool = False,
    num_parallel_calls: int = tf.data.AUTOTUNE,
    rand_augment: bool = False
):
    '''
    # for cifar-10
    ## mean = [0.491, 0.482, 0.447]
    ## variance = [0.061, 0.059, 0.068]
    # for cifar-100
    ## mean = [0.507, 0.487, 0.441]
    ## variance = [0.072, 0.066, 0.076]
    '''
    mean = [0.507, 0.487, 0.441]
    variance = [0.072, 0.066, 0.076]
    
    def map_train_before_cache(image, label, resizing):
        transform = keras.Sequential([
            keras.layers.Resizing(height=resizing, width=resizing)
        ])
        return transform(image), label
    
    def map_train_after_cache(image, label, rand_augment):
        transform = keras.Sequential()
        if rand_augment:
            transform.add(
                keras_cv.layers.RandAugment(
                    value_range=(0, 255),
                    augmentations_per_image=augmentations_per_image,
                    magnitude=magnitude
                )
            )
        else:
            transform.add(
                keras.layers.RandomTranslation(
                    height_factor=0.125,
                    width_factor=0.125,
                    fill_mode='constant'
                )
            )
        transform.add(keras.layers.RandomFlip('horizontal'))
        transform.add(keras.layers.Rescaling(1/255))
        transform.add(keras.layers.Normalization(mean=mean, variance=variance))
        return transform(image), label
    
    def map_test(image, label, resizing):
        transform = keras.Sequential([
            keras.layers.Resizing(height=resizing, width=resizing),
            keras.layers.Rescaling(1/255),
            keras.layers.Normalization(mean=mean, variance=variance)
        ])
        return transform(image), label
    
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
    dataloader = {
        'train': (tf.data.Dataset.from_tensor_slices((x_train, y_train))
                  .map(lambda x, y: (map_train_before_cache(x, y, resizing)),
                       num_parallel_calls=num_parallel_calls)
                  .cache()
                  .shuffle(buffer_size=len(x_train))
                  .map(lambda x, y: (map_train_after_cache(x, y, rand_augment)),
                       num_parallel_calls=num_parallel_calls)
                  .batch(batch_size=batch_size, drop_remainder=drop_remainder)
                  .prefetch(buffer_size=tf.data.AUTOTUNE)),
        'test': (tf.data.Dataset.from_tensor_slices((x_test, y_test))
                 .map(lambda x, y: (map_test(x, y, resizing)),
                      num_parallel_calls=num_parallel_calls)
                 .cache()
                 .batch(batch_size=validation_batch_size)
                 .prefetch(buffer_size=tf.data.AUTOTUNE))
    }
    return dataloader

In [None]:
def build_cifar_resnet(
    resizing: int = 32,
    n: int = 3,
    rate: float = 0.2,
    classes: int = 100
) -> keras.Model:
    
    def basic_block(x: keras.Input, filters: int, conv_shortcut: bool = False):
        if conv_shortcut:
            shortcut = keras.layers.Conv2D(
                filters, 1, strides=2, use_bias=False,
                kernel_initializer='he_normal',
                #kernel_regularizer=keras.regularizers.L2(1e-4)
            )(x)
            shortcut = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(shortcut)
            x = keras.layers.Conv2D(
                filters, 3, strides=2, padding='same', use_bias=False,
                kernel_initializer='he_normal',
                #kernel_regularizer=keras.regularizers.L2(1e-4)
            )(x)
        else:
            shortcut = x
            x = keras.layers.Conv2D(
                filters, 3, padding='same', use_bias=False,
                kernel_initializer='he_normal',
                #kernel_regularizer=keras.regularizers.L2(1e-4)
            )(x)
        x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
        x = keras.layers.Activation('relu')(x)
        x = keras.layers.Conv2D(
            filters, 3, padding='same', use_bias=False,
            kernel_initializer='he_normal',
            #kernel_regularizer=keras.regularizers.L2(1e-4)
        )(x)
        x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
        x = keras.layers.Add()([shortcut, x])
        x = keras.layers.Activation('relu')(x)
        return x
    
    def basic_stack(x: keras.Input, filters: int, conv_shortcut: bool = False):
        for i in range(n):
            if i == 0 and conv_shortcut == True:
                filters *= 2
                x = basic_block(x, filters, conv_shortcut)
            else:
                x = basic_block(x, filters)
        return x, filters
    
    inputs = keras.Input(shape=(resizing, resizing, 3))
    filters = 16
    x = keras.layers.Conv2D(
        filters, 3, padding='same', use_bias=False,
        kernel_initializer='he_normal',
        #kernel_regularizer=keras.regularizers.L2(1e-4)
    )(inputs)
    x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
    x = keras.layers.Activation('relu')(x)
    x, filters = basic_stack(x, filters)
    x, filters = basic_stack(x, filters, True)
    x, filters = basic_stack(x, filters, True)
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(rate)(x)
    outputs = keras.layers.Dense(classes, activation='softmax')(x)
    return keras.Model(inputs=inputs, outputs=outputs)

In [None]:
def build_old_cifar_resnet(
    resizing: int = 32,
    filters: int = 64,
    repeat: list = [2, 2, 2, 2], # [2, 2, 2, 2] for 18-layer, [3, 4, 6, 3] for 34-layer
    rate: float = 0.2,
    classes: int = 100
) -> keras.Model:
    
    def basic_block(x: keras.Input, filters: int, conv_shortcut: bool = False):
        if conv_shortcut:
            shortcut = keras.layers.Conv2D(
                filters, 1, strides=2, use_bias=False,
                kernel_initializer='he_normal',
                #kernel_regularizer=keras.regularizers.L2(1e-4)
            )(x)
            shortcut = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(shortcut)
            x = keras.layers.Conv2D(
                filters, 3, strides=2, padding='same', use_bias=False,
                kernel_initializer='he_normal',
                #kernel_regularizer=keras.regularizers.L2(1e-4)
            )(x)
        else:
            shortcut = x
            x = keras.layers.Conv2D(
                filters, 3, padding='same', use_bias=False,
                kernel_initializer='he_normal',
                #kernel_regularizer=keras.regularizers.L2(1e-4)
            )(x)
        x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
        x = keras.layers.Activation('relu')(x)
        x = keras.layers.Conv2D(
            filters, 3, padding='same', use_bias=False,
            kernel_initializer='he_normal',
            #kernel_regularizer=keras.regularizers.L2(1e-4)
        )(x)
        x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
        x = keras.layers.Add()([shortcut, x])
        x = keras.layers.Activation('relu')(x)
        return x
    
    def basic_stack(x: keras.Input, filters: int, n: int, conv_shortcut: bool = False):
        for i in range(n):
            if i == 0 and conv_shortcut == True:
                filters *= 2
                x = basic_block(x, filters, conv_shortcut)
            else:
                x = basic_block(x, filters)
        return x, filters
    
    inputs = keras.Input(shape=(resizing, resizing, 3))
    filters = 16
    x = keras.layers.Conv2D(
        filters, 3, padding='same', use_bias=False,
        kernel_initializer='he_normal',
        #kernel_regularizer=keras.regularizers.L2(1e-4)
    )(inputs)
    x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
    x = keras.layers.Activation('relu')(x)
    for i, n in enumerate(repeat):
        if i == 0:
            x, filters = basic_stack(x, filters, n)
        else:
            x, filters = basic_stack(x, filters, n, True)
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(rate)(x)
    outputs = keras.layers.Dense(classes, activation='softmax')(x)
    return keras.Model(inputs=inputs, outputs=outputs)

In [None]:
dataloader = load_cifar100(
    batch_size=batch_size,
    validation_batch_size=validation_batch_size,
    resizing=resizing,
    augmentations_per_image=augmentations_per_image,
    magnitude=magnitude,
    drop_remainder=drop_remainder,
    num_parallel_calls=num_parallel_calls,
    rand_augment=rand_augment
)

In [None]:
#model = build_cifar_resnet(resizing=resizing, n=n, rate=rate, classes=classes)
model = build_old_cifar_resnet(resizing=resizing, rate=rate, classes=classes)
#model.summary()

In [None]:
def lr_schedule(epoch, lr, milestones, gamma: float = 0.1):
    if epoch in milestones:
        lr *= gamma
    return lr

class TimeCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.history = []
    def on_epoch_begin(self, epoch, logs=None):
        self.time_epoch_begin = perf_counter()
    def on_epoch_end(self, epoch, logs=None):
        self.history.append(perf_counter() - self.time_epoch_begin)

lr_scheduler_callback = keras.callbacks.LearningRateScheduler(
    lambda x, y: lr_schedule(x, y, milestones=milestones, gamma=gamma)
)
time_callback = TimeCallback()

In [None]:
model.compile(
    optimizer=keras.optimizers.experimental.SGD(
        learning_rate=learning_rate,
        momentum=momentum,
        weight_decay=weight_decay
    ),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'],
    jit_compile=JIT_COMPILE_FLAG
)

In [None]:
# for training time warm-up
'''
class StopCallBack(keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        self.model.stop_training = True
stop_call = StopCallBack()
model.fit(dataloader['train'], verbose=2, callbacks=[stop_call])
'''
pass

In [None]:
logs = model.fit(
    dataloader['train'],
    epochs=epochs,
    verbose=2,
    callbacks=[lr_scheduler_callback, time_callback],
    validation_data=dataloader['test']
)
logs.history['t'] = time_callback.history
#logs.history

In [None]:
index = -1
print('----')
print('PRINT RESULTS')
print(f'batch_size: {batch_size}')
print(f'MIXED_PRECISION: {MIXED_PRECISION_FLAG}')
print(f'JIT_COMPILE: {JIT_COMPILE_FLAG}')
print(f'time: {logs.history["t"][index]}')
print(f'learning_rate: {logs.history["lr"][index]}')
print(f'loss: {logs.history["loss"][index]}')
print(f'acc: {logs.history["accuracy"][index]}')
print(f'val_loss: {logs.history["val_loss"][index]}')
print(f'val_acc: {logs.history["val_accuracy"][index]}')
print('----')

In [None]:
'''
----
PRINT RESULTS
batch_size: 2048
MIXED_PRECISION: True
JIT_COMPILE: True
time: 7.176191415870562
learning_rate: 0.0009999999310821295
loss: 0.9402134418487549
acc: 0.7233999967575073
val_loss: 2.5140624046325684
val_acc: 0.44749999046325684
----
----
PRINT RESULTS
batch_size: 512
MIXED_PRECISION: True
JIT_COMPILE: True
time: 7.181444549933076
learning_rate: 0.0009999999310821295
loss: 0.45324844121932983
acc: 0.8585600256919861
val_loss: 2.8365235328674316
val_acc: 0.4674000144004822
----
----
PRINT RESULTS
batch_size: 128
MIXED_PRECISION: True
JIT_COMPILE: True
time: 8.163974148919806
learning_rate: 0.0009999999310821295
loss: 0.4783342182636261
acc: 0.8499000072479248
val_loss: 2.6634764671325684
val_acc: 0.5026000142097473
----
----
PRINT RESULTS
batch_size: 32
MIXED_PRECISION: True
JIT_COMPILE: True
time: 18.20342784305103
learning_rate: 0.0009999999310821295
loss: 0.9204937219619751
acc: 0.7246999740600586
val_loss: 2.087597608566284
val_acc: 0.5317999720573425
----
'''