From d46e4c54ae47b730d0805694849f106c41828e97 Mon Sep 17 00:00:00 2001 From: YONG Hongwei <30529442+Yonghongwei@users.noreply.github.com> Date: Wed, 3 Jun 2020 11:32:25 +0800 Subject: [PATCH] Update Adam.py --- GC_code/CIFAR100/algorithm/Adam.py | 115 +++++++++++++++++++++++++---- 1 file changed, 100 insertions(+), 15 deletions(-) diff --git a/GC_code/CIFAR100/algorithm/Adam.py b/GC_code/CIFAR100/algorithm/Adam.py index 5999acc..a940f90 100644 --- a/GC_code/CIFAR100/algorithm/Adam.py +++ b/GC_code/CIFAR100/algorithm/Adam.py @@ -110,7 +110,6 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. @@ -152,12 +151,7 @@ def step(self, closure=None): if group['weight_decay'] != 0: grad.add_(group['weight_decay'], p.data) - - #GC operation for Conv layers - if len(list(p.data.size()))==4: - weight_mean=p.data.mean(dim = tuple(range(1,len(list(grad.size())))), keepdim = True) - grad.add_(-grad.mean(dim = 1, keepdim = True).mean(dim = 2, keepdim = True).mean(dim = 3, keepdim = True)) - + # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) @@ -170,11 +164,13 @@ def step(self, closure=None): denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) step_size = group['lr'] / bias_correction1 - - p.data.addcdiv_(-step_size, exp_avg, denom) - # keep mean unchanged - if len(list(grad.size()))==4: - p.data.add_(-p.data.mean(dim = tuple(range(1,len(list(grad.size())))), keepdim = True)).add_(weight_mean) + #GC operation for Conv layers + if len(list(grad.size()))>3: + delta=(step_size*exp_avg/denom).clone() + delta.add_(-delta.mean(dim = tuple(range(1,len(list(grad.size())))), keepdim = True)) + p.data.add_(-delta) + else: + p.data.addcdiv_(-step_size, exp_avg, denom) return loss class Adam_GC(Optimizer): @@ -287,6 +283,90 @@ def step(self, closure=None): return loss +class Adam_GC2(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(Adam_GC2, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Adam_GC2, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + if group['weight_decay'] != 0: + grad.add_(group['weight_decay'], p.data) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + #GC operation for Conv layers and FC layers + if len(list(grad.size()))>1: + delta=(step_size*exp_avg/denom).clone() + delta.add_(-delta.mean(dim = tuple(range(1,len(list(grad.size())))), keepdim = True)) + p.data.add_(-delta) + else: + p.data.addcdiv_(-step_size, exp_avg, denom) + return loss class AdamW(Optimizer): """Implements Adam algorithm. @@ -597,8 +677,13 @@ def step(self, closure=None): bias_correction2 = 1 - beta2 ** state['step'] step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 - # p.data.addcdiv_(-step_size, exp_avg, denom) - p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom) ) + # GC operation for Conv layers if len(list(grad.size()))>3: - p.data.add_(-p.data.mean(dim = tuple(range(1,len(list(grad.size())))), keepdim = True)).add_(weight_mean) + delta=(step_size*torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom)).clone() + delta.add_(-delta.mean(dim = tuple(range(1,len(list(grad.size())))), keepdim = True)) + p.data.add_(-delta) + else: + p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom) ) + + return loss