In [None]:
import argparse
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision

In [None]:
parser = argparse.ArgumentParser(
        description='Too Simple! Sometimes Naive!'
)
parser.add_argument(
    '-m', '--mixed',
    type=bool,
    default=False,
    help='MIXED_PRECISION'
)
parser.add_argument(
    '-j', '--jit',
    type=bool,
    default=False,
    help='JIT_COMPILE'
)
parser.add_argument(
    '-b', '--batch',
    type=int,
    default=500,
    help='batch_size'
)
# parser.parse_args() used in .py
# parser.parse_args('') used in .ipynb
args = parser.parse_args('')

In [None]:
# Optimization Setting
device_index = -1
MIXED_PRECISION_FLAG = args.mixed
JIT_COMPILE_FLAG = args.jit

# Dataloader Setting
batch_size = args.batch

# Training Setting
learning_rate = 1e-2
momentum = 0.9
epochs = 1

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(physical_devices[device_index], 'GPU')
#tf.config.experimental.set_memory_growth(physical_devices[device_index], True)

In [None]:
if MIXED_PRECISION_FLAG:
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)

In [None]:
def load_cifar100():
    def data_preprocessing(x, y):
        mean = tf.constant([129.3, 124.1, 112.4]) / 255
        std = tf.constant([68.2, 65.4, 70.4]) / 255
        pre = keras.Sequential([
            layers.Rescaling(1/255),
            layers.Normalization(mean=mean, variance=tf.math.square(std))
        ])
        return pre(x), keras.utils.to_categorical(y)
    
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
    x_train, y_train = data_preprocessing(x_train, y_train)
    x_test, y_test = data_preprocessing(x_test, y_test)
    return (x_train, y_train), (x_test, y_test)

In [None]:
def make_resnet18(inputs: keras.Input = keras.Input(shape=(32, 32, 3)),
                  num_classes: int = 100
                 ) -> keras.Model:
    def basicblock(inputs: keras.Input, filters: int, bottleneck: bool):
        if bottleneck:
            identity = layers.Conv2D(filters, 1, strides=2, padding='valid',
                                     kernel_initializer='he_normal'
                                    )(inputs)
            identity = layers.BatchNormalization()(identity)
            x = layers.Conv2D(filters, 3, strides=2, padding='same',
                              kernel_initializer='he_normal'
                             )(inputs)
        else:
            identity = inputs
            x = layers.Conv2D(filters, 3, strides=1, padding='same',
                              kernel_initializer='he_normal',
                             )(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        x = layers.Conv2D(filters, 3, strides=1, padding='same',
                          kernel_initializer='he_normal',
                         )(x)
        x = layers.BatchNormalization()(x)
        x = layers.Add()([x, identity])
        x = layers.Activation('relu')(x)
        return x
    
    x = layers.Conv2D(64, 3, strides=1, padding='same',
                      kernel_initializer='he_normal',
                     )(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = basicblock(x, 64, False)
    x = basicblock(x, 64, False)
    x = basicblock(x, 128, True)
    x = basicblock(x, 128, False)
    x = basicblock(x, 256, True)
    x = basicblock(x, 256, False)
    x = basicblock(x, 512, True)
    x = basicblock(x, 512, False)
    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    return keras.Model(inputs=inputs, outputs=outputs)

In [None]:
(x_train, y_train), (x_test, y_test) = load_cifar100()

In [None]:
datagen = keras.preprocessing.image.ImageDataGenerator(
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True
)

In [None]:
model = make_resnet18()

In [None]:
model.compile(
optimizer=(
    keras.mixed_precision.LossScaleOptimizer(
            keras.optimizers.SGD(learning_rate=learning_rate, momentum=momentum)
        ) if MIXED_PRECISION_FLAG
        else keras.optimizers.SGD(learning_rate=learning_rate, momentum=momentum)
    ),
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=['accuracy'],
    jit_compile=JIT_COMPILE_FLAG
)

In [None]:
# get time
#model.fit(
#    datagen.flow(x_train[:batch_size], y_train[:batch_size], batch_size=batch_size),
#    epochs=epochs,
#    workers=tf.data.AUTOTUNE
#)

In [None]:
t = time.monotonic()
logs = model.fit(
    datagen.flow(x_train, y_train, batch_size=batch_size),
    epochs=epochs,
    validation_data=(x_test, y_test),
    validation_batch_size=batch_size,
    workers=tf.data.AUTOTUNE
)
t = time.monotonic() - t

In [None]:
print('----')
print(f'BATCH_SIZE: {batch_size}')
print(f'MIXED_PRECISION: {MIXED_PRECISION_FLAG}')
print(f'JIT_COMPILE: {JIT_COMPILE_FLAG}')
print(f'TIME: {t}')
print(f'LOGS: {logs.history}')
print('----')