In [None]:
import time
import tensorflow as tf
from tensorflow import keras

In [None]:
# for device
device_index = 0
MIXED_PRECISION_FLAG = True
## Tensorflow 2.6 does not support jit_compile in model.compile() yet.
## So, just set it to False.
JIT_COMPILE_FLAG = False

# for dataset
dataset = 'imagenet'
dir_path = f'/ssd/{dataset}'
## r_b = [(160, 510), (224, 390), (288, 180)]
resolution = 224
batch_size = 1500

# for model
depth = 18
# dropout rate from 0.1 to 0.3
dropout_rate = 0.2

# for training
learning_rate = 1e-1
momentum = 0.9
epochs = 120
## Tensorflow 2.6 does not support weight_decay in keras.optimizers.SGD() yet.
## So, it might be set in the model.
weight_decay = 1e-4

# for learning rate scheduler
milestones = [30, 60, 90]
gamma = 0.1

######## for testing: BS and LR are propotional.
learning_rate *= batch_size / 390


In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
print(f'Numbers of Physical Devices: {len(physical_devices)}')
tf.config.set_visible_devices(physical_devices[device_index], 'GPU')
tf.config.experimental.set_memory_growth(physical_devices[device_index], True)
print(f'Using device: {physical_devices[device_index]}')

In [None]:
if MIXED_PRECISION_FLAG:
    policy = keras.mixed_precision.Policy('mixed_float16')
    keras.mixed_precision.set_global_policy(policy)
    print(f'Policy: {policy.name}')
    print(f'Compute dtype: {policy.compute_dtype}')
    print(f'Variable dtype: {policy.variable_dtype}')

In [None]:
def load_imagenet(resolution: int, batch_size: int):
    mean = [0.485, 0.456, 0.406]
    std = [0.299, 0.224, 0.225]
    var = [0.089401, 0.050176, 0.050625]#tf.math.square(std)
    
    resolution_list = [160, 224, 288]
    if resolution not in resolution_list:
        raise ValueError(f'Invalid resolution "{resolution}", it should be in {resolution_list}.')
    
    def train_map(image):
        transform = keras.Sequential([
            # tf.keras.utils.image_dataset_from_directory() can not allow random preprocessing layers
            # move random preprocessing layers to build_model
            #keras.layers.RandomFlip('horizontal'),
            #keras.layers.RandomTranslation(0.125, 0.125, fill_mode='constant'),
            keras.layers.Rescaling(1/255),
            keras.layers.Normalization(mean=mean, variance=var)
        ])
        return transform(image)
    train_data = tf.keras.utils.image_dataset_from_directory(
        directory=f'{dir_path}/train',
        label_mode='int', # for keras.losses.SparseCategoricalCrossentropy()
        batch_size=batch_size,
        image_size=(resolution, resolution)
    )
    # tf.data.cache() is a bomb, causing excessive memory usage when training imagenet
    train_data = train_data.map(lambda x, y: (train_map(x), y), num_parallel_calls=tf.data.AUTOTUNE)
    train_data = train_data.prefetch(buffer_size=tf.data.AUTOTUNE)
    
    def val_map(image):
        transform = keras.Sequential([
            keras.layers.Rescaling(1/255),
            keras.layers.Normalization(mean=mean, variance=var)
        ])
        return transform(image)
    val_data = tf.keras.utils.image_dataset_from_directory(
        directory=f'{dir_path}/val',
        label_mode='int', # for keras.losses.SparseCategoricalCrossentropy()
        batch_size=batch_size,
        image_size=(resolution, resolution)
    )
    # tf.data.cache() is a bomb, causing excessive memory usage when training imagenet
    val_data = val_data.map(lambda x, y: (val_map(x), y), num_parallel_calls=tf.data.AUTOTUNE)
    val_data = val_data.prefetch(buffer_size=tf.data.AUTOTUNE)
    
    dataloader = {
        'train': train_data,
        'val': val_data
    }
    
    return dataloader

