In [None]:
import os
import time
import tensorflow as tf
from tensorflow import keras

In [None]:
# for device
device_index = -1
MIXED_PRECISION_FLAG = True
## TF 2.6 does not support jit_compile in keras.Model.compile() yet.
## So, just set it to False.
## Another way is to use environment variable 'TF_XLA_FLAGS'.
## Set os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'.
JIT_COMPILE_FLAG = True

# for dataset
dataset = 'cifar100'
dir_path = '/ssd'
resolution = 32
batch_size = 100

# for model
depth = 18
dropout_rate = 0.2 # [0.1, 0.2]

# for training
learning_rate = 1e-1
momentum = 0.9
epochs = 6 # 90
## TF 2.6 does not support weight_decay in keras.optimizers.SGD() yet.
## So, it might be set in the model.
weight_decay = 1e-4

# for learning rate scheduler
milestones = [2, 4] # [30, 60]
gamma = 0.1

######## for testing: BS and LR are propotional
#learning_rate *= batch_size / 100
######## for testing: kernel_regularizer for TF 2.6
KR_FLAG = True
KR_VALUE = keras.regularizers.L2(weight_decay) if KR_FLAG else None

In [None]:
######## for testing: TF 2.6 jit_compile, it must call berfore any tensorflow function.
## For unknown reasons, '--tf_xla_cpu_global_jit' only supports the first GPU.
## Otherwise an error will result.
if JIT_COMPILE_FLAG:
    if device_index == 0:
        # can not use the condition 'len(tf.config.list_physical_devices('GPU')) == 1'
        # since it call tf function...
        os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
    else:
        os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2'

if MIXED_PRECISION_FLAG:
    policy = keras.mixed_precision.Policy('mixed_float16')
    keras.mixed_precision.set_global_policy(policy)
    print(f'Policy: {policy.name}')
    print(f'Compute dtype: {policy.compute_dtype}')
    print(f'Variable dtype: {policy.variable_dtype}')

print('----')
print(f'MIXED_PRECISION: {MIXED_PRECISION_FLAG}')
print(f'JIT_COMPILE: {JIT_COMPILE_FLAG}')
print('----')

In [None]:
os.getenv('TF_XLA_FLAGS')

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
print(f'Numbers of Physical Devices: {len(physical_devices)}')
tf.config.set_visible_devices(physical_devices[device_index], 'GPU')
tf.config.experimental.set_memory_growth(physical_devices[device_index], True)
print(f'Using Device: {physical_devices[device_index]}')

