Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
190 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
# Source: https://github.com/pytorch/pytorch/pull/3740 | ||
import math | ||
|
||
import torch | ||
from torch.optim.optimizer import Optimizer | ||
|
||
from .a2grad import OptFloat, OptLossClosure, Params | ||
|
||
__all__ = ('SGDP',) | ||
|
||
|
||
class SGDP(Optimizer): | ||
r"""Implements SGDP algorithm. | ||
It has been proposed in `Slowing Down the Weight Norm Increase in | ||
Momentum-based Optimizers`__ | ||
Arguments: | ||
params: iterable of parameters to optimize or dicts defining | ||
parameter groups | ||
lr: learning rate (default: 1e-3) | ||
momentum: momentum factor (default: 0) | ||
dampening: dampening for momentum (default: 0) | ||
eps: term added to the denominator to improve | ||
numerical stability (default: 1e-8) | ||
weight_decay: weight decay (L2 penalty) (default: 0) | ||
delta: threhold that determines whether a set of parameters is scale | ||
invariant or not (default: 0.1) | ||
wd_ratio: relative weight decay applied on scale-invariant parameters | ||
compared to that applied on scale-variant parameters (default: 0.1) | ||
nesterov: enables Nesterov momentum (default: False) | ||
__ https://arxiv.org/abs/2006.08217 | ||
Note: | ||
Reference code: https://github.com/clovaai/AdamP | ||
""" | ||
|
||
def __init__( | ||
self, | ||
params: Params, | ||
lr: float = 1e-3, | ||
momentum: float = 0, | ||
dampening: float = 0, | ||
eps: float = 1e-8, | ||
weight_decay: float = 0, | ||
delta: float = 0.1, | ||
wd_ratio: float = 0.1, | ||
nesterov: bool = False, | ||
) -> None: | ||
if lr <= 0.0: | ||
raise ValueError('Invalid learning rate: {}'.format(lr)) | ||
if eps < 0.0: | ||
raise ValueError('Invalid epsilon value: {}'.format(eps)) | ||
if momentum < 0.0: | ||
raise ValueError('Invalid momentum value: {}'.format(momentum)) | ||
if dampening < 0.0: | ||
raise ValueError('Invalid dampening value: {}'.format(dampening)) | ||
if weight_decay < 0: | ||
raise ValueError( | ||
'Invalid weight_decay value: {}'.format(weight_decay) | ||
) | ||
if delta < 0: | ||
raise ValueError('Invalid delta value: {}'.format(delta)) | ||
if wd_ratio < 0: | ||
raise ValueError('Invalid wd_ratio value: {}'.format(wd_ratio)) | ||
|
||
defaults = dict( | ||
lr=lr, | ||
momentum=momentum, | ||
dampening=dampening, | ||
eps=eps, | ||
weight_decay=weight_decay, | ||
delta=delta, | ||
wd_ratio=wd_ratio, | ||
nesterov=nesterov, | ||
) | ||
super(SGDP, self).__init__(params, defaults) | ||
|
||
@staticmethod | ||
def _channel_view(x): | ||
return x.view(x.size(0), -1) | ||
|
||
@staticmethod | ||
def _layer_view(x): | ||
return x.view(1, -1) | ||
|
||
@staticmethod | ||
def _cosine_similarity(x, y, eps, view_func): | ||
x = view_func(x) | ||
y = view_func(y) | ||
|
||
x_norm = x.norm(dim=1).add_(eps) | ||
y_norm = y.norm(dim=1).add_(eps) | ||
dot = (x * y).sum(dim=1) | ||
|
||
return dot.abs() / x_norm / y_norm | ||
|
||
def _projection(self, p, grad, perturb, delta, wd_ratio, eps): | ||
wd = 1 | ||
expand_size = [-1] + [1] * (len(p.shape) - 1) | ||
for view_func in [self._channel_view, self._layer_view]: | ||
|
||
cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) | ||
|
||
if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): | ||
p_n = p.data / view_func(p.data).norm(dim=1).view( | ||
expand_size | ||
).add_(eps) | ||
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view( | ||
expand_size | ||
) | ||
wd = wd_ratio | ||
|
||
return perturb, wd | ||
|
||
return perturb, wd | ||
|
||
def step(self, closure: OptLossClosure = None) -> OptFloat: | ||
r"""Performs a single optimization step. | ||
Arguments: | ||
closure: A closure that reevaluates the model and returns the loss. | ||
""" | ||
loss = None | ||
if closure is not None: | ||
loss = closure() | ||
|
||
for group in self.param_groups: | ||
weight_decay = group['weight_decay'] | ||
momentum = group['momentum'] | ||
dampening = group['dampening'] | ||
nesterov = group['nesterov'] | ||
|
||
for p in group['params']: | ||
if p.grad is None: | ||
continue | ||
|
||
grad = p.grad.data | ||
state = self.state[p] | ||
|
||
# State initialization | ||
if len(state) == 0: | ||
state['momentum'] = torch.zeros_like( | ||
p.data, memory_format=torch.preserve_format | ||
) | ||
|
||
# SGD | ||
buf = state['momentum'] | ||
buf.mul_(momentum).add_(grad, alpha=1 - dampening) | ||
if nesterov: | ||
d_p = grad + momentum * buf | ||
else: | ||
d_p = buf | ||
|
||
# Projection | ||
wd_ratio = 1 | ||
if len(p.shape) > 1: | ||
d_p, wd_ratio = self._projection( | ||
p, | ||
grad, | ||
d_p, | ||
group['delta'], | ||
group['wd_ratio'], | ||
group['eps'], | ||
) | ||
|
||
# Weight decay | ||
if weight_decay != 0: | ||
p.data.mul_( | ||
1 | ||
- group['lr'] | ||
* group['weight_decay'] | ||
* wd_ratio | ||
/ (1 - momentum) | ||
) | ||
|
||
# Step | ||
p.data.add_(d_p, alpha=-group['lr']) | ||
|
||
return loss |