In [None]:
def build_resnet(
    dataset: str,
    depth: int,
    resolution: int,
    dropout_rate: float = 0.2,
) -> keras.Model:
    
    dataset_list = ['cifar', 'imagenet']
    depth_list = [18, 34]
    
    if dataset not in dataset_list:
        raise ValueError(f'Invalid dataset "{dataset}", it should be in {dataset_list}.')
    if depth not in depth_list:
        raise ValueError(f'Invalid depth "{depth}", it should be in {depth_list}.')
    
    if dataset == 'cifar':
        classes = 100
    elif dataset == 'imagenet':
        classes = 1000
    
    if depth == 18:
        stack_list = [2, 2, 2, 2]
    elif depth == 34:
        stack_list = [3, 4, 6, 3]
    
    def basic_block(x: keras.Input, filters: int, conv_shortcut: bool = False):
        if conv_shortcut:
            shortcut = keras.layers.Conv2D(
                filters, 1, strides=2, use_bias=False,
                kernel_initializer='he_normal',
                #kernel_regularizer=keras.regularizers.L2(1e-4)
            )(x)
            shortcut = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(shortcut)
            x = keras.layers.Conv2D(
                filters, 3, strides=2, padding='same', use_bias=False,
                kernel_initializer='he_normal',
                #kernel_regularizer=keras.regularizers.L2(1e-4)
            )(x)
        else:
            shortcut = x
            x = keras.layers.Conv2D(
                filters, 3, padding='same', use_bias=False,
                kernel_initializer='he_normal',
                #kernel_regularizer=keras.regularizers.L2(1e-4)
            )(x)
        x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
        x = keras.layers.Activation('relu')(x)
        x = keras.layers.Conv2D(
            filters, 3, padding='same', use_bias=False,
            kernel_initializer='he_normal',
            #kernel_regularizer=keras.regularizers.L2(1e-4)
        )(x)
        x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
        x = keras.layers.Add()([shortcut, x])
        x = keras.layers.Activation('relu')(x)
        return x
    
    def basic_stack(x: keras.Input, filters: int, stack: int, conv_shortcut: bool = False):
        for i in range(stack):
            if i == 0 and conv_shortcut == True:
                filters *= 2
                x = basic_block(x, filters, conv_shortcut)
            else:
                x = basic_block(x, filters)
        return x, filters
    
    inputs = keras.Input(shape=(resolution, resolution, 3))
    filters = 64
    ## random preprocessing layers
    x = keras.layers.RandomFlip('horizontal')(inputs)
    x = keras.layers.RandomTranslation(0.125, 0.125, fill_mode='constant')(x)
    ## stem
    if dataset == 'cifar':
        x = keras.layers.Conv2D(
            filters, 3, padding='same', use_bias=False,
            kernel_initializer='he_normal',
            #kernel_regularizer=keras.regularizers.L2(1e-4)
        )(x)
        x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
        x = keras.layers.Activation('relu')(x)
    elif dataset == 'imagenet':
        x = keras.layers.Conv2D(
            filters, 7, strides=2, padding='same', use_bias=False,
            kernel_initializer='he_normal',
            #kernel_regularizer=keras.regularizers.L2(1e-4)
        )(x)
        x = keras.layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5)(x)
        x = keras.layers.Activation('relu')(x)
        x = keras.layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x)
    ## trunk
    for i, stack in enumerate(stack_list):
        if i == 0:
            x, filters = basic_stack(x, filters, stack)
        else:
            x, filters = basic_stack(x, filters, stack, True)
    ## classifier
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(dropout_rate)(x)
    outputs = keras.layers.Dense(classes, activation='softmax')(x)
    
    return keras.Model(inputs=inputs, outputs=outputs)

In [None]:
dataloader = load_imagenet(resolution=resolution, batch_size=batch_size)

In [None]:
model = build_resnet(dataset=dataset, depth=depth, resolution=resolution, dropout_rate=dropout_rate)
#print(model.summary())

In [None]:
def lr_schedule(epoch, lr, milestones, gamma: float = 0.1):
    if epoch in milestones:
        lr *= gamma
    return lr

class TimeCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.history = []
    def on_epoch_begin(self, epoch, logs=None):
        self.time_epoch_begin = time.perf_counter()
    def on_epoch_end(self, epoch, logs=None):
        self.history.append(time.perf_counter() - self.time_epoch_begin)

lr_scheduler_callback = keras.callbacks.LearningRateScheduler(
    lambda x, y: lr_schedule(x, y, milestones=milestones, gamma=gamma)
)

time_callback = TimeCallback()

In [None]:
model.compile(
    optimizer=keras.optimizers.SGD(learning_rate=learning_rate, momentum=momentum),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy'],
    #jit_compile=JIT_COMPILE_FLAG
)

In [None]:
logs = model.fit(
    dataloader['train'],
    epochs=epochs,
    verbose='auto',
    callbacks=[time_callback, lr_scheduler_callback],
    validation_data=dataloader['val']
)
logs.history['t'] = time_callback.history