In [None]:
import argparse
import itertools
import os
import shutil
import sys
import time

import numpy as np
import tensorflow as tf
from tensorflow import keras

import tf_data_model

In [None]:
class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.MetavarTypeHelpFormatter):
    pass
parser = argparse.ArgumentParser(
    description='Progressive Training by Using Tensorflow',
    epilog=(
        'The parser only supports high-level control options. '
        'If the user wants to adjust low-level control options, modify the code. '
        'Required sttings [--dataset, --path] or [-d, -p], optional settings [--amp, --xla].'
    ),
    formatter_class=CustomFormatter,
)

# high-level control options
parser.add_argument(
    '--multi-gpu',
    action='store_true',
    help='training using multiple GPUs, not yet completed',
)
parser.add_argument(
    '--mixed-precision', '--amp',
    action='store_true',
    help='train with mixed precision (amp)',
)
parser.add_argument(
    '--jit-compile', '--xla',
    action='store_true',
    help='train with jit compile (xla)',
)
parser.add_argument(
    '--dataset', '--data', '-d',
    type=str,
    help='dataset to train, currently supports ["cifar10", "cifar100", "imagenet"]',
)
parser.add_argument(
    '--dir-path', '--path', '-p',
    type=str,
    help='path to the dataset directory',
)
parser.add_argument(
    '--comments', '-c',
    type=str,
    help='add additional comments on filename',
)
parser.add_argument(
    '--no-cycle',
    dest='cycle',
    action='store_false',
    help='do not use all image resolutions with different learning rates',
)
parser.add_argument(
    '--no-temp',
    dest='temp',
    action='store_false',
    help='do not save the temporary state during training, including "_model" and ".npy"',
)
parser.add_argument(
    '--no-save',
    dest='save',
    action='store_false',
    help='do not save the training results, including "_model" and ".npy"',
)

# check the file type is '.py' or '.ipynb'
## parse args of '.ipynb' from here
## ex. ['--dataset=imagenet', '--path=./dataset', '--amp', '--xla']
ipynb_args = ['-d=cifar100', '-p=~/ssd', '--amp', '--xla']
args = (
    parser.parse_args(ipynb_args)
    if len(sys.argv) > 2 and sys.argv[1] == '-f' and '.json' in sys.argv[2]
    else parser.parse_args()
)
print(args)

In [None]:
# high-level control options
## device
### data perallel not complete yet, set "False"
### there exists a lot of problems to solve, which are incompatible:
### - tf.distribute.MirroredStrategy()
### - keras.backend.clear_session()
### - jit_compile
MULTI_GPU = args.multi_gpu
MIXED_PRECISION = args.mixed_precision
### TF 2.6 does not support jit_compile in keras.Model.compile() yet
### another way is to use environment variable 'TF_XLA_FLAGS'
### set os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
JIT_COMPILE = args.jit_compile
## dataset and model
dataset = args.dataset
dir_path = args.dir_path
## training
CYCLE = args.cycle
## output file
TEMP = args.temp
SAVE = args.save

# low-level control options
## device
device_index = 0
## dataset and model
depth = 18
if dataset == 'cifar10' or dataset == 'cifar100': 
    batch_size_ls = [1000, 500]
    resolution_ls = [24, 32]
    dropout_rate_ls = [0.1, 0.2]
elif dataset == 'imagenet':
    batch_size_ls = [510, 360, 170]
    resolution_ls = [160, 224, 288]
    dropout_rate_ls = [0.1, 0.2, 0.3]
else:
    raise ValueError(f'Invalid dataset "{dataset}".')
## training
learning_rate = 1e-1
momentum = 0.9
### TF 2.6 does not support weight_decay in keras.optimizers.SGD() yet
### another way is to modify the model, which is set in tf_data_model.py
weight_decay = 1e-4 # [None, 1e-4]
epochs = 90
step = 3
gamma = 0.1
## output file

# adaptive options
## device
## dataset and model
batch_size_iter = itertools.cycle(batch_size_ls)
resolution_iter = itertools.cycle(resolution_ls)
dropout_rate_iter = itertools.cycle(dropout_rate_ls)
## training
milestones = list(int(epochs * i / step) for i in range(1, step))
modify_freq = int((milestones[0] if CYCLE else epochs) / len(resolution_ls))
## output file
outfile = (
    f'{dataset}_resnet{depth}_{epochs}'
    f'{"_amp" if MIXED_PRECISION else ""}'
    f'{"_xla" if JIT_COMPILE else ""}'
    f'{"" if CYCLE else "_nocycle"}'
    f'{"_" + args.comments if args.comments else ""}'
)
tempfile = f'temp_{outfile}'

# experimental: BS and LR are propotional
#learning_rate *= batch_size / (32 * 8) # PyTorch uses batch size 32 with 8 GPUs

In [None]:
# mixed_precision and jit_compile
## experimental: TF 2.6 jit_compile
### it must be called berfore any tensorflow function
### for unknown reasons, '--tf_xla_cpu_global_jit' only supports the first GPU.
### otherwise an error will result.
os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit' if JIT_COMPILE else ''
tf.config.optimizer.set_jit('autoclustering' if JIT_COMPILE else False)
print(f'Optimizer set_jit: "{tf.config.optimizer.get_jit()}"')

