Skip to content

Commit

Permalink
add choice to degenerated to sgd
Browse files Browse the repository at this point in the history
add choice to degenerated to sgd or freeze
  • Loading branch information
LiyuanLucasLiu committed Oct 10, 2019
1 parent 62cc2b2 commit 373b3e4
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class RAdam(Optimizer):

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
Expand All @@ -13,7 +13,8 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))


self.degenerated_to_sgd = degenerated_to_sgd
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)]
super(RAdam, self).__init__(params, defaults)
Expand Down Expand Up @@ -68,27 +69,30 @@ def step(self, closure=None):
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
elif self.degenerated_to_sgd:
step_size = 1.0 / (1 - beta1 ** state['step'])
else:
step_size = -1
buffered[2] = step_size

if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

# more conservative since it's an approximated value
if N_sma >= 5:
if N_sma >= 5:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
else:
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(-step_size * group['lr'], exp_avg)

p.data.copy_(p_data_fp32)
p.data.copy_(p_data_fp32)

return loss

class PlainRAdam(Optimizer):

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
Expand All @@ -97,7 +101,8 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))


self.degenerated_to_sgd = degenerated_to_sgd
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)

super(PlainRAdam, self).__init__(params, defaults)
Expand Down Expand Up @@ -143,19 +148,21 @@ def step(self, closure=None):
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)

if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

# more conservative since it's an approximated value
if N_sma >= 5:
if N_sma >= 5:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
else:
p.data.copy_(p_data_fp32)
elif self.degenerated_to_sgd:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
step_size = group['lr'] / (1 - beta1 ** state['step'])
p_data_fp32.add_(-step_size, exp_avg)

p.data.copy_(p_data_fp32)
p.data.copy_(p_data_fp32)

return loss

Expand Down

0 comments on commit 373b3e4

Please sign in to comment.