Skip to content

Commit

Permalink
adding sgdp
Browse files Browse the repository at this point in the history
  • Loading branch information
achaiah committed Nov 22, 2021
1 parent 265a7b8 commit ec2c0f6
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 6 deletions.
15 changes: 9 additions & 6 deletions pywick/models/model_utils.py
Expand Up @@ -103,7 +103,7 @@ def get_model(model_type: ModelType,
torch_hub_names = torch.hub.list(rwightman_repo, force_reload=force_reload)
if model_name in torch_hub_names:
model = torch.hub.load(rwightman_repo, model_name, pretrained=pretrained, num_classes=num_classes)
elif custom_load_fn is not None:
elif custom_load_fn:
model = custom_load_fn(model_name, pretrained, num_classes, **kwargs)
else:
# 1. Load model (pretrained or vanilla)
Expand Down Expand Up @@ -360,7 +360,7 @@ def diff_states(dict_canonical, dict_subset):
yield (name, v1)


def load_checkpoint(checkpoint_path, model=None, device='cpu', strict=True, ignore_chkpt_layers=None):
def load_checkpoint(checkpoint_path: str, model=None, device='cpu', strict: bool = True, ignore_chkpt_layers=None, debug: bool = False):
"""
Loads weights from a checkpoint into memory. If model is not None then the weights are loaded into the model.
Expand Down Expand Up @@ -401,12 +401,14 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac
# load data directly from a checkpoint
checkpoint_path = os.path.expanduser(checkpoint_path)
if os.path.isfile(checkpoint_path):
print('=> Loading checkpoint: {} onto device: {}'.format(checkpoint_path, device))
if debug:
print('=> Loading checkpoint: {} onto device: {}'.format(checkpoint_path, device))
checkpoint = torch.load(checkpoint_path, map_location=device)

pretrained_state = checkpoint['state_dict']
print("INFO: => loaded checkpoint {} (epoch {})".format(checkpoint_path, checkpoint.get('epoch')))
print('INFO: => checkpoint model name: ', checkpoint.get('modelname', checkpoint.get('model_name')), ' Make sure the checkpoint model name matches your model!!!')
if debug:
print("INFO: => loaded checkpoint {} (epoch {})".format(checkpoint_path, checkpoint.get('epoch')))
print('INFO: => checkpoint model name: ', checkpoint.get('modelname', checkpoint.get('model_name')), ' Make sure the checkpoint model name matches your model!!!')
else:
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_path)

Expand All @@ -431,7 +433,8 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac

# finally load the model weights
if model:
print('INFO: => Attempting to load checkpoint data onto model. Device: {} Strict: {}'.format(device, strict))
if debug:
print('INFO: => Attempting to load checkpoint data onto model. Device: {} Strict: {}'.format(device, strict))
model.load_state_dict(checkpoint['state_dict'], strict=strict)
return checkpoint

Expand Down
181 changes: 181 additions & 0 deletions pywick/optimizers/sgdp.py
@@ -0,0 +1,181 @@
# Source: https://github.com/pytorch/pytorch/pull/3740
import math

import torch
from torch.optim.optimizer import Optimizer

from .a2grad import OptFloat, OptLossClosure, Params

__all__ = ('SGDP',)


class SGDP(Optimizer):
r"""Implements SGDP algorithm.
It has been proposed in `Slowing Down the Weight Norm Increase in
Momentum-based Optimizers`__
Arguments:
params: iterable of parameters to optimize or dicts defining
parameter groups
lr: learning rate (default: 1e-3)
momentum: momentum factor (default: 0)
dampening: dampening for momentum (default: 0)
eps: term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay: weight decay (L2 penalty) (default: 0)
delta: threhold that determines whether a set of parameters is scale
invariant or not (default: 0.1)
wd_ratio: relative weight decay applied on scale-invariant parameters
compared to that applied on scale-variant parameters (default: 0.1)
nesterov: enables Nesterov momentum (default: False)
__ https://arxiv.org/abs/2006.08217
Note:
Reference code: https://github.com/clovaai/AdamP
"""

def __init__(
self,
params: Params,
lr: float = 1e-3,
momentum: float = 0,
dampening: float = 0,
eps: float = 1e-8,
weight_decay: float = 0,
delta: float = 0.1,
wd_ratio: float = 0.1,
nesterov: bool = False,
) -> None:
if lr <= 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
if eps < 0.0:
raise ValueError('Invalid epsilon value: {}'.format(eps))
if momentum < 0.0:
raise ValueError('Invalid momentum value: {}'.format(momentum))
if dampening < 0.0:
raise ValueError('Invalid dampening value: {}'.format(dampening))
if weight_decay < 0:
raise ValueError(
'Invalid weight_decay value: {}'.format(weight_decay)
)
if delta < 0:
raise ValueError('Invalid delta value: {}'.format(delta))
if wd_ratio < 0:
raise ValueError('Invalid wd_ratio value: {}'.format(wd_ratio))

defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
eps=eps,
weight_decay=weight_decay,
delta=delta,
wd_ratio=wd_ratio,
nesterov=nesterov,
)
super(SGDP, self).__init__(params, defaults)

@staticmethod
def _channel_view(x):
return x.view(x.size(0), -1)

@staticmethod
def _layer_view(x):
return x.view(1, -1)

@staticmethod
def _cosine_similarity(x, y, eps, view_func):
x = view_func(x)
y = view_func(y)

x_norm = x.norm(dim=1).add_(eps)
y_norm = y.norm(dim=1).add_(eps)
dot = (x * y).sum(dim=1)

return dot.abs() / x_norm / y_norm

def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
wd = 1
expand_size = [-1] + [1] * (len(p.shape) - 1)
for view_func in [self._channel_view, self._layer_view]:

cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)

if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
p_n = p.data / view_func(p.data).norm(dim=1).view(
expand_size
).add_(eps)
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(
expand_size
)
wd = wd_ratio

return perturb, wd

return perturb, wd

def step(self, closure: OptLossClosure = None) -> OptFloat:
r"""Performs a single optimization step.
Arguments:
closure: A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']

for p in group['params']:
if p.grad is None:
continue

grad = p.grad.data
state = self.state[p]

# State initialization
if len(state) == 0:
state['momentum'] = torch.zeros_like(
p.data, memory_format=torch.preserve_format
)

# SGD
buf = state['momentum']
buf.mul_(momentum).add_(grad, alpha=1 - dampening)
if nesterov:
d_p = grad + momentum * buf
else:
d_p = buf

# Projection
wd_ratio = 1
if len(p.shape) > 1:
d_p, wd_ratio = self._projection(
p,
grad,
d_p,
group['delta'],
group['wd_ratio'],
group['eps'],
)

# Weight decay
if weight_decay != 0:
p.data.mul_(
1
- group['lr']
* group['weight_decay']
* wd_ratio
/ (1 - momentum)
)

# Step
p.data.add_(d_p, alpha=-group['lr'])

return loss

0 comments on commit ec2c0f6

Please sign in to comment.