Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Added in Large-Batch SGD with a warmup, and a LARS startegy. Also add… #8918

Merged
merged 5 commits into from
Jan 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 102 additions & 36 deletions example/image-classification/common/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
""" example train fit utility """
import logging
import os
import time
import re
import math
import mxnet as mx


def _get_lr_scheduler(args, kv):
if 'lr_factor' not in args or args.lr_factor >= 1:
Expand All @@ -27,17 +31,26 @@ def _get_lr_scheduler(args, kv):
if 'dist' in args.kv_store:
epoch_size /= kv.num_workers
begin_epoch = args.load_epoch if args.load_epoch else 0
if 'pow' in args.lr_step_epochs:
lr = args.lr
max_up = args.num_epochs * epoch_size
pwr = float(re.sub('pow[- ]*', '', args.lr_step_epochs))
poly_sched = mx.lr_scheduler.PolyScheduler(max_up, lr, pwr)
return (lr, poly_sched)
step_epochs = [int(l) for l in args.lr_step_epochs.split(',')]
lr = args.lr
for s in step_epochs:
if begin_epoch >= s:
lr *= args.lr_factor
if lr != args.lr:
logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch))
logging.info('Adjust learning rate to %e for epoch %d',
lr, begin_epoch)

steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0]
steps = [epoch_size * (x - begin_epoch)
for x in step_epochs if x - begin_epoch > 0]
return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))


def _load_model(args, rank=0):
if 'load_epoch' not in args or args.load_epoch is None:
return (None, None, None)
Expand All @@ -50,6 +63,7 @@ def _load_model(args, rank=0):
logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
return (sym, arg_params, aux_params)


def _save_model(args, rank=0):
if args.model_prefix is None:
return None
Expand All @@ -59,6 +73,7 @@ def _save_model(args, rank=0):
return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % (
args.model_prefix, rank))


def add_fit_args(parser):
"""
parser : argparse.ArgumentParser
Expand All @@ -68,7 +83,8 @@ def add_fit_args(parser):
train.add_argument('--network', type=str,
help='the neural network to use')
train.add_argument('--num-layers', type=int,
help='number of layers in the neural network, required by some networks such as resnet')
help='number of layers in the neural network, \
required by some networks such as resnet')
train.add_argument('--gpus', type=str,
help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu')
train.add_argument('--kv-store', type=str, default='device',
Expand All @@ -81,6 +97,8 @@ def add_fit_args(parser):
help='the ratio to reduce lr on each step')
train.add_argument('--lr-step-epochs', type=str,
help='the epochs to reduce the lr, e.g. 30,60')
train.add_argument('--initializer', type=str, default='default',
help='the initializer type')
train.add_argument('--optimizer', type=str, default='sgd',
help='the optimizer type')
train.add_argument('--mom', type=float, default=0.9,
Expand Down Expand Up @@ -108,8 +126,16 @@ def add_fit_args(parser):
takes `2bit` or `none` for now')
train.add_argument('--gc-threshold', type=float, default=0.5,
help='threshold for 2bit gradient compression')
# additional parameters for large batch sgd
train.add_argument('--macrobatch-size', type=int, default=0,
help='distributed effective batch size')
train.add_argument('--warmup-epochs', type=int, default=5,
help='the epochs to ramp-up lr to scaled large-batch value')
train.add_argument('--warmup-strategy', type=str, default='linear',
help='the ramping-up strategy for large batch sgd')
return train


def fit(args, network, data_loader, **kwargs):
"""
train a model
Expand All @@ -135,14 +161,13 @@ def fit(args, network, data_loader, **kwargs):
for i, batch in enumerate(train):
for j in batch.data:
j.wait_to_read()
if (i+1) % args.disp_batches == 0:
logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % (
i, args.disp_batches*args.batch_size/(time.time()-tic)))
if (i + 1) % args.disp_batches == 0:
logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
args.disp_batches * args.batch_size / (time.time() - tic))
tic = time.time()

return


# load model
if 'arg_params' in kwargs and 'aux_params' in kwargs:
arg_params = kwargs['arg_params']
Expand All @@ -156,22 +181,22 @@ def fit(args, network, data_loader, **kwargs):
checkpoint = _save_model(args, kv.rank)

