This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
/
optimizer.py
1578 lines (1318 loc) · 59.2 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
"""Weight updating functions."""
import logging
import math
import pickle
import warnings
import numpy
from ..base import py_str
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update)
from ..ndarray import sparse
from ..random import normal
__all__ = [
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD',
'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum',
'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
]
class Optimizer(object):
"""The base class inherited by all optimizers.
Parameters
----------
rescale_grad : float, optional
Multiply the gradient with `rescale_grad` before updating. Often
choose to be ``1.0/batch_size``.
param_idx2name : dict from int to string, optional
A dictionary that maps int index to string name.
clip_gradient : float, optional
Clip the gradient by projecting onto the box ``[-clip_gradient, clip_gradient]``.
learning_rate : float, optional
The initial learning rate.
lr_scheduler : LRScheduler, optional
The learning rate scheduler.
wd : float, optional
The weight decay (or L2 regularization) coefficient. Modifies objective
by adding a penalty for having large weights.
sym: Symbol, optional
The Symbol this optimizer is applying to.
begin_num_update : int, optional
The initial number of updates.
multi_precision : bool, optional
Flag to control the internal precision of the optimizer.::
False: results in using the same precision as the weights (default),
True: makes internal 32-bit copy of the weights and applies gradients
in 32-bit precision even if actual weights used in the model have lower precision.
Turning this on can improve convergence and accuracy when training with float16.
Properties
----------
learning_rate : float
The current learning rate of the optimizer. Given an Optimizer object
optimizer, its learning rate can be accessed as optimizer.learning_rate.
"""
def __init__(self, rescale_grad=1., param_idx2name=None, wd=0.,
clip_gradient=None, learning_rate=0.01,
lr_scheduler=None, sym=None, begin_num_update=0,
multi_precision=False, param_dict=None):
self.rescale_grad = rescale_grad
self.lr = learning_rate
self.lr_scheduler = lr_scheduler
if lr_scheduler is not None:
self.lr_scheduler.base_lr = learning_rate
self.wd = wd
self.lr_mult = {}
self.wd_mult = {}
self.begin_num_update = begin_num_update
self.num_update = begin_num_update
self._index_update_count = {}
self.clip_gradient = clip_gradient
self.multi_precision = multi_precision
if param_idx2name is None:
param_idx2name = {}
assert isinstance(param_idx2name, dict), \
'param_idx2name should be a dict of param indexes to names.'
self.idx2name = param_idx2name.copy()
self.sym_info = (sym.attr_dict(), sym.list_arguments()) if sym is not None else ()
self.param_dict = param_dict if param_dict else {}
self.set_lr_mult({})
self.set_wd_mult({})
opt_registry = {}
@staticmethod
def register(klass):
"""Registers a new optimizer.
Once an optimizer is registered, we can create an instance of this
optimizer with `create_optimizer` later.
Examples
--------
>>> @mx.optimizer.Optimizer.register
... class MyOptimizer(mx.optimizer.Optimizer):
... pass
>>> optim = mx.optimizer.Optimizer.create_optimizer('MyOptimizer')
>>> print(type(optim))
<class '__main__.MyOptimizer'>
"""
assert(isinstance(klass, type))
name = klass.__name__.lower()
if name in Optimizer.opt_registry:
warnings.warn('WARNING: New optimizer %s.%s is overriding '
'existing optimizer %s.%s' %
(klass.__module__, klass.__name__,
Optimizer.opt_registry[name].__module__,
Optimizer.opt_registry[name].__name__))
Optimizer.opt_registry[name] = klass
return klass
@staticmethod
def create_optimizer(name, **kwargs):
"""Instantiates an optimizer with a given name and kwargs.
.. note:: We can use the alias `create` for ``Optimizer.create_optimizer``.
Parameters
----------
name: str
Name of the optimizer. Should be the name
of a subclass of Optimizer. Case insensitive.
kwargs: dict
Parameters for the optimizer.
Returns
-------
Optimizer
An instantiated optimizer.
Examples
--------
>>> sgd = mx.optimizer.Optimizer.create_optimizer('sgd')
>>> type(sgd)
<class 'mxnet.optimizer.SGD'>
>>> adam = mx.optimizer.create('adam', learning_rate=.1)
>>> type(adam)
<class 'mxnet.optimizer.Adam'>
"""
if name.lower() in Optimizer.opt_registry:
return Optimizer.opt_registry[name.lower()](**kwargs)
else:
raise ValueError('Cannot find optimizer %s' % name)
@property
def learning_rate(self):
if self.lr_scheduler is not None:
return self.lr_scheduler(self.num_update)
else:
return self.lr
def create_state(self, index, weight):
"""Creates auxiliary state for a given weight.
Some optimizers require additional states, e.g. as momentum, in addition
to gradients in order to update weights. This function creates state
for a given weight which will be used in `update`. This function is
called only once for each weight.
Parameters
----------
index : int
An unique index to identify the weight.
weight : NDArray
The weight.
Returns
-------
state : any obj
The state associated with the weight.
"""
def create_state_multi_precision(self, index, weight):
"""Creates auxiliary state for a given weight, including FP32 high
precision copy if original weight is FP16.
This method is provided to perform automatic mixed precision training
for optimizers that do not support it themselves.
Parameters
----------
index : int
An unique index to identify the weight.
weight : NDArray
The weight.
Returns
-------
state : any obj
The state associated with the weight.
"""
weight_master_copy = None
if self.multi_precision and weight.dtype == numpy.float16:
weight_master_copy = weight.astype(numpy.float32)
return (weight_master_copy,) + (self.create_state(index, weight_master_copy),)
if weight.dtype == numpy.float16 and not self.multi_precision:
warnings.warn("Accumulating with float16 in optimizer can lead to "
"poor accuracy or slow convergence. "
"Consider using multi_precision=True option of the "
"optimizer")
return self.create_state(index, weight)
def update(self, index, weight, grad, state):
"""Updates the given parameter using the corresponding gradient and state.
Parameters
----------
index : int
The unique index of the parameter into the individual learning
rates and weight decays. Learning rates and weight decay
may be set via `set_lr_mult()` and `set_wd_mult()`, respectively.
weight : NDArray
The parameter to be updated.
grad : NDArray
The gradient of the objective with respect to this parameter.
state : any obj
The state returned by `create_state()`.
"""
raise NotImplementedError()
def update_multi_precision(self, index, weight, grad, state):
"""Updates the given parameter using the corresponding gradient and state.
Mixed precision version.
Parameters
----------
index : int
The unique index of the parameter into the individual learning
rates and weight decays. Learning rates and weight decay
may be set via `set_lr_mult()` and `set_wd_mult()`, respectively.
weight : NDArray
The parameter to be updated.
grad : NDArray
The gradient of the objective with respect to this parameter.
state : any obj
The state returned by `create_state()`.
"""
if self.multi_precision and weight.dtype == numpy.float16:
# Wrapper for mixed precision
weight_master_copy = state[0]
original_state = state[1]
grad32 = grad.astype(numpy.float32)
self.update(index, weight_master_copy, grad32, original_state)
cast(weight_master_copy, dtype=weight.dtype, out=weight)
else:
self.update(index, weight, grad, state)
def set_learning_rate(self, lr):
"""Sets a new learning rate of the optimizer.
Parameters
----------
lr : float
The new learning rate of the optimizer.
"""
if self.lr_scheduler is not None:
raise UserWarning("LRScheduler of the optimizer has already been "
"defined. Note that set_learning_rate can mutate "
"the value of the learning rate of the optimizer "
"only when the LRScheduler of the optimizer is "
"undefined.")
else:
self.lr = lr
def set_lr_scale(self, args_lrscale): # pylint: disable=unused-argument
"""[DEPRECATED] Sets lr scale. Use set_lr_mult instead."""
raise DeprecationWarning
def set_lr_mult(self, args_lr_mult):
"""Sets an individual learning rate multiplier for each parameter.
If you specify a learning rate multiplier for a parameter, then
the learning rate for the parameter will be set as the product of
the global learning rate `self.lr` and its multiplier.
.. note:: The default learning rate multiplier of a `Variable`
can be set with `lr_mult` argument in the constructor.
Parameters
----------
args_lr_mult : dict of str/int to float
For each of its key-value entries, the learning rate multipler for the
parameter specified in the key will be set as the given value.
You can specify the parameter with either its name or its index.
If you use the name, you should pass `sym` in the constructor,
and the name you specified in the key of `args_lr_mult` should match
the name of the parameter in `sym`. If you use the index, it should
correspond to the index of the parameter used in the `update` method.
Specifying a parameter by its index is only supported for backward
compatibility, and we recommend to use the name instead.
"""
self.lr_mult = {}
if self.sym_info:
attr, arg_names = self.sym_info
for name in arg_names:
if name in attr and '__lr_mult__' in attr[name]:
self.lr_mult[name] = float(attr[name]['__lr_mult__'])
self.lr_mult.update(args_lr_mult)
def set_wd_mult(self, args_wd_mult):
"""Sets an individual weight decay multiplier for each parameter.
By default, if `param_idx2name` was provided in the
constructor, the weight decay multipler is set as 0 for all
parameters whose name don't end with ``_weight`` or
``_gamma``.
.. note:: The default weight decay multiplier for a `Variable`
can be set with its `wd_mult` argument in the constructor.
Parameters
----------
args_wd_mult : dict of string/int to float
For each of its key-value entries, the weight decay multipler for the
parameter specified in the key will be set as the given value.
You can specify the parameter with either its name or its index.
If you use the name, you should pass `sym` in the constructor,
and the name you specified in the key of `args_lr_mult` should match
the name of the parameter in `sym`. If you use the index, it should
correspond to the index of the parameter used in the `update` method.
Specifying a parameter by its index is only supported for backward
compatibility, and we recommend to use the name instead.
"""
self.wd_mult = {}
for n in self.idx2name.values():
if not (n.endswith('_weight') or n.endswith('_gamma')):
self.wd_mult[n] = 0.0
if self.sym_info:
attr, arg_names = self.sym_info
for name in arg_names:
if name in attr and '__wd_mult__' in attr[name]:
self.wd_mult[name] = float(attr[name]['__wd_mult__'])
self.wd_mult.update(args_wd_mult)
def _update_count(self, index):
"""Updates num_update.
Parameters
----------
index : int
The index to be updated.
"""
if index not in self._index_update_count:
self._index_update_count[index] = self.begin_num_update
self._index_update_count[index] += 1
self.num_update = max(self._index_update_count[index], self.num_update)
def _get_lr(self, index):
"""Gets the learning rate given the index of the weight.
Parameters
----------
index : int
The index corresponding to the weight.
Returns
-------
lr : float
Learning rate for this index.
"""
if self.lr_scheduler is not None:
lr = self.lr_scheduler(self.num_update)
else:
lr = self.lr
if index in self.param_dict:
lr *= self.param_dict[index].lr_mult
elif index in self.lr_mult:
lr *= self.lr_mult[index]
elif index in self.idx2name:
lr *= self.lr_mult.get(self.idx2name[index], 1.0)
return lr
def _get_wd(self, index):
"""Gets weight decay for index.
Returns 0 for non-weights if the name of weights are provided for `__init__`.
Parameters
----------
index : int
The index for weight.
Returns
-------
wd : float
Weight decay for this index.
"""
wd = self.wd
if index in self.param_dict:
wd *= self.param_dict[index].wd_mult
elif index in self.wd_mult:
wd *= self.wd_mult[index]
elif index in self.idx2name:
wd *= self.wd_mult.get(self.idx2name[index], 1.0)
return wd
def __getstate__(self):
ret = self.__dict__.copy()
# do not include param_dict in the state
del ret['param_dict']
return ret
def __setstate__(self, state):
self.__dict__ = state
# param_dict needs to be explicitly set by the trainer
self.param_dict = {}
# convenience wrapper for Optimizer.Register
register = Optimizer.register # pylint: disable=invalid-name
# pylint: disable=line-too-long
@register
class SGD(Optimizer):
"""The SGD optimizer with momentum and weight decay.
If the storage types of grad is ``row_sparse`` and ``lazy_update`` is True, \
**lazy updates** are applied by::
for row in grad.indices:
rescaled_grad[row] = lr * (rescale_grad * clip(grad[row], clip_gradient) + wd * weight[row])
state[row] = momentum[row] * state[row] + rescaled_grad[row]
weight[row] = weight[row] - state[row]
The sparse update only updates the momentum for the weights whose row_sparse
gradient indices appear in the current batch, rather than updating it for all
indices. Compared with the original update, it can provide large
improvements in model training throughput for some applications. However, it
provides slightly different semantics than the original update, and
may lead to different empirical results.
Otherwise, **standard updates** are applied by::
rescaled_grad = lr * (rescale_grad * clip(grad, clip_gradient) + wd * weight)
state = momentum * state + rescaled_grad
weight = weight - state
For details of the update algorithm see
:class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`.
This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.
Parameters
----------
momentum : float, optional
The momentum value.
lazy_update : bool, optional
Default is True. If True, lazy updates are applied \
if the storage types of weight and grad are both ``row_sparse``.
multi_precision: bool, optional
Flag to control the internal precision of the optimizer.::
False: results in using the same precision as the weights (default),
True: makes internal 32-bit copy of the weights and applies gradients
in 32-bit precision even if actual weights used in the model have lower precision.
Turning this on can improve convergence and accuracy when training with float16.
"""
def __init__(self, momentum=0.0, lazy_update=True, **kwargs):
super(SGD, self).__init__(**kwargs)
self.momentum = momentum
self.lazy_update = lazy_update
def create_state_multi_precision(self, index, weight):
weight_master_copy = None
if self.multi_precision and weight.dtype == numpy.float16:
weight_master_copy = weight.astype(numpy.float32)
return (self.create_state(index, weight_master_copy), weight_master_copy)
if weight.dtype == numpy.float16 and not self.multi_precision:
warnings.warn("Accumulating with float16 in optimizer can lead to "
"poor accuracy or slow convergence. "
"Consider using multi_precision=True option of the "
"SGD optimizer")
return self.create_state(index, weight)
def create_state(self, index, weight):
momentum = None
if self.momentum != 0.0:
stype = weight.stype if self.lazy_update else 'default'
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype)
return momentum
def _update_impl(self, index, weight, grad, state, multi_precision=False):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
kwargs = {'rescale_grad': self.rescale_grad}
if self.momentum > 0:
kwargs['momentum'] = self.momentum
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
if not multi_precision:
if state is not None:
sgd_mom_update(weight, grad, state, out=weight,
lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs)
else:
sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update,
lr=lr, wd=wd, **kwargs)
else:
if state[0] is not None:
mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight,
lr=lr, wd=wd, **kwargs)
else:
mp_sgd_update(weight, grad, state[1], out=weight,
lr=lr, wd=wd, **kwargs)
def update(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state, multi_precision=False)
def update_multi_precision(self, index, weight, grad, state):
use_multi_precision = self.multi_precision and weight.dtype == numpy.float16
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)
@register
class Signum(Optimizer):
r"""The Signum optimizer that takes the sign of gradient or momentum.
The optimizer updates the weight by::
rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
state = momentum * state + (1-momentum)*rescaled_grad
weight = (1 - lr * wd_lh) * weight - lr * sign(state)
References
----------
Jeremy Bernstein, Yu-Xiang Wang, Kamyar Azizzadenesheli & Anima Anandkumar. (2018).
signSGD: Compressed Optimisation for Non-Convex Problems. In ICML'18.
See: https://arxiv.org/abs/1802.04434
For details of the update algorithm see
:class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.
This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.
Parameters
----------
momentum : float, optional
The momentum value.
wd_lh : float, optional
The amount of decoupled weight decay regularization, see details in the original paper at:\
https://arxiv.org/abs/1711.05101
"""
def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs):
super(Signum, self).__init__(learning_rate=learning_rate, **kwargs)
self.momentum = momentum
self.wd_lh = wd_lh
def create_state(self, index, weight):
momentum = None
if self.momentum != 0.0:
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
return momentum
def _update_impl(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
kwargs = {'rescale_grad': self.rescale_grad}
if self.momentum > 0:
kwargs['momentum'] = self.momentum
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
if self.wd_lh:
kwargs['wd_lh'] = self.wd_lh
if state is not None:
signum_update(weight, grad, state, out=weight,
lr=lr, wd=wd, **kwargs)
else:
signsgd_update(weight, grad, out=weight,
lr=lr, wd=wd, **kwargs)
def update(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state)
@register
class FTML(Optimizer):
"""The FTML optimizer.
This class implements the optimizer described in
*FTML - Follow the Moving Leader in Deep Learning*,
available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.
Denote time step by t. The optimizer updates the weight by::
rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
v = beta2 * v + (1 - beta2) * square(rescaled_grad)
d_t = (1 - power(beta1, t)) / lr * square_root(v / (1 - power(beta2, t))) + epsilon)
z = beta1 * z + (1 - beta1) * rescaled_grad - (d_t - beta1 * d_(t-1)) * weight
weight = - z / d_t
This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.
Parameters
----------
beta1 : float, optional
0 < beta1 < 1. Generally close to 0.5.
beta2 : float, optional
0 < beta2 < 1. Generally close to 1.
epsilon : float, optional
Small value to avoid division by 0.
"""
def __init__(self, beta1=0.6, beta2=0.999, epsilon=1e-8, **kwargs):
super(FTML, self).__init__(**kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
def create_state(self, index, weight):
return (zeros(weight.shape, weight.context, dtype=weight.dtype), # d_0
zeros(weight.shape, weight.context, dtype=weight.dtype), # v_0
zeros(weight.shape, weight.context, dtype=weight.dtype)) # z_0
def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]
kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'rescale_grad': self.rescale_grad, 't': t}
if self.clip_gradient:
kwargs['clip_grad'] = self.clip_gradient
prev_d, prev_v, prev_z = state
ftml_update(weight, grad, prev_d, prev_v, prev_z, out=weight,
lr=lr, wd=wd, **kwargs)
@register
class LBSGD(Optimizer):
"""The Large Batch SGD optimizer with momentum and weight decay.
The optimizer updates the weight by::
state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
weight = weight - state
For details of the update algorithm see :class:`~mxnet.ndarray.lbsgd_update` and
:class:`~mxnet.ndarray.lbsgd_mom_update`.
This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.
Parameters
----------
momentum : float, optional
The momentum value.
multi_precision: bool, optional
Flag to control the internal precision of the optimizer.::
False: results in using the same precision as the weights (default),
True: makes internal 32-bit copy of the weights and applies gradients
in 32-bit precision even if actual weights used in the model have lower precision.
Turning this on can improve convergence and accuracy when training with float16.
warmup_strategy: string ('linear', 'power2', 'sqrt'. , 'lars' default : 'linear')
warmup_epochs: unsigned, default: 5
batch_scale: unsigned, default: 1 (same as batch size*numworkers)
updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)
begin_epoch: unsigned, default 0, starting epoch.
"""
def __init__(self, momentum=0.0, multi_precision=False, warmup_strategy='linear',
warmup_epochs=5, batch_scale=1, updates_per_epoch=32, begin_epoch=0, num_epochs=60,
**kwargs):
super(LBSGD, self).__init__(**kwargs)
logging.info('Running Large-Batch SGD Algorithm')
logging.info('(Batch_scale=%f, warmup_epochs=%d, warmup_strategy=%s, updates_per_epoch=%d)',
batch_scale, warmup_epochs, warmup_strategy, updates_per_epoch)
self.momentum = momentum
self.multi_precision = multi_precision
# new user parameters for large batch
self.warmup_strategy = warmup_strategy
self.warmup_epochs = warmup_epochs
self.batch_scale = batch_scale
self.updates_per_epoch = updates_per_epoch
self.init_updates = begin_epoch * updates_per_epoch
self.num_epochs = num_epochs
# addl internal usage parameters and storage
self.lbmult = 1
self.cumgrads = {}
# for adaptive lr
self.adaptive = False
self.admult = 1 # adaptation constant
def create_state(self, index, weight):
momentum = None
weight_master_copy = None
if self.multi_precision and weight.dtype == numpy.float16:
weight_master_copy = array(weight, ctx=weight.context, dtype=numpy.float32)
if self.momentum != 0.0:
momentum = zeros(weight.shape, weight.context, dtype=numpy.float32,
stype=weight.stype)
return (momentum, weight_master_copy)
if weight.dtype == numpy.float16 and not self.multi_precision:
warnings.warn("Accumulating with float16 in optimizer can lead to "
"poor accuracy or slow convergence. "
"Consider using multi_precision=True option of the "
"SGD optimizer")
if self.momentum != 0.0:
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
return momentum
def _get_lbmult(self, nup):
"""Returns lr scaling factor for large batch according to warmup schedule
(to be implemented)
"""
nwup = self.warmup_epochs * self.updates_per_epoch
strategy = self.warmup_strategy
maxmult = float(self.batch_scale)
if nup >= nwup:
mult = maxmult
elif nwup <= 1:
mult = 1.0
else:
if (strategy == 'linear'):
mult = 1.0 + (maxmult - 1) * nup / nwup
elif (strategy == 'power2'):
mult = 1.0 + (maxmult-1) * (nup*nup)/(nwup*nwup)
elif (strategy == 'sqrt'):
mult = 1.0 + (maxmult - 1) * math.sqrt(float(nup) / nwup)
else:
mult = 1.0
return mult
def _get_lars(self, weight, g, wd):
"""Returns a scaling factor for the learning rate for this layer
default is 1
"""
weight2 = self._l2norm(weight)
grad2 = self._l2norm(g)
lars = math.sqrt(weight2 / (grad2 + wd * weight2 + 1e-18))
if lars < 0.01:
lars = 0.01
elif lars > 100:
lars = 100
return lars
def _l2norm(self, v):
"inner product implementation"
norm = multiply(v, v).asnumpy().sum()
return norm
def _reset_cum_gradient(self, index):
"called every macro-batch to reset cumulated gradients to 0 for a given index"
self.cumgrads[index]['cum_grad'] = 0
def _get_cum_gradient(self, index):
"get the cumulated gradient for index"
if index in self.cumgrads:
return self.cumgrads[index]
else:
return {}
def _put_cum_gradient(self, index, cgrad):
"store cumulated gradient for index"
self.cumgrads[index] = cgrad
def _cumulate_gradient(self, grad, index):
"Cumulate gradients for large-batch emulation. Cumulated by index (layer)"
cgrad = self._get_cum_gradient(index)
if cgrad:
num_cums = cgrad['num_cums']
if num_cums > 0:
cum_grad = cgrad['cum_grad'] + grad
num_cums += 1
else:
cum_grad = grad
num_cums = self.init_updates + 1
else:
cum_grad = grad
num_cums = self.init_updates + 1
cgrad = {'cum_grad': cum_grad, 'num_cums': num_cums}
self._put_cum_gradient(index, cgrad)
return cgrad
def update(self, index, weight, grad, state):
assert (isinstance(weight, NDArray))
assert (isinstance(grad, NDArray))
lr = self._get_lr(index)
wd = self._get_wd(index)
self._update_count(index)
# new stuff for large batch
cgrad = self._cumulate_gradient(grad, index)
if (cgrad['num_cums'] % self.batch_scale) == 0:
grad = cgrad['cum_grad'] / self.batch_scale
if self.warmup_strategy == 'lars':
lbmult = self._get_lars(weight, grad, wd)
else:
lbmult = self._get_lbmult(cgrad['num_cums'])
lr = lr * lbmult
# do the regular sgd update flow
kwargs = {'rescale_grad': self.rescale_grad}
if self.momentum > 0:
kwargs['momentum'] = self.momentum
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
use_multi_precision = isinstance(state, (list, tuple))
if not use_multi_precision:
if state is not None:
sgd_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs)
else:
sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
else:
if state[0] is not None:
mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, lr=lr, wd=wd,
**kwargs)
else:
mp_sgd_update(weight, grad, state[1], out=weight, lr=lr, wd=wd, **kwargs)
# reset update count and cumulated gradient per large batch
self._reset_cum_gradient(index)
else:
lr = 0.0
kwargs = {}
sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
# pylint: enable=line-too-long
@register
class DCASGD(Optimizer):
"""The DCASGD optimizer.
This class implements the optimizer described in *Asynchronous Stochastic Gradient Descent
with Delay Compensation for Distributed Deep Learning*,
available at https://arxiv.org/abs/1609.08326.
This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.
Parameters
----------
momentum : float, optional
The momentum value.
lamda : float, optional
Scale DC value.
"""
def __init__(self, momentum=0.0, lamda=0.04, **kwargs):
super(DCASGD, self).__init__(**kwargs)
self.momentum = momentum
self.weight_previous = {}
self.lamda = lamda
def create_state(self, index, weight):
if self.momentum == 0.0:
return (None,
weight.copy()) # previous weight
else:
return (zeros(weight.shape, weight.context, dtype=weight.dtype), # momentum
weight.copy()) # previous weight
def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
mom, previous_weight = state
if mom:
mom[:] *= self.momentum
mom[:] += -lr * (grad + wd * weight + self.lamda \
* grad * grad * (weight - previous_weight))
else:
assert(self.momentum == 0.0)
mom = -lr * (grad + wd * weight + self.lamda \
* grad * grad * (weight - previous_weight))
previous_weight[:] = weight
weight[:] += mom
@register
class NAG(Optimizer):
"""Nesterov accelerated SGD.
This optimizer updates each weight by::
state = momentum * state + grad + wd * weight
weight = weight - (lr * (grad + momentum * state))
Parameters
----------
momentum : float, optional
The momentum value.
multi_precision: bool, optional
Flag to control the internal precision of the optimizer.::
False: results in using the same precision as the weights (default),
True: makes internal 32-bit copy of the weights and applies gradients
in 32-bit precision even if actual weights used in the model have lower precision.
Turning this on can improve convergence and accuracy when training with float16.
"""
def __init__(self, momentum=0.0, **kwargs):
super(NAG, self).__init__(**kwargs)
self.momentum = momentum
def create_state(self, index, weight):
momentum = None
if self.momentum != 0.0:
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype)
return momentum
def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
if state is not None:
mom = state
mom[:] *= self.momentum
grad += wd * weight
mom[:] += grad
grad[:] += self.momentum * mom
weight[:] += -lr * grad
else:
assert self.momentum == 0.0
weight[:] += -lr * (grad + wd * weight)
@register
class SGLD(Optimizer):
"""Stochastic Gradient Riemannian Langevin Dynamics.
This class implements the optimizer described in the paper *Stochastic Gradient
Riemannian Langevin Dynamics on the Probability Simplex*, available at
https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf.
"""
def __init__(self, **kwargs):
super(SGLD, self).__init__(**kwargs)
def create_state(self, index, weight):
return None
def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)