Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix pylint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ashokei committed Dec 5, 2017
1 parent 382ecde commit 20deebf
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 57 deletions.
10 changes: 6 additions & 4 deletions python/mxnet/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,24 @@ class PolyScheduler(LRScheduler):
"""

def __init__(self, num_update, max_update, base_lr=0.01, pwr = 2):
def __init__(self, num_update, max_update, base_lr=0.01, pwr=2):
super(PolyScheduler, self).__init__(base_lr)
assert isinstance(max_update, int)
if max_update < 1:
raise ValueError("maximum number of updates must be strictly positive")
raise ValueError("maximum number of updates must be strictly positive")
self.base_lr_orig = self.base_lr
self.max_update = max_update
self.power = pwr
self.count = num_update
if num_update <= max_update:
self.base_lr = self.base_lr_orig*pow(1.0 - float(num_update)/float(self.max_update), self.power)
self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / float(self.max_update),
self.power)
else:
self.base_lr = self.base_lr_orig

def __call__(self, num_update):
if num_update <= self.max_update:
self.base_lr = self.base_lr_orig*pow(1.0 - float(num_update)/float(self.max_update), self.power)
self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / float(self.max_update),
self.power)
self.count += 1
return self.base_lr
102 changes: 49 additions & 53 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,27 +562,30 @@ class LBSGD(Optimizer):
updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)
begin_epoch: unsigned, default 0, starting epoch.
"""
def __init__(self, momentum=0.0, multi_precision=False, warmup_strategy='linear', warmup_epochs=5, batch_scale=1, updates_per_epoch=32, begin_epoch=0, num_epochs=60, **kwargs):

def __init__(self, momentum=0.0, multi_precision=False, warmup_strategy='linear',
warmup_epochs=5, batch_scale=1, updates_per_epoch=32, begin_epoch=0, num_epochs=60,
**kwargs):
super(LBSGD, self).__init__(**kwargs)
logging.info('Running Large-Batch SGD Algorithm')
logging.info('(Batch_scale=%f, warmup_epochs=%d, warmup_strategy=%s, updates_per_epoch=%d)', batch_scale, warmup_epochs, warmup_strategy, updates_per_epoch)
logging.info('(Batch_scale=%f, warmup_epochs=%d, warmup_strategy=%s, updates_per_epoch=%d)',
batch_scale, warmup_epochs, warmup_strategy, updates_per_epoch)
self.momentum = momentum
self.multi_precision = multi_precision
# new user parameters for large batch
self.warmup_strategy = warmup_strategy
self.warmup_epochs = warmup_epochs
self.batch_scale = batch_scale
self.updates_per_epoch = updates_per_epoch
self.init_updates = begin_epoch*updates_per_epoch
self.num_epochs = num_epochs
self.init_updates = begin_epoch * updates_per_epoch
self.num_epochs = num_epochs
# addl internal usage parameters and storage
self.lbmult = 1
self.cumgrads = {}
# for adaptive lr
self.adaptive = False
self.admult = 1 # adaptation constant
self.admult = 1 # adaptation constant