if MIXED_PRECISION:
    policy = keras.mixed_precision.Policy('mixed_float16')
    keras.mixed_precision.set_global_policy(policy)
    print(f'Policy: {policy.name}')
    print(f'Compute dtype: {policy.compute_dtype}')
    print(f'Variable dtype: {policy.variable_dtype}')

print('----')
print(f'MIXED_PRECISION: {MIXED_PRECISION}')
print(f'JIT_COMPILE: {JIT_COMPILE}')
print('----')

In [None]:
######## data perallel not complete yet
# GPU initialization
physical_devices = tf.config.list_physical_devices('GPU')
print(f'The Number of Available Physical Devices: {len(physical_devices)}')
tf.config.set_visible_devices(
    physical_devices[:] if MULTI_GPU else physical_devices[device_index],
    'GPU'
)
for device in tf.config.get_visible_devices('GPU'):
    tf.config.experimental.set_memory_growth(device, True)
print(f'Using Devices: {tf.config.get_visible_devices("GPU")}')

In [None]:
def lr_schedule(epoch, lr, milestones, gamma: float = 0.1):
    if epoch in milestones:
        lr *= gamma
    return lr

class TimeCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.history = []
    def on_epoch_begin(self, epoch, logs=None):
        self.time_epoch_begin = time.perf_counter()
    def on_epoch_end(self, epoch, logs=None):
        self.history.append(time.perf_counter() - self.time_epoch_begin)

def get_modified_parameters(
    batch_size_iter: itertools.cycle,
    resolution_iter: itertools.cycle,
    dropout_rate_iter: itertools.cycle,
    learning_rate: float,
    logs: dict,
    modify_freq: int
):
    batch_size = next(batch_size_iter)
    resolution = next(resolution_iter)
    dropout_rate = next(dropout_rate_iter)
    logs['batch_size'] += list(itertools.repeat(batch_size, modify_freq))
    logs['resolution'] += list(itertools.repeat(resolution, modify_freq))
    logs['dropout_rate'] += list(itertools.repeat(dropout_rate, modify_freq))
    if 'lr' in logs.keys():
        learning_rate = logs['lr'][-1]
    print(f'batch_size: {batch_size}')
    print(f'resolution: {resolution}, dropout_rate: {dropout_rate}')
    print(f'learning_rate (previous step): {learning_rate: g}')
    return batch_size, resolution, dropout_rate, learning_rate

In [None]:
lr_scheduler_callback = keras.callbacks.LearningRateScheduler(
    lambda x, y: lr_schedule(x, y, milestones=milestones, gamma=gamma)
)
time_callback = TimeCallback()
logs = {
    'batch_size': [],
    'resolution': [],
    'dropout_rate': [],
    't': []
}
epoch_index = len(logs['t'])

In [None]:
while epoch_index < epochs:
    # modify data and model
    batch_size, resolution, dropout_rate, learning_rate = get_modified_parameters(
        batch_size_iter, resolution_iter, dropout_rate_iter, learning_rate, logs, modify_freq
    )
    dataloader = tf_data_model.load_data(
        resolution=resolution,
        batch_size=batch_size,
        dataset=dataset,
        dir_path=dir_path
    )
    model = tf_data_model.modify_resnet(
        dataset=dataset,
        depth=depth,
        dropout_rate=dropout_rate,
        resolution=resolution,
        old_model=model if epoch_index else None
    )
    
    # compile model
    model.compile(
        optimizer=keras.optimizers.experimental.SGD(
            learning_rate=learning_rate,
            momentum=momentum,
            weight_decay=None if tf_data_model.OLD_VERSION else weight_decay
            ## 'decay_steps' in 'keras.optimizers.schedules.LearningRateSchedule()'
            ## means batches instead of epochs, which is a fine grained value
            ## so try to use 'keras.callbacks.LearningRateScheduler()'
            ## to set the learning rate decay value each epoch.
        ),
        loss=keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy']
    )
    
    # training step, record temporary logs
    temp_logs = model.fit(
        dataloader['train'],
        epochs=min(epoch_index + modify_freq, epochs),
        verbose='auto',
        callbacks=[time_callback, lr_scheduler_callback],
        validation_data=dataloader['val'],
        initial_epoch=epoch_index
    )
    temp_logs.history['t'] = time_callback.history

    # concatenate temporary logs to the logs
    for key, value in temp_logs.history.items():
        if key in logs:
            logs[key] += value
        else:
            logs[key] = value
    
    # update epoch index of the training
    epoch_index = len(logs['t'])

    # save the temporary state (aka checkpoint)
    if TEMP and epoch_index < epochs:
        keras.models.save_model(model, f'{tempfile}_model')
        np.save(f'{tempfile}.npy', logs)
        print(f'Save The Temporary State at Epoch {epoch_index}')

In [None]:
if SAVE:
    keras.models.save_model(model, f'{outfile}_model')
    np.save(f'{outfile}.npy', logs)
    print(f'Save Model: {outfile}_model')
    print(f'Save Logs: {outfile}.npy')

if TEMP:
    shutil.rmtree(f'{tempfile}_model', ignore_errors=True)
    os.remove(f'{tempfile}.npy') if os.path.isfile(f'{tempfile}.npy') else None
    print('Clean The Temporary State')

print('Training Completed')