Skip to content

Commit

Permalink
Noam lr sched: do not force min_lr after max_steps (#4472)
Browse files Browse the repository at this point in the history
Signed-off-by: Adrian Lancucki <alancucki@users.noreply.github.com>

Co-authored-by: Adrian Lancucki <alancucki@users.noreply.github.com>
Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
  • Loading branch information
3 people committed Jul 7, 2022
1 parent f21cf36 commit cf95f93
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
9 changes: 5 additions & 4 deletions nemo/core/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,6 @@ def get_lr(self):

step = max(1, self.last_epoch)

if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]

for initial_lr in self.base_lrs:
if initial_lr < self.min_lr:
raise ValueError(
Expand All @@ -485,7 +482,11 @@ def get_lr(self):
return new_lrs

def _noam_annealing(self, initial_lr, step):
mult = self._normalize * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5)))
if self.warmup_steps > 0:
mult = self._normalize * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5)))
else:
mult = self._normalize * step ** (-0.5)

out_lr = initial_lr * mult
if step > self.warmup_steps:
out_lr = max(out_lr, self.min_lr)
Expand Down
51 changes: 51 additions & 0 deletions tests/core/test_optimizers_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class TestOptimizersSchedulers:
INITIAL_LR = 0.1
MIN_LR = 1e-3
MAX_STEPS = 10
D_MODEL = 16

# fused_adam is looking for CUDA and this test is being run on CPU only tests
@pytest.mark.unit
Expand Down Expand Up @@ -616,6 +617,56 @@ def test_CosineAnnealing(self):

assert final_lr == self.MIN_LR

# Noam scheduler should decay past MAX_STEPS - run two schedulers in parallel to test it
@pytest.mark.unit
def test_NoamAnnealing(self):
model = TempModel()
opt_cls = optim.get_optimizer('novograd')
opt1 = opt_cls(model.parameters(), lr=self.INITIAL_LR)
opt2 = opt_cls(model.parameters(), lr=self.INITIAL_LR)

# No warmup case
policy1 = optim.lr_scheduler.NoamAnnealing(
opt1, d_model=self.D_MODEL, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR
)
policy2 = optim.lr_scheduler.NoamAnnealing(
opt2, d_model=self.D_MODEL, max_steps=self.MAX_STEPS * 2, min_lr=self.MIN_LR
)
initial_lr = policy1.get_last_lr()[0]

assert initial_lr == self.D_MODEL ** (-0.5) * self.INITIAL_LR

for i in range(self.MAX_STEPS * 2):
assert self.MIN_LR < policy1.get_last_lr()[0] <= self.INITIAL_LR
assert policy1.get_last_lr()[0] == policy2.get_last_lr()[0]
opt1.step()
opt2.step()
policy1.step()
policy2.step()

# Warmup steps available
policy1 = optim.lr_scheduler.NoamAnnealing(
opt1, d_model=self.D_MODEL, warmup_steps=5, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR
)
policy2 = optim.lr_scheduler.NoamAnnealing(
opt2, d_model=self.D_MODEL, warmup_steps=5, max_steps=self.MAX_STEPS * 2, min_lr=self.MIN_LR
)
initial_lr = policy1.get_last_lr()[0]

assert initial_lr < self.INITIAL_LR

for i in range(self.MAX_STEPS * 2):
if i <= 5:
assert policy1.get_last_lr()[0] <= self.INITIAL_LR
else:
assert self.MIN_LR < policy1.get_last_lr()[0] < self.INITIAL_LR
assert policy1.get_last_lr()[0] == policy2.get_last_lr()[0]

opt1.step()
opt2.step()
policy1.step()
policy2.step()

@pytest.mark.unit
def test_PolynomialDecayAnnealing(self):
model = TempModel()
Expand Down

0 comments on commit cf95f93

Please sign in to comment.