In [None]:
import argparse
import itertools
import os
import shutil
import sys
import time

import numpy as np
import tensorflow as tf
from tensorflow import keras

import tf_data_model

In [None]:
class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.MetavarTypeHelpFormatter):
    pass

parser = argparse.ArgumentParser(
    description='Progressive Training by Using Tensorflow',
    epilog=(
        'The parser only supports high-level control options. '
        'If the user wants to adjust low-level control options, modify the code. '
        'Required settings [--dataset, --path] or [-d, -p], optional settings [--amp, --xla].'
    ),
    formatter_class=CustomFormatter,
)

# high-level control options
## device
parser.add_argument(
    '--multi-gpu', # data perallel not complete yet, set "False"
    action='store_true',
    help='training using multiple GPUs, not yet completed',
)
parser.add_argument(
    '--device-index', # [0, -1]
    type=int,
    default=0,
    help='the index of the GPU used to run the program',
)
parser.add_argument(
    '--mixed-precision', '--amp',
    dest='amp',
    action='store_true',
    help='train with mixed precision (amp)',
)
parser.add_argument(
    '--jit-compile', '--xla',
    dest='xla',
    action='store_true',
    help='train with jit compile (xla)',
)
## dataset and model
parser.add_argument(
    '--dataset', '--data', '-d',
    type=str,
    help='dataset to train, currently supports ["cifar10", "cifar100", "imagenet"]',
)
parser.add_argument(
    '--dir-path', '--path', '-p',
    type=str,
    help='path to the dataset directory',
)
parser.add_argument(
    '--depth',
    type=int,
    default=18,
    help='resnet depth, currently supports [18, 34]',
)
## training
parser.add_argument(
    '--no-cycle',
    dest='cycle',
    action='store_false',
    help='do not use all image resolutions with different learning rates',
)
## output file
parser.add_argument(
    '--comments', '-c',
    type=str,
    help='add additional comments on filename',
)
parser.add_argument(
    '--no-temp',
    dest='temp',
    action='store_false',
    help='do not save the temporary state during training, including "_model" and ".npy"',
)
parser.add_argument(
    '--no-save',
    dest='save',
    action='store_false',
    help='do not save the training results, including "_model" and ".npy"',
)

# check the file type is '.py' or '.ipynb'
## parse args of '.ipynb' from here
## ex. ['--dataset=imagenet', '--path=./dataset', '--amp', '--xla']
ipynb_args = ['-d=cifar100', '-p=/ssd', '--amp', '--xla', '--no-temp', '--no-save']
args = (
    parser.parse_args(ipynb_args)
    if len(sys.argv) > 2 and sys.argv[1] == '-f' and '.json' in sys.argv[2]
    else parser.parse_args()
)
print('----')
print(args)
print('----')

In [None]:
# low-level control options
## device
## dataset and model
if args.dataset == 'cifar10' or args.dataset == 'cifar100': 
    batch_size_ls = [1000, 500]
    resolution_ls = [24, 32]
    dropout_rate_ls = [0.1, 0.2]
elif args.dataset == 'imagenet':
    batch_size_ls = [510, 360, 170]
    resolution_ls = [160, 224, 288]
    dropout_rate_ls = [0.1, 0.2, 0.3]
else:
    raise ValueError(f'Invalid dataset "{args.dataset}".')
## training
learning_rate = 1e-1
momentum = 0.9
### TF 2.6 does not support weight_decay in keras.optimizers.SGD() yet
### another way is to modify the model, which is set in tf_data_model.py
weight_decay = 1e-4 # [None, 1e-4]
epochs = 90
steps = 3
gamma = 0.1
## output file

