In [None]:
import itertools
import os
import time
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tf_data_model

In [None]:
# for device
MULTI_GPU = False ######## data perallel not complete yet, set "False"
device_index = 0
MIXED_PRECISION_FLAG = True
## TF 2.6 does not support jit_compile in keras.Model.compile() yet.
## So, just set it to False.
## Another way is to use environment variable 'TF_XLA_FLAGS'.
## Set os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'.
JIT_COMPILE_FLAG = True

# for dataset
dataset = 'imagenet'
dir_path = '/ssd'
## r_b = [(160, 510), (224, 390), (288, 180)] without jit_compile on GTX-1080
## r_b = [(160, 510), (224, 360), (288, 170)] with jit_compile on GTX-1080
batch_size_iter = itertools.cycle([510, 360, 170]) # [1000, 500], [510, 360, 170]
resolution_iter = itertools.cycle([160, 224, 288]) # [24, 32], [160, 224, 288]

# for model
depth = 18
dropout_rate_iter = itertools.cycle([0.1, 0.2, 0.3]) # [0.1, 0.2], [0.1, 0.2, 0.3]

# for training
learning_rate = 1e-1
momentum = 0.9
epochs = 90 # 90
## TF 2.6 does not support weight_decay in keras.optimizers.SGD() yet.
## So, it might be set in the model.
## Setting in tf_data_model.py
weight_decay = 1e-4
## learning rate scheduler
milestones = list(int(epochs * i / 3) for i in range(1, 3)) # [30, 60]
gamma = 0.1
## cycle or iter
### note to divide by 2 or 3, which modify_freq should be 15 or 10
CYCLE = True
modify_freq = int((milestones[0] if CYCLE else epochs) / 3)

# for output file
SAVEFILE = True
outfile = f'{dataset}_resnet{depth}_{"cycle" if CYCLE else "iter"}_{epochs}.npy'

######## testing: BS and LR are propotional
#learning_rate *= batch_size / (32 * 8) # PyTorch uses batch size 32 with 8 GPUs

In [None]:
######## for testing: TF 2.6 jit_compile, it must call berfore any tensorflow function.
## For unknown reasons, '--tf_xla_cpu_global_jit' only supports the first GPU.
## Otherwise an error will result.
if JIT_COMPILE_FLAG:
    if device_index == 0 or MULTI_GPU:
        # can not use the condition 'len(tf.config.list_physical_devices('GPU')) == 1'
        # since it calls tf function...
        os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
    else:
        os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2'

if MIXED_PRECISION_FLAG:
    policy = keras.mixed_precision.Policy('mixed_float16')
    keras.mixed_precision.set_global_policy(policy)
    print(f'Policy: {policy.name}')
    print(f'Compute dtype: {policy.compute_dtype}')
    print(f'Variable dtype: {policy.variable_dtype}')

print('----')
print(f'MIXED_PRECISION: {MIXED_PRECISION_FLAG}')
print(f'JIT_COMPILE: {JIT_COMPILE_FLAG}')
print('----')

In [None]:
######## data perallel not complete yet
physical_devices = tf.config.list_physical_devices('GPU')
print(f'Numbers of Physical Devices: {len(physical_devices)}')
tf.config.set_visible_devices(
    physical_devices[:] if MULTI_GPU else physical_devices[device_index],
    'GPU'
)
for device in tf.config.get_visible_devices('GPU'):
    tf.config.experimental.set_memory_growth(device, True)
print(f'Using Devices: {tf.config.get_visible_devices("GPU")}')

In [None]:
def lr_schedule(epoch, lr, milestones, gamma: float = 0.1):
    if epoch in milestones:
        lr *= gamma
    return lr

class TimeCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.history = []
    def on_epoch_begin(self, epoch, logs=None):
        self.time_epoch_begin = time.perf_counter()
    def on_epoch_end(self, epoch, logs=None):
        self.history.append(time.perf_counter() - self.time_epoch_begin)

lr_scheduler_callback = keras.callbacks.LearningRateScheduler(
    lambda x, y: lr_schedule(x, y, milestones=milestones, gamma=gamma)
)

time_callback = TimeCallback()

In [None]:
def get_modified_parameters(
    batch_size_iter: itertools.cycle,
    resolution_iter: itertools.cycle,
    dropout_rate_iter: itertools.cycle,
    learning_rate: float,
    logs: dict,
    modify_freq: int
):
    batch_size = next(batch_size_iter)
    resolution = next(resolution_iter)
    dropout_rate = next(dropout_rate_iter)
    logs['batch_size'] += list(itertools.repeat(batch_size, modify_freq))
    logs['resolution'] += list(itertools.repeat(resolution, modify_freq))
    logs['dropout_rate'] += list(itertools.repeat(dropout_rate, modify_freq))
    if 'lr' in logs.keys():
        learning_rate = logs['lr'][-1]
    print(f'batch_size: {batch_size}')
    print(f'resolution: {resolution}, dropout_rate: {dropout_rate}')
    print(f'learning_rate: {learning_rate}')
    return batch_size, resolution, dropout_rate, learning_rate

logs = {
    'batch_size': [],
    'resolution': [],
    'dropout_rate': [],
    't':[]
}
epoch_index = len(logs['t'])

In [None]:
while epoch_index < epochs:
    # modify data and model
    batch_size, resolution, dropout_rate, learning_rate = get_modified_parameters(
        batch_size_iter, resolution_iter, dropout_rate_iter, learning_rate, logs, modify_freq
    )
    dataloader = tf_data_model.load_data(
        resolution=resolution,
        batch_size=batch_size,
        dataset=dataset,
        dir_path=dir_path
    )
    model = tf_data_model.modify_resnet(
        old_model=model if epoch_index else None,
        dataset=dataset,
        depth=depth,
        dropout_rate=dropout_rate,
        resolution=resolution
    )
    
    # compile model
    model.compile(
        optimizer=keras.optimizers.SGD(
            learning_rate=learning_rate,
            momentum=momentum,
            #weight_decay=weight_decay
            # `decay_steps` in `keras.optimizers.schedules.LearningRateSchedule`
            # means batches instead of epochs, which is a fine grained value,
            # so try to use `keras.callbacks.LearningRateScheduler`
            # to set the learning rate decay value each epoch.
        ),
        loss=keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy'],
        #jit_compile=JIT_COMPILE_FLAG
    )
    
    # training step, record temporary logs
    temp_logs = model.fit(
        dataloader['train'],
        epochs=min(epoch_index + modify_freq, epochs),
        verbose='auto',
        callbacks=[time_callback, lr_scheduler_callback],
        validation_data=dataloader['val'],
        initial_epoch=epoch_index
    )
    temp_logs.history['t'] = time_callback.history

    # concatenate temporary logs to the logs
    for key, value in temp_logs.history.items():
        if key in logs:
            logs[key] += value
        else:
            logs[key] = value
    
    # update epoch index of the training
    epoch_index = len(logs['t'])

In [None]:
if SAVEFILE:
    np.save(outfile, logs)