diff --git a/monai/optimizers/lr_scheduler.py b/monai/optimizers/lr_scheduler.py index 83412c61ea..cb047f8bc5 100644 --- a/monai/optimizers/lr_scheduler.py +++ b/monai/optimizers/lr_scheduler.py @@ -62,7 +62,13 @@ class WarmupCosineSchedule(LambdaLR): """ def __init__( - self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1 + self, + optimizer: Optimizer, + warmup_steps: int, + t_total: int, + cycles: float = 0.5, + last_epoch: int = -1, + warmup_multiplier: float = 0, ) -> None: """ Args: @@ -71,16 +77,22 @@ def __init__( t_total: total number of training iterations. cycles: cosine cycles parameter. last_epoch: the index of last epoch. + warmup_multiplier: if provided, starts the linear warmup from this fraction of the intial lr. + Must be in 0..1 interval. Defaults to 0 Returns: None """ - self.warmup_steps = warmup_steps + self.warmup_steps = min(max(warmup_steps, 0), t_total) + self.warmup_multiplier = warmup_multiplier self.t_total = t_total self.cycles = cycles + if warmup_multiplier < 0 or warmup_multiplier > 1: + raise ValueError("warmup_multiplier must be in 0..1 range") super().__init__(optimizer, self.lr_lambda, last_epoch) def lr_lambda(self, step): if step < self.warmup_steps: - return float(step) / float(max(1.0, self.warmup_steps)) + f = float(step) / float(max(1.0, self.warmup_steps)) + return self.warmup_multiplier + (1 - self.warmup_multiplier) * f progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index a3e1ea9dd6..44f4c50c0f 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -28,7 +28,11 @@ def forward(self, x): TEST_CASE_LRSCHEDULER = [ - [{"warmup_steps": 2, "t_total": 10}, [0.000, 0.500, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038]] + [{"warmup_steps": 2, "t_total": 10}, [0.000, 0.500, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038]], + [ + {"warmup_steps": 2, "t_total": 10, "warmup_multiplier": 0.1}, + [0.1, 0.55, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038], + ], ] @@ -47,6 +51,13 @@ def test_shape(self, input_param, expected_lr): for a, b in zip(lrs_1, expected_lr): self.assertEqual(a, b, msg=f"LR is wrong ! expected {b}, got {a}") + def test_error(self): + """Should fail because warmup_multiplier is outside 0..1""" + net = SchedulerTestNet() + optimizer = torch.optim.Adam(net.parameters(), lr=1.0) + with self.assertRaises(ValueError): + WarmupCosineSchedule(optimizer, warmup_steps=2, t_total=10, warmup_multiplier=-1) + if __name__ == "__main__": unittest.main()