def create_state(self, index, weight):
momentum = None
weight_master_copy = None
Expand All @@ -607,32 +610,31 @@ def _get_lbmult(self, nup):
"""
nwup = self.warmup_epochs * self.updates_per_epoch
strategy = self.warmup_strategy
maxmult = float(self.batch_scale);
if nup >= nwup :
maxmult = float(self.batch_scale)
if nup >= nwup:
mult = maxmult
elif nwup <= 1:
mult = 1.0
else:
if (strategy == 'linear'):
mult = 1.0 + (maxmult-1) * nup/nwup
if (strategy == 'linear'):
mult = 1.0 + (maxmult - 1) * nup / nwup
elif (strategy == 'power2'):
mult = 1.0 + (maxmult-1) * (nup*nup)/(nwup*nwup)
mult = 1.0 + (maxmult - 1) * (nup * nup) / (nwup * nwup)
elif (strategy == 'power3'):
mult = 1.0 + (maxmult-1) * (nup*nup)/(nwup*nwup)
mult = 1.0 + (maxmult - 1) * (nup * nup) / (nwup * nwup)
elif (strategy == 'sqrt'):
mult = 1.0 + (maxmult-1) * math.sqrt(float(nup)/nwup)
mult = 1.0 + (maxmult - 1) * math.sqrt(float(nup) / nwup)
else:
mult = 1.0
return mult

def _get_lars(self, w, g, wd):
def _get_lars(self, weight, g, wd):
"""Returns a scaling factor for the learning rate for this layer
default is 1
"""
w2=self._l2norm(w)
g2 = self._l2norm(g)
lars = math.sqrt(w2/(g2+wd*w2+1e-18))
#print "W2="+str(w2), "G2="+str(g2), "lars="+str(lars)
weight2 = self._l2norm(weight)
grad2 = self._l2norm(g)
lars = math.sqrt(weight2 / (grad2 + wd * weight2 + 1e-18))
if lars < 0.01:
lars = 0.01
elif lars > 100:
Expand All @@ -643,59 +645,57 @@ def _l2norm(self, v):
"inner product implementation"
norm = multiply(v, v).asnumpy().sum()
return norm


def _reset_cum_gradient(self, index):
"called every macro-batch to reset cumulated gradients to 0 for a given index"
self.cumgrads[index]['cum_grad'] = 0

def _get_cum_gradient(self, index):
"get the cumulated gradient for index"
if index in self.cumgrads:
return self.cumgrads[index]
else:
return {}

def _put_cum_gradient(self, index, cg):
def _put_cum_gradient(self, index, cgrad):
"store cumulated gradient for index"
self.cumgrads[index] = cg
self.cumgrads[index] = cgrad

def _cumulate_gradient(self, grad, index):
"Cumulate gradients for large-batch emulation. Cumulated by index (layer)"
cg = self._get_cum_gradient(index)
if cg:
num_cums = cg['num_cums']
cgrad = self._get_cum_gradient(index)
if cgrad:
num_cums = cgrad['num_cums']
if num_cums > 0:
cum_grad = cg['cum_grad'] + grad
num_cums +=1
cum_grad = cgrad['cum_grad'] + grad
num_cums += 1
else:
cum_grad = grad
num_cums = self.init_updates+1
num_cums = self.init_updates + 1
else:
cum_grad = grad
num_cums = self.init_updates+1
cg = {'cum_grad': cum_grad, 'num_cums':num_cums}
self._put_cum_gradient(index, cg)
return cg

num_cums = self.init_updates + 1
cgrad = {'cum_grad': cum_grad, 'num_cums': num_cums}
self._put_cum_gradient(index, cgrad)
return cgrad

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
assert (isinstance(weight, NDArray))
assert (isinstance(grad, NDArray))

lr = self._get_lr(index)
wd = self._get_wd(index)
self._update_count(index)

# new stuff for large batch
cg = self._cumulate_gradient(grad, index)
if (cg['num_cums'] % self.batch_scale) == 0:
grad = cg['cum_grad'] / self.batch_scale
cgrad = self._cumulate_gradient(grad, index)
if (cgrad['num_cums'] % self.batch_scale) == 0:
grad = cgrad['cum_grad'] / self.batch_scale
if self.warmup_strategy == 'lars':
lbmult = self._get_lars(weight, grad, wd)
else:
lbmult = self._get_lbmult(cg['num_cums'])
lr = lr*lbmult
lbmult = self._get_lbmult(cgrad['num_cums'])
lr = lr * lbmult
# do the regular sgd update flow
kwargs = {'rescale_grad': self.rescale_grad}
if self.momentum > 0:
Expand All @@ -706,26 +706,22 @@ def update(self, index, weight, grad, state):

if not use_multi_precision:
if state is not None:
sgd_mom_update(weight, grad, state, out=weight,
lr=lr, wd=wd, **kwargs)
sgd_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs)
else:
sgd_update(weight, grad, out=weight,
lr=lr, wd=wd, **kwargs)
sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
else:
if state[0] is not None:
mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight,
lr=lr, wd=wd, **kwargs)
mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, lr=lr, wd=wd,
**kwargs)
else:
mp_sgd_update(weight, grad, state[1], out=weight,
lr=lr, wd=wd, **kwargs)
mp_sgd_update(weight, grad, state[1], out=weight, lr=lr, wd=wd, **kwargs)
# reset update count and cumulated gradient per large batch
self._reset_cum_gradient(index)
else:
lr=0.0
kwargs = {}
sgd_update(weight, grad, out=weight,
lr=lr, wd=wd, **kwargs)

lr = 0.0
kwargs = {}
sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)

# pylint: enable=line-too-long
@register
class DCASGD(Optimizer):
Expand Down

0 comments on commit 20deebf

Please sign in to comment.