# devices for training
devs = mx.cpu() if args.gpus is None or args.gpus is '' else [
devs = mx.cpu() if args.gpus is None or args.gpus == "" else [
mx.gpu(int(i)) for i in args.gpus.split(',')]

# learning rate
lr, lr_scheduler = _get_lr_scheduler(args, kv)

# create model
model = mx.mod.Module(
context = devs,
symbol = network
context=devs,
symbol=network
)

lr_scheduler = lr_scheduler
lr_scheduler = lr_scheduler
optimizer_params = {
'learning_rate': lr,
'wd' : args.wd,
'wd': args.wd,
'lr_scheduler': lr_scheduler,
'multi_precision': True}

Expand All @@ -180,40 +205,81 @@ def fit(args, network, data_loader, **kwargs):
if args.optimizer in has_momentum:
optimizer_params['momentum'] = args.mom

monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 else None
monitor = mx.mon.Monitor(
args.monitor, pattern=".*") if args.monitor > 0 else None

if args.network == 'alexnet':
# AlexNet will not converge using Xavier
initializer = mx.init.Normal()
else:
initializer = mx.init.Xavier(
rnd_type='gaussian', factor_type="in", magnitude=2)
# A limited number of optimizers have a warmup period
has_warmup = {'lbsgd', 'lbnag'}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like LBNAG doesn't exist?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chaseadams509 can you please address above comments. thanks.

Copy link

@chaseadams509 chaseadams509 May 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, lbnag was an optimizer we were experimenting with with incorporating the large-batch algorithm with Nesterov accelerated gradient. LBNAG was not showing the desired improvements, so we hadn't pushed it yet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chaseadams509 How large batch sizes did you experiment with?

if args.optimizer in has_warmup:
if 'dist' in args.kv_store:
nworkers = kv.num_workers
else:
nworkers = 1
epoch_size = args.num_examples / args.batch_size / nworkers
if epoch_size < 1:
epoch_size = 1
macrobatch_size = args.macrobatch_size
if macrobatch_size < args.batch_size * nworkers:
macrobatch_size = args.batch_size * nworkers
#batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999)
batch_scale = math.ceil(
float(macrobatch_size) / args.batch_size / nworkers)
optimizer_params['updates_per_epoch'] = epoch_size
optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0
optimizer_params['batch_scale'] = batch_scale
optimizer_params['warmup_strategy'] = args.warmup_strategy
optimizer_params['warmup_epochs'] = args.warmup_epochs
optimizer_params['num_epochs'] = args.num_epochs

if args.initializer == 'default':
if args.network == 'alexnet':
# AlexNet will not converge using Xavier
initializer = mx.init.Normal()
else:
initializer = mx.init.Xavier(
rnd_type='gaussian', factor_type="in", magnitude=2)
# initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
elif args.initializer == 'xavier':
initializer = mx.init.Xavier()
elif args.initializer == 'msra':
initializer = mx.init.MSRAPrelu()
elif args.initializer == 'orthogonal':
initializer = mx.init.Orthogonal()
elif args.initializer == 'normal':
initializer = mx.init.Normal()
elif args.initializer == 'uniform':
initializer = mx.init.Uniform()
elif args.initializer == 'one':
initializer = mx.init.One()
elif args.initializer == 'zero':
initializer = mx.init.Zero()

# evaluation metrices
eval_metrics = ['accuracy']
if args.top_k > 0:
eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=args.top_k))
eval_metrics.append(mx.metric.create(
'top_k_accuracy', top_k=args.top_k))

# callbacks that run after each batch
batch_end_callbacks = [mx.callback.Speedometer(args.batch_size, args.disp_batches)]
batch_end_callbacks = [mx.callback.Speedometer(
args.batch_size, args.disp_batches)]
if 'batch_end_callback' in kwargs:
cbs = kwargs['batch_end_callback']
batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]

# run
model.fit(train,
begin_epoch = args.load_epoch if args.load_epoch else 0,
num_epoch = args.num_epochs,
eval_data = val,
eval_metric = eval_metrics,
kvstore = kv,
optimizer = args.optimizer,
optimizer_params = optimizer_params,
initializer = initializer,
arg_params = arg_params,
aux_params = aux_params,
batch_end_callback = batch_end_callbacks,
epoch_end_callback = checkpoint,
allow_missing = True,
monitor = monitor)
begin_epoch=args.load_epoch if args.load_epoch else 0,
num_epoch=args.num_epochs,
eval_data=val,
eval_metric=eval_metrics,
kvstore=kv,
optimizer=args.optimizer,
optimizer_params=optimizer_params,
initializer=initializer,
arg_params=arg_params,
aux_params=aux_params,
batch_end_callback=batch_end_callbacks,
epoch_end_callback=checkpoint,
allow_missing=True,
monitor=monitor)
32 changes: 32 additions & 0 deletions python/mxnet/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,35 @@ def __call__(self, num_update):
else:
return self.base_lr
return self.base_lr

class PolyScheduler(LRScheduler):
""" Reduce the learning rate by given a list of steps.

Calculate the new learning rate by::

base_lr * (1-nup/max_nup)^pwr
if nup < max_nup, 0 otherwise.

Parameters
----------
max_update: maximum number of updates before the decay reaches 0.
base_lr: base learning rate
pwr: power of the decay term as a funtion of the current number of updates.

"""

def __init__(self, max_update, base_lr=0.01, pwr=2):
super(PolyScheduler, self).__init__(base_lr)
assert isinstance(max_update, int)
if max_update < 1:
raise ValueError("maximum number of updates must be strictly positive")
self.base_lr_orig = self.base_lr
self.max_update = max_update
self.power = pwr
self.base_lr = self.base_lr_orig

def __call__(self, num_update):
if num_update <= self.max_update:
self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / float(self.max_update),
self.power)
return self.base_lr
Loading