In [None]:
import itertools
import os
import time
import tensorflow as tf
from tensorflow import keras
import tf_data_model

In [None]:
# for device
device_index = 0
MIXED_PRECISION_FLAG = True
## TF 2.6 does not support jit_compile in keras.Model.compile() yet.
## So, just set it to False.
## Another way is to use environment variable 'TF_XLA_FLAGS'.
## Set os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'.
JIT_COMPILE_FLAG = True

# for dataset
dataset = 'cifar100'
dir_path = '/ssd'
## r_b = [(160, 510), (224, 390), (288, 180)] without jit_compile
## r_b = [(160, 510), (224, 360), (288, 170)] with jit_compile
batch_size = 500
resolution_iter = itertools.cycle([24, 32]) # [160, 224, 288]

# for model
depth = 18
dropout_rate_iter = itertools.cycle([0.1, 0.2]) # [0.1, 0.2, 0.3]

# for training
learning_rate = 1e-1
momentum = 0.9
epochs = 90 # 120
modify_freq = 45 # 10
## TF 2.6 does not support weight_decay in keras.optimizers.SGD() yet.
## So, it might be set in the model.
weight_decay = 1e-4

# for learning rate scheduler
milestones = [30, 60] # [30, 60, 90]
gamma = 0.1

# create for using later
def get_modify_parameters(resolution_iter: itertools.cycle, dropout_rate_iter: itertools.cycle):
    resolution = next(resolution_iter)
    dropout_rate = next(dropout_rate_iter)
    print(f'resolution: {resolution}, dropout_rate: {dropout_rate}')
    return resolution, dropout_rate
logs = {}
epoch_index = len(logs['t']) if 't' in logs.keys() else 0

######## for testing: BS and LR are propotional
#learning_rate *= batch_size / (32 * 8) # PyTorch uses batch size 32 with 8 GPUs

In [None]:
######## for testing: TF 2.6 jit_compile, it must call berfore any tensorflow function.
## For unknown reasons, '--tf_xla_cpu_global_jit' only supports the first GPU.
## Otherwise an error will result.
if JIT_COMPILE_FLAG:
    if device_index == 0:
        # can not use the condition 'len(tf.config.list_physical_devices('GPU')) == 1'
        # since it calls tf function...
        os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
    else:
        os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2'

if MIXED_PRECISION_FLAG:
    policy = keras.mixed_precision.Policy('mixed_float16')
    keras.mixed_precision.set_global_policy(policy)
    print(f'Policy: {policy.name}')
    print(f'Compute dtype: {policy.compute_dtype}')
    print(f'Variable dtype: {policy.variable_dtype}')

print('----')
print(f'MIXED_PRECISION: {MIXED_PRECISION_FLAG}')
print(f'JIT_COMPILE: {JIT_COMPILE_FLAG}')
print('----')

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
print(f'Numbers of Physical Devices: {len(physical_devices)}')
tf.config.set_visible_devices(physical_devices[device_index], 'GPU')
tf.config.experimental.set_memory_growth(physical_devices[device_index], True)
print(f'Using Device: {physical_devices[device_index]}')

In [None]:
def lr_schedule(epoch, lr, milestones, gamma: float = 0.1):
    if epoch in milestones:
        lr *= gamma
    return lr

class TimeCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.history = []
    def on_epoch_begin(self, epoch, logs=None):
        self.time_epoch_begin = time.perf_counter()
    def on_epoch_end(self, epoch, logs=None):
        self.history.append(time.perf_counter() - self.time_epoch_begin)

lr_scheduler_callback = keras.callbacks.LearningRateScheduler(
    lambda x, y: lr_schedule(x, y, milestones=milestones, gamma=gamma)
)

time_callback = TimeCallback()

In [None]:
while epoch_index < epochs:
    resolution, dropout_rate = get_modify_parameters(resolution_iter, dropout_rate_iter)
    dataloader = tf_data_model.load_cifar(
        resolution=resolution, batch_size=batch_size, dataset=dataset
    )
    model = tf_data_model.modify_resnet(
        old_model=model if epoch_index else None,
        dataset=dataset,
        depth=depth,
        dropout_rate=dropout_rate,
        resolution=resolution
    )
    
    if 'lr' in logs.keys():
        learning_rate = logs['lr'][-1]
    model.compile(
        optimizer=keras.optimizers.SGD(
            learning_rate=learning_rate,
            momentum=momentum,
            #weight_decay=weight_decay
        ),
        loss=keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy'],
        #jit_compile=JIT_COMPILE_FLAG
    )
    print(f'learning_rate: {learning_rate}')
    
    temp_logs = model.fit(
        dataloader['train'],
        epochs=epoch_index + modify_freq,
        verbose='auto',
        callbacks=[time_callback, lr_scheduler_callback],
        validation_data=dataloader['val'],
        initial_epoch=epoch_index
    )
    temp_logs.history['t'] = time_callback.history

    for key, value in temp_logs.history.items():
        if key in logs:
            logs[key] += value
        else:
            logs[key] = value
    
    epoch_index = len(logs['t']) if 't' in logs.keys() else 0

- cifar100
    - for each lr, train all resolution
        - 90 epoch, resolution = [24, 32], milestones = [30, 60]
        - loss: 0.7437 - accuracy: 0.9930 - val_loss: 2.0937 - val_accuracy: 0.6790
    - seperate lr and resolution
        - 90 epoch, resolution = [24, 32], res_milestones = [45], lr_milestones = [30, 60]
        - loss: 0.9103 - accuracy: 0.9633 - val_loss: 2.3247 - val_accuracy: 0.6400