diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e790c38cfcd6..a80c38bc0ade5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -150,6 +150,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333)) +- Fixed `StochasticWeightAveraging` with a list of learning rates not applying them to each param group ([#8747](https://github.com/PyTorchLightning/pytorch-lightning/issues/8747)) + + - Fixed truncated backprop through time enablement when set as a property on the LightningModule and not the Trainer ([#8804](https://github.com/PyTorchLightning/pytorch-lightning/pull/8804/)) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index c1d9ee7840829..12a9ac8275adb 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -166,25 +166,18 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo # move average model to request device. self._average_model = self._average_model.to(self._device or pl_module.device) - optimizers = trainer.optimizers + optimizer = trainer.optimizers[0] + if self._swa_lrs is None: + self._swa_lrs = [param_group["lr"] for param_group in optimizer.param_groups] + if isinstance(self._swa_lrs, float): + self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups) - for param_group in optimizers[0].param_groups: - if self._swa_lrs is None: - initial_lr = param_group["lr"] - - elif isinstance(self._swa_lrs, float): - initial_lr = self._swa_lrs - - else: - initial_lr = self._swa_lrs[0] - - param_group["initial_lr"] = initial_lr - - self._swa_lrs = initial_lr + for lr, group in zip(self._swa_lrs, optimizer.param_groups): + group["initial_lr"] = lr self._swa_scheduler = SWALR( - optimizers[0], - swa_lr=initial_lr, + optimizer, + swa_lr=self._swa_lrs, anneal_epochs=self._annealing_epochs, anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index be4cbb14e7344..0bfaa359bb1a8 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -211,7 +211,7 @@ def configure_optimizers(self): trainer.fit(model) if use_callbacks or stochastic_weight_avg: assert sum(1 for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)) == 1 - assert trainer.callbacks[0]._swa_lrs == (1e-3 if use_callbacks else 0.1) + assert trainer.callbacks[0]._swa_lrs == [1e-3 if use_callbacks else 0.1] else: assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks) @@ -237,3 +237,39 @@ def on_before_accelerator_backend_setup(self, trainer: "Trainer", pl_module: "Li trainer = Trainer(default_root_dir=tmpdir, callbacks=swa, fast_dev_run=True) trainer.fit(model, train_dataloader=DataLoader(RandomDataset(32, 2))) assert swa.on_before_accelerator_backend_setup_called + + +def test_swa_multiple_lrs(tmpdir): + swa_lrs = [0.123, 0.321] + + class TestModel(BoringModel): + def __init__(self): + super(BoringModel, self).__init__() + self.layer1 = torch.nn.Linear(32, 32) + self.layer2 = torch.nn.Linear(32, 2) + + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + return x + + def configure_optimizers(self): + params = [{"params": self.layer1.parameters(), "lr": 0.1}, {"params": self.layer2.parameters(), "lr": 0.2}] + return torch.optim.Adam(params) + + def on_train_epoch_start(self): + optimizer = trainer.optimizers[0] + assert [pg["lr"] for pg in optimizer.param_groups] == [0.1, 0.2] + assert [pg["initial_lr"] for pg in optimizer.param_groups] == swa_lrs + assert [pg["swa_lr"] for pg in optimizer.param_groups] == swa_lrs + self.on_train_epoch_start_called = True + + model = TestModel() + swa_callback = StochasticWeightAveraging(swa_lrs=swa_lrs) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=swa_callback, + fast_dev_run=1, + ) + trainer.fit(model) + assert model.on_train_epoch_start_called