Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add end_lr in WarmupCosineSchedule #6662

Merged
merged 3 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions monai/optimizers/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
optimizer: Optimizer,
warmup_steps: int,
t_total: int,
end_lr: float = 0.0,
cycles: float = 0.5,
last_epoch: int = -1,
warmup_multiplier: float = 0,
Expand All @@ -77,6 +78,7 @@ def __init__(
optimizer: wrapped optimizer.
warmup_steps: number of warmup iterations.
t_total: total number of training iterations.
end_lr: the final learning rate. Defaults to 0.0.
cycles: cosine cycles parameter.
last_epoch: the index of last epoch.
warmup_multiplier: if provided, starts the linear warmup from this fraction of the initial lr.
Expand All @@ -88,6 +90,7 @@ def __init__(
self.warmup_multiplier = warmup_multiplier
self.t_total = t_total
self.cycles = cycles
self.end_lr = end_lr
if warmup_multiplier < 0 or warmup_multiplier > 1:
raise ValueError("warmup_multiplier must be in 0..1 range")
super().__init__(optimizer, self.lr_lambda, last_epoch)
Expand All @@ -98,3 +101,10 @@ def lr_lambda(self, step):
return self.warmup_multiplier + (1 - self.warmup_multiplier) * f
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))

def get_lr(self):
current_lr = [base_lr * lmbda(self.last_epoch) for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
if self.last_epoch < self.warmup_steps:
return current_lr
else:
return [max(self.end_lr, _current_lr) for _current_lr in current_lr]
4 changes: 4 additions & 0 deletions tests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def forward(self, x):
{"warmup_steps": 2, "t_total": 10, "warmup_multiplier": 0.1},
[0.1, 0.55, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038],
],
[
{"warmup_steps": 2, "t_total": 10, "warmup_multiplier": 0.1, "end_lr": 0.309},
[0.1, 0.55, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.309, 0.309],
],
]


Expand Down