Skip to content

Commit

Permalink
keep old fused* name and rename new optimizers without prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
FDecaYed committed Aug 12, 2019
1 parent 4d6ed50 commit adad599
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 30 deletions.
8 changes: 4 additions & 4 deletions apex/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .fused_sgd import FusedSGD
from .novograd import FusedNovoGrad
from .fused_adam_v1 import FusedAdam_v1
from .adam import FusedAdam
#from .sgd import FusedSGD
from .fused_adam import FusedAdam
from .fp16_optimizer import FP16_Optimizer
from .sgd import SGD
from .adam import Adam
from .novograd import NovoGrad
8 changes: 4 additions & 4 deletions apex/optimizers/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from apex.multi_tensor_apply import multi_tensor_applier
from amp_C import multi_tensor_adam

class FusedAdam(torch.optim.Optimizer):
class Adam(torch.optim.Optimizer):

"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(self, params, lr=1e-3, bias_correction = True,
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay)
super(FusedAdam, self).__init__(params, defaults)
super(Adam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self.dummy_overflow_buf = torch.cuda.IntTensor([0])

Expand All @@ -57,8 +57,8 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
and returns the loss.
"""
if any(p is not None for p in [grads, output_params, scale, grad_norms]):
raise RuntimeError('FusedAdam has been updated, please use with AMP for mixed precision. '
'For legacy code using fp16_optimizer, use FusedAdam_v1.')
raise RuntimeError('Adam has been updated, please use with AMP for mixed precision. '
'For legacy code using fp16_optimizer, use FusedAdam.')
loss = None
if closure is not None:
loss = closure()
Expand Down
3 changes: 2 additions & 1 deletion apex/optimizers/fp16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(self,
dynamic_loss_args=None,
verbose=True):

print("\nfp16_optimizer will be removed in future. To update, use fused optimizers with AMP.")
print("\nfp16_optimizer is designed to work with apex.optimizers.Fused*, and will be removed in future")
print("To update, use updated optimizers without Fused prefix with AMP.")
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add new fused optimizer later
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import torch
import importlib

from ..multi_tensor_apply import multi_tensor_applier
from apex.multi_tensor_apply import multi_tensor_applier

class FusedAdam_v1(torch.optim.Optimizer):
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
Expand Down Expand Up @@ -40,6 +40,8 @@ def __init__(self, params,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0):
print("\nFusedAdam will be removed in future. To update, use apex.optimizers.Adam with AMP.")

global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")

Expand All @@ -58,7 +60,7 @@ def __init__(self, params,
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(FusedAdam_v1, self).__init__(params, defaults)
super(FusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1

def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
Expand Down Expand Up @@ -195,4 +197,3 @@ def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norm
group['weight_decay'])

return loss

3 changes: 3 additions & 0 deletions apex/optimizers/fused_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False,
wd_after_momentum=False,
materialize_master_grads=True):

print("\nFusedSGD will be removed in future. To update, use apex.optimizers.SGD with AMP.")

if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
Expand Down
12 changes: 6 additions & 6 deletions apex/optimizers/novograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from apex.multi_tensor_apply import multi_tensor_applier
from amp_C import multi_tensor_novograd

class FusedNovoGrad(torch.optim.Optimizer):
class NovoGrad(torch.optim.Optimizer):

"""Implements NovoGrad algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
Expand Down Expand Up @@ -48,12 +48,12 @@ def __init__(self, params, lr=1e-3, bias_correction=True,
grad_averaging=True, norm_type=2, init_zero=False,
set_grad_none=True):
if amsgrad:
raise RuntimeError('FusedNovoGrad does not support the AMSGrad variant.')
raise RuntimeError('NovoGrad does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, norm_type=norm_type,
init_zero=init_zero)
super(FusedNovoGrad, self).__init__(params, defaults)
super(NovoGrad, self).__init__(params, defaults)
self.moment_mode = 0 if reg_inside_moment else 1
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
self.set_grad_none = set_grad_none
Expand All @@ -64,7 +64,7 @@ def zero_grad(self):
for p in group['params']:
p.grad = None
else:
super(FusedNovoGrad, self).zero_grad()
super(NovoGrad, self).zero_grad()

def step(self, closure=None):
"""Performs a single optimization step.
Expand Down Expand Up @@ -96,7 +96,7 @@ def step(self, closure=None):
if p.grad is None:
continue
if p.grad.data.is_sparse:
raise RuntimeError('FusedNovoGrad does not support sparse gradients, please consider SparseAdam instead')
raise RuntimeError('NovoGrad does not support sparse gradients, please consider SparseAdam instead')

state = self.state[p]
# State initialization
Expand All @@ -119,7 +119,7 @@ def step(self, closure=None):
elif group['norm_type'] == 2:
m2 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_list]
else:
raise RuntimeError('FusedNovoGrad only support l2/inf norm now.')
raise RuntimeError('NovoGrad only support l2/inf norm now.')
group['exp_avg_sq'] = torch.cuda.FloatTensor(m2)
else:
assert(len(g_list) == group['exp_avg_sq'].numel())
Expand Down
6 changes: 3 additions & 3 deletions apex/optimizers/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from amp_C import multi_tensor_axpby
from apex.multi_tensor_apply import multi_tensor_applier

class FusedSGD(Optimizer):
class SGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Expand Down Expand Up @@ -52,10 +52,10 @@ def __init__(self, params, lr=0.1, momentum=0., dampening=0.,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(FusedSGD, self).__init__(params, defaults)
super(SGD, self).__init__(params, defaults)

def __setstate__(self, state):
super(FusedSGD, self).__setstate__(state)
super(SGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)

Expand Down
8 changes: 4 additions & 4 deletions tests/L0/run_mixed_adam/test_fp16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_fp16_optimizer(self):
ref_optim = torch.optim.Adam(self.ref_model.parameters())
ref_optim = apex.fp16_utils.FP16_Optimizer(ref_optim, verbose=False)

tst_optim = apex.optimizers.FusedAdam_v1(self.tst_model.parameters())
tst_optim = apex.optimizers.FusedAdam(self.tst_model.parameters())
tst_optim = apex.optimizers.FP16_Optimizer(tst_optim)

for i in range(self.iters):
Expand All @@ -58,7 +58,7 @@ def test_loss_scaling(self):
ref_optim = torch.optim.Adam(self.ref_model.parameters())
ref_optim = apex.fp16_utils.FP16_Optimizer(ref_optim, static_loss_scale=128.0, verbose=False)

tst_optim = apex.optimizers.FusedAdam_v1(self.tst_model.parameters())
tst_optim = apex.optimizers.FusedAdam(self.tst_model.parameters())
tst_optim = apex.optimizers.FP16_Optimizer(tst_optim, static_loss_scale=128.0)

for i in range(self.iters):
Expand All @@ -81,7 +81,7 @@ def test_parameter_groups(self):
ref_optim = apex.fp16_utils.FP16_Optimizer(ref_optim, verbose=False)

tst_groups = [{'params': [self.tst_model.weight]},{'params': [self.tst_model.bias]}]
tst_optim = apex.optimizers.FusedAdam_v1(tst_groups)
tst_optim = apex.optimizers.FusedAdam(tst_groups)
tst_optim = apex.optimizers.FP16_Optimizer(tst_optim)

for i in range(self.iters):
Expand All @@ -101,7 +101,7 @@ def test_grad_clip(self):
ref_optim = torch.optim.Adam(self.ref_model.parameters())
ref_optim = apex.fp16_utils.FP16_Optimizer(ref_optim, verbose=False)

tst_optim = apex.optimizers.FusedAdam_v1(self.tst_model.parameters(), max_grad_norm=0.01)
tst_optim = apex.optimizers.FusedAdam(self.tst_model.parameters(), max_grad_norm=0.01)
tst_optim = apex.optimizers.FP16_Optimizer(tst_optim)

for i in range(self.iters):
Expand Down
6 changes: 3 additions & 3 deletions tests/L0/run_mixed_adam/test_mixed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def gen_param_optim(self, tensors, ref_adam_option, tst_adam_option=None):

ref_optim = torch.optim.Adam(ref_param, **ref_adam_option)
if tst_adam_option:
tst_optim = apex.optimizers.FusedAdam_v1(tst_param, **tst_adam_option)
tst_optim = apex.optimizers.FusedAdam(tst_param, **tst_adam_option)
else:
tst_optim = apex.optimizers.FusedAdam_v1(tst_param, **ref_adam_option)
tst_optim = apex.optimizers.FusedAdam(tst_param, **ref_adam_option)

return (ref_param, tst_param, ref_optim, tst_optim)

def gen_grad(self, ref_param, tst_param):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def gen_param_optim(self, tensors, adam_option):
tst_param.append(torch.nn.Parameter(tensor.clone()))

ref_optim = torch.optim.Adam(ref_param, **adam_option)
tst_optim = apex.optimizers.FusedAdam(tst_param, **adam_option)
tst_optim = apex.optimizers.Adam(tst_param, **adam_option)

return (ref_param, tst_param, ref_optim, tst_optim)

Expand Down

0 comments on commit adad599

Please sign in to comment.