From ec2c0f61c32ff77eebf98bb1089a9e456165e52e Mon Sep 17 00:00:00 2001 From: Achaiah Date: Mon, 22 Nov 2021 16:04:48 -0600 Subject: [PATCH] adding sgdp --- pywick/models/model_utils.py | 15 +-- pywick/optimizers/sgdp.py | 181 +++++++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+), 6 deletions(-) create mode 100644 pywick/optimizers/sgdp.py diff --git a/pywick/models/model_utils.py b/pywick/models/model_utils.py index 3913431..2c2bf48 100644 --- a/pywick/models/model_utils.py +++ b/pywick/models/model_utils.py @@ -103,7 +103,7 @@ def get_model(model_type: ModelType, torch_hub_names = torch.hub.list(rwightman_repo, force_reload=force_reload) if model_name in torch_hub_names: model = torch.hub.load(rwightman_repo, model_name, pretrained=pretrained, num_classes=num_classes) - elif custom_load_fn is not None: + elif custom_load_fn: model = custom_load_fn(model_name, pretrained, num_classes, **kwargs) else: # 1. Load model (pretrained or vanilla) @@ -360,7 +360,7 @@ def diff_states(dict_canonical, dict_subset): yield (name, v1) -def load_checkpoint(checkpoint_path, model=None, device='cpu', strict=True, ignore_chkpt_layers=None): +def load_checkpoint(checkpoint_path: str, model=None, device='cpu', strict: bool = True, ignore_chkpt_layers=None, debug: bool = False): """ Loads weights from a checkpoint into memory. If model is not None then the weights are loaded into the model. @@ -401,12 +401,14 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac # load data directly from a checkpoint checkpoint_path = os.path.expanduser(checkpoint_path) if os.path.isfile(checkpoint_path): - print('=> Loading checkpoint: {} onto device: {}'.format(checkpoint_path, device)) + if debug: + print('=> Loading checkpoint: {} onto device: {}'.format(checkpoint_path, device)) checkpoint = torch.load(checkpoint_path, map_location=device) pretrained_state = checkpoint['state_dict'] - print("INFO: => loaded checkpoint {} (epoch {})".format(checkpoint_path, checkpoint.get('epoch'))) - print('INFO: => checkpoint model name: ', checkpoint.get('modelname', checkpoint.get('model_name')), ' Make sure the checkpoint model name matches your model!!!') + if debug: + print("INFO: => loaded checkpoint {} (epoch {})".format(checkpoint_path, checkpoint.get('epoch'))) + print('INFO: => checkpoint model name: ', checkpoint.get('modelname', checkpoint.get('model_name')), ' Make sure the checkpoint model name matches your model!!!') else: raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_path) @@ -431,7 +433,8 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac # finally load the model weights if model: - print('INFO: => Attempting to load checkpoint data onto model. Device: {} Strict: {}'.format(device, strict)) + if debug: + print('INFO: => Attempting to load checkpoint data onto model. Device: {} Strict: {}'.format(device, strict)) model.load_state_dict(checkpoint['state_dict'], strict=strict) return checkpoint diff --git a/pywick/optimizers/sgdp.py b/pywick/optimizers/sgdp.py new file mode 100644 index 0000000..0bb7f94 --- /dev/null +++ b/pywick/optimizers/sgdp.py @@ -0,0 +1,181 @@ +# Source: https://github.com/pytorch/pytorch/pull/3740 +import math + +import torch +from torch.optim.optimizer import Optimizer + +from .a2grad import OptFloat, OptLossClosure, Params + +__all__ = ('SGDP',) + + +class SGDP(Optimizer): + r"""Implements SGDP algorithm. + + It has been proposed in `Slowing Down the Weight Norm Increase in + Momentum-based Optimizers`__ + + Arguments: + params: iterable of parameters to optimize or dicts defining + parameter groups + lr: learning rate (default: 1e-3) + momentum: momentum factor (default: 0) + dampening: dampening for momentum (default: 0) + eps: term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay: weight decay (L2 penalty) (default: 0) + delta: threhold that determines whether a set of parameters is scale + invariant or not (default: 0.1) + wd_ratio: relative weight decay applied on scale-invariant parameters + compared to that applied on scale-variant parameters (default: 0.1) + nesterov: enables Nesterov momentum (default: False) + + __ https://arxiv.org/abs/2006.08217 + + Note: + Reference code: https://github.com/clovaai/AdamP + """ + + def __init__( + self, + params: Params, + lr: float = 1e-3, + momentum: float = 0, + dampening: float = 0, + eps: float = 1e-8, + weight_decay: float = 0, + delta: float = 0.1, + wd_ratio: float = 0.1, + nesterov: bool = False, + ) -> None: + if lr <= 0.0: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if eps < 0.0: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if momentum < 0.0: + raise ValueError('Invalid momentum value: {}'.format(momentum)) + if dampening < 0.0: + raise ValueError('Invalid dampening value: {}'.format(dampening)) + if weight_decay < 0: + raise ValueError( + 'Invalid weight_decay value: {}'.format(weight_decay) + ) + if delta < 0: + raise ValueError('Invalid delta value: {}'.format(delta)) + if wd_ratio < 0: + raise ValueError('Invalid wd_ratio value: {}'.format(wd_ratio)) + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + eps=eps, + weight_decay=weight_decay, + delta=delta, + wd_ratio=wd_ratio, + nesterov=nesterov, + ) + super(SGDP, self).__init__(params, defaults) + + @staticmethod + def _channel_view(x): + return x.view(x.size(0), -1) + + @staticmethod + def _layer_view(x): + return x.view(1, -1) + + @staticmethod + def _cosine_similarity(x, y, eps, view_func): + x = view_func(x) + y = view_func(y) + + x_norm = x.norm(dim=1).add_(eps) + y_norm = y.norm(dim=1).add_(eps) + dot = (x * y).sum(dim=1) + + return dot.abs() / x_norm / y_norm + + def _projection(self, p, grad, perturb, delta, wd_ratio, eps): + wd = 1 + expand_size = [-1] + [1] * (len(p.shape) - 1) + for view_func in [self._channel_view, self._layer_view]: + + cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) + + if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): + p_n = p.data / view_func(p.data).norm(dim=1).view( + expand_size + ).add_(eps) + perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view( + expand_size + ) + wd = wd_ratio + + return perturb, wd + + return perturb, wd + + def step(self, closure: OptLossClosure = None) -> OptFloat: + r"""Performs a single optimization step. + + Arguments: + closure: A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad.data + state = self.state[p] + + # State initialization + if len(state) == 0: + state['momentum'] = torch.zeros_like( + p.data, memory_format=torch.preserve_format + ) + + # SGD + buf = state['momentum'] + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + if nesterov: + d_p = grad + momentum * buf + else: + d_p = buf + + # Projection + wd_ratio = 1 + if len(p.shape) > 1: + d_p, wd_ratio = self._projection( + p, + grad, + d_p, + group['delta'], + group['wd_ratio'], + group['eps'], + ) + + # Weight decay + if weight_decay != 0: + p.data.mul_( + 1 + - group['lr'] + * group['weight_decay'] + * wd_ratio + / (1 - momentum) + ) + + # Step + p.data.add_(d_p, alpha=-group['lr']) + + return loss