# adaptive options
## device
## dataset and model
batch_size_iter = itertools.cycle(batch_size_ls)
resolution_iter = itertools.cycle(resolution_ls)
dropout_rate_iter = itertools.cycle(dropout_rate_ls)
## training
milestones = list(epochs // steps * i for i in range(1, steps + 1))
modify_freq = (milestones[0] if args.cycle else epochs) // len(resolution_ls)
if modify_freq == 0:
    raise ValueError('"modify_freq" is "0"')
## output file
outfile = (
    f'{args.dataset}_resnet{args.depth}_{epochs}'
    f'{"_amp" if args.amp else ""}'
    f'{"_xla" if args.xla else ""}'
    f'{"" if args.cycle else "_noCycle"}'
    f'{"_" + args.comments if args.comments else ""}'
)
tempfile = f'temp_{outfile}'

# experimental: BS and LR are propotional
#learning_rate *= batch_size / (32 * 8) # PyTorch uses batch size 32 with 8 GPUs

In [None]:
# mixed_precision and jit_compile
### for unknown reasons, '--tf_xla_cpu_global_jit' only supports the first GPU
if args.xla:
    os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
    tf.config.optimizer.set_jit('autoclustering')
    print(f'Optimizer set_jit: "{tf.config.optimizer.get_jit()}"')

if args.amp:
    policy = keras.mixed_precision.Policy('mixed_float16')
    keras.mixed_precision.set_global_policy(policy)
    print(f'Policy: {policy.name}')
    print(f'Compute dtype: {policy.compute_dtype}')
    print(f'Variable dtype: {policy.variable_dtype}')

print('----')
print(f'MIXED_PRECISION: {args.amp}')
print(f'JIT_COMPILE: {args.xla}')
print('----')

In [None]:
# GPU initialization, data perallel not complete yet
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(
    physical_devices[:] if args.multi_gpu else physical_devices[args.device_index],
    'GPU'
)
for device in tf.config.get_visible_devices('GPU'):
    tf.config.experimental.set_memory_growth(device, True)
print('----')
print(f'The Number of Available Physical Devices: {len(physical_devices)}')
print(f'Using Devices: {tf.config.get_visible_devices("GPU")}')
print('----')

In [None]:
def lr_schedule(epoch, lr, milestones, gamma: float = 0.1):
    if epoch in milestones:
        lr *= gamma
    return lr

class TimeCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.history = []
    def on_epoch_begin(self, epoch, logs=None):
        self.time_epoch_begin = time.perf_counter()
    def on_epoch_end(self, epoch, logs=None):
        self.history.append(time.perf_counter() - self.time_epoch_begin)

def get_modified_parameters(
    batch_size_iter: itertools.cycle,
    resolution_iter: itertools.cycle,
    dropout_rate_iter: itertools.cycle,
    learning_rate: float,
    logs: dict,
    modify_freq: int
):
    batch_size = next(batch_size_iter)
    resolution = next(resolution_iter)
    dropout_rate = next(dropout_rate_iter)
    logs['batch_size'] += list(itertools.repeat(batch_size, modify_freq))
    logs['resolution'] += list(itertools.repeat(resolution, modify_freq))
    logs['dropout_rate'] += list(itertools.repeat(dropout_rate, modify_freq))
    if 'lr' in logs.keys():
        learning_rate = logs['lr'][-1]
    print(f'batch_size: {batch_size}')
    print(f'resolution: {resolution}, dropout_rate: {dropout_rate}')
    print(f'learning_rate (previous step): {learning_rate: g}')
    return batch_size, resolution, dropout_rate, learning_rate

In [None]:
lr_scheduler_callback = keras.callbacks.LearningRateScheduler(
    lambda x, y: lr_schedule(x, y, milestones=milestones, gamma=gamma)
)
time_callback = TimeCallback()
logs = {
    'batch_size': [],
    'resolution': [],
    'dropout_rate': [],
    't': []
}
epoch_index = len(logs['t'])

In [None]:
while epoch_index < epochs:
    # modify data and model
    batch_size, resolution, dropout_rate, learning_rate = get_modified_parameters(
        batch_size_iter, resolution_iter, dropout_rate_iter, learning_rate, logs, modify_freq
    )
    dataloader = tf_data_model.load_data(
        resolution=resolution,
        batch_size=batch_size,
        dataset=args.dataset,
        dir_path=args.dir_path
    )
    model = tf_data_model.modify_resnet(
        dataset=args.dataset,
        depth=args.depth,
        dropout_rate=dropout_rate,
        resolution=resolution,
        old_model=model if epoch_index else None
    )
    
    # compile model
    model.compile(
        optimizer=keras.optimizers.experimental.SGD(
            learning_rate=learning_rate,
            momentum=momentum,
            weight_decay=None if tf_data_model.OLD_VERSION else weight_decay
            ## 'decay_steps' in 'keras.optimizers.schedules.LearningRateSchedule()'
            ## means batches instead of epochs, which is a fine grained value
            ## so try to use 'keras.callbacks.LearningRateScheduler()'
            ## to set the learning rate decay value each epoch.
        ),
        loss=keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy']
    )
    
    # training step, record temporary logs
    temp_logs = model.fit(
        dataloader['train'],
        epochs=min(epoch_index + modify_freq, epochs),
        verbose='auto',
        callbacks=[time_callback, lr_scheduler_callback],
        validation_data=dataloader['val'],
        initial_epoch=epoch_index
    )
    temp_logs.history['t'] = time_callback.history

    # concatenate temporary logs to the logs
    for key, value in temp_logs.history.items():
        if key in logs:
            logs[key] += value
        else:
            logs[key] = value
    
    # update epoch index of the training
    epoch_index = len(logs['t'])

    # save the temporary state (aka checkpoint)
    if args.temp and epoch_index < epochs:
        keras.models.save_model(model, f'{tempfile}_model')
        np.save(f'{tempfile}.npy', logs)
        print(f'Save The Temporary State at Epoch {epoch_index}')

In [None]:
if args.save:
    keras.models.save_model(model, f'{outfile}_model')
    np.save(f'{outfile}.npy', logs)
    print(f'Save Model: {outfile}_model')
    print(f'Save Logs: {outfile}.npy')

if args.temp:
    shutil.rmtree(f'{tempfile}_model', ignore_errors=True)
    os.remove(f'{tempfile}.npy') if os.path.isfile(f'{tempfile}.npy') else None
    print('Clean The Temporary State')

print('Training Completed')