In [None]:
def load_cifar(resolution: int, batch_size: int, dataset: str):
    mean = [0.485, 0.456, 0.406]
    std = [0.299, 0.224, 0.225]
    var = [0.089401, 0.050176, 0.050625] # tf.math.square(std)
    
    resolution_list = [16, 32]
    dataset_list = ['cifar10', 'cifar100']
    
    if resolution not in resolution_list:
        raise ValueError(f'Invalid resolution "{resolution}", it should be in {resolution_list}.')
    if dataset not in dataset_list:
        raise ValueError(f'Invalid resolution "{dataset}", it should be in {dataset_list}.')
    
    '''
    # tf.keras.utils.image_dataset_from_directory() can not allow simple augmentation pipeline
    # simple augmentation pipeline == keras.layers.RandomXXX()
    # move simple augmentation pipeline to build_model
    simple_aug = keras.Sequential([
        keras.layers.RandomFlip('horizontal'),
        keras.layers.RandomRotation(factor=0.02),
        keras.layers.RandomZoom(height_factor=0.2, width_factor=0.2)
    ])
    '''
    def preprocessing_map(image):
        transform = keras.Sequential([
            keras.layers.Resizing(resolution, resolution),
            keras.layers.Rescaling(1/255),
            keras.layers.Normalization(mean=mean, variance=var)
        ])
        return transform(image)
    
    if dataset == 'cifar10':
        (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
    elif dataset == 'cifar100':
        (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
    
    dataloader = {
        'train': (
            tf.data.Dataset.from_tensor_slices((x_train, y_train))
            .map(
                lambda x, y: (preprocessing_map(x), y),
                num_parallel_calls=tf.data.AUTOTUNE
            )
            .cache()
            .shuffle(buffer_size=len(y_train))
            .batch(batch_size=batch_size)
            .prefetch(buffer_size=tf.data.AUTOTUNE)
        ),
        'val': (
            tf.data.Dataset.from_tensor_slices((x_test, y_test))
            .map(
                lambda x, y: (preprocessing_map(x), y),
                num_parallel_calls=tf.data.AUTOTUNE
            )
            .batch(batch_size=batch_size)
            .cache()
            .prefetch(buffer_size=tf.data.AUTOTUNE)
        )
    }
    
    return dataloader

In [None]:
def build_resnet(
    dataset: str,
    depth: int,
    resolution: int,
    dropout_rate: float = 0.2,
) -> keras.Model:
    
    dataset_list = ['cifar10', 'cifar100', 'imagenet']
    depth_list = [18, 34]
    
    if dataset not in dataset_list:
        raise ValueError(f'Invalid dataset "{dataset}", it should be in {dataset_list}.')
    if depth not in depth_list:
        raise ValueError(f'Invalid depth "{depth}", it should be in {depth_list}.')
    
    if dataset == 'cifar10':
        classes = 10
    elif dataset == 'cifar100':
        classes = 100
    elif dataset == 'imagenet':
        classes = 1000
    
    if depth == 18:
        stack_list = [2, 2, 2, 2]
    elif depth == 34:
        stack_list = [3, 4, 6, 3]
    
    def basic_block(x: keras.Input, filters: int, conv_shortcut: bool = False):
        if conv_shortcut:
            shortcut = keras.layers.Conv2D(
                filters, 1, strides=2, use_bias=False,
                kernel_initializer='he_normal',
                kernel_regularizer=KR_VALUE
            )(x)
            shortcut = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(shortcut)
            x = keras.layers.Conv2D(
                filters, 3, strides=2, padding='same', use_bias=False,
                kernel_initializer='he_normal',
                kernel_regularizer=KR_VALUE
            )(x)
        else:
            shortcut = x
            x = keras.layers.Conv2D(
                filters, 3, padding='same', use_bias=False,
                kernel_initializer='he_normal',
                kernel_regularizer=KR_VALUE
            )(x)
        x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
        x = keras.layers.Activation('relu')(x)
        x = keras.layers.Conv2D(
            filters, 3, padding='same', use_bias=False,
            kernel_initializer='he_normal',
            kernel_regularizer=KR_VALUE
        )(x)
        x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
        x = keras.layers.Add()([shortcut, x])
        x = keras.layers.Activation('relu')(x)
        return x
    
    def basic_stack(x: keras.Input, filters: int, stack: int, conv_shortcut: bool = False):
        for i in range(stack):
            if i == 0 and conv_shortcut == True:
                filters *= 2
                x = basic_block(x, filters, conv_shortcut)
            else:
                x = basic_block(x, filters)
        return x, filters
    
    inputs = keras.Input(shape=(resolution, resolution, 3))
    filters = 64
    ## simple augmentation pipeline
    simple_aug = keras.Sequential([
        keras.layers.RandomFlip('horizontal'),
        keras.layers.RandomRotation(factor=0.02),
        keras.layers.RandomZoom(height_factor=0.2, width_factor=0.2)
    ])
    x = simple_aug(inputs)
    ## stem
    if 'cifar' in dataset:
        x = keras.layers.Conv2D(
            filters, 3, padding='same', use_bias=False,
            kernel_initializer='he_normal',
            kernel_regularizer=KR_VALUE
        )(x)
        x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
        x = keras.layers.Activation('relu')(x)
    elif dataset == 'imagenet':
        x = keras.layers.Conv2D(
            filters, 7, strides=2, padding='same', use_bias=False,
            kernel_initializer='he_normal',
            kernel_regularizer=KR_VALUE
        )(x)
        x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
        x = keras.layers.Activation('relu')(x)
        x = keras.layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x)
    ## trunk
    for i, stack in enumerate(stack_list):
        if i == 0:
            x, filters = basic_stack(x, filters, stack)
        else:
            x, filters = basic_stack(x, filters, stack, True)
    ## classifier
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(dropout_rate)(x)
    outputs = keras.layers.Dense(
        classes, activation='softmax',
        kernel_regularizer=KR_VALUE
    )(x)
    
    return keras.Model(inputs=inputs, outputs=outputs)

In [None]:
dataloader = load_cifar(resolution=resolution, batch_size=batch_size, dataset=dataset)

In [None]:
model = build_resnet(dataset=dataset, depth=depth, resolution=resolution, dropout_rate=dropout_rate)
#print(model.summary())

In [None]:
def lr_schedule(epoch, lr, milestones, gamma: float = 0.1):
    if epoch in milestones:
        lr *= gamma
    return lr

class TimeCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.history = []
    def on_epoch_begin(self, epoch, logs=None):
        self.time_epoch_begin = time.perf_counter()
    def on_epoch_end(self, epoch, logs=None):
        self.history.append(time.perf_counter() - self.time_epoch_begin)

lr_scheduler_callback = keras.callbacks.LearningRateScheduler(
    lambda x, y: lr_schedule(x, y, milestones=milestones, gamma=gamma)
)

time_callback = TimeCallback()

In [None]:
model.compile(
    optimizer=keras.optimizers.SGD(
        learning_rate=learning_rate,
        momentum=momentum,
        #weight_decay=weight_decay
    ),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'],
    #jit_compile=JIT_COMPILE_FLAG
)

In [None]:
logs = model.fit(
    dataloader['train'],
    epochs=epochs,
    verbose='auto',
    callbacks=[time_callback, lr_scheduler_callback],
    validation_data=dataloader['val']
)
logs.history['t'] = time_callback.history