Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))


- Fixed `StochasticWeightAveraging` with a list of learning rates not applying them to each param group ([#8747](https://github.com/PyTorchLightning/pytorch-lightning/issues/8747))


- Fixed truncated backprop through time enablement when set as a property on the LightningModule and not the Trainer ([#8804](https://github.com/PyTorchLightning/pytorch-lightning/pull/8804/))


Expand Down
25 changes: 9 additions & 16 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,25 +166,18 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
# move average model to request device.
self._average_model = self._average_model.to(self._device or pl_module.device)

optimizers = trainer.optimizers
optimizer = trainer.optimizers[0]
if self._swa_lrs is None:
self._swa_lrs = [param_group["lr"] for param_group in optimizer.param_groups]
if isinstance(self._swa_lrs, float):
self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups)

for param_group in optimizers[0].param_groups:
if self._swa_lrs is None:
initial_lr = param_group["lr"]

elif isinstance(self._swa_lrs, float):
initial_lr = self._swa_lrs

else:
initial_lr = self._swa_lrs[0]

param_group["initial_lr"] = initial_lr

self._swa_lrs = initial_lr
for lr, group in zip(self._swa_lrs, optimizer.param_groups):
group["initial_lr"] = lr

self._swa_scheduler = SWALR(
optimizers[0],
swa_lr=initial_lr,
optimizer,
swa_lr=self._swa_lrs,
anneal_epochs=self._annealing_epochs,
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
Expand Down
38 changes: 37 additions & 1 deletion tests/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def configure_optimizers(self):
trainer.fit(model)
if use_callbacks or stochastic_weight_avg:
assert sum(1 for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)) == 1
assert trainer.callbacks[0]._swa_lrs == (1e-3 if use_callbacks else 0.1)
assert trainer.callbacks[0]._swa_lrs == [1e-3 if use_callbacks else 0.1]
else:
assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks)

Expand All @@ -237,3 +237,39 @@ def on_before_accelerator_backend_setup(self, trainer: "Trainer", pl_module: "Li
trainer = Trainer(default_root_dir=tmpdir, callbacks=swa, fast_dev_run=True)
trainer.fit(model, train_dataloader=DataLoader(RandomDataset(32, 2)))
assert swa.on_before_accelerator_backend_setup_called


def test_swa_multiple_lrs(tmpdir):
swa_lrs = [0.123, 0.321]

class TestModel(BoringModel):
def __init__(self):
super(BoringModel, self).__init__()
self.layer1 = torch.nn.Linear(32, 32)
self.layer2 = torch.nn.Linear(32, 2)

def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x

def configure_optimizers(self):
params = [{"params": self.layer1.parameters(), "lr": 0.1}, {"params": self.layer2.parameters(), "lr": 0.2}]
return torch.optim.Adam(params)

def on_train_epoch_start(self):
optimizer = trainer.optimizers[0]
assert [pg["lr"] for pg in optimizer.param_groups] == [0.1, 0.2]
assert [pg["initial_lr"] for pg in optimizer.param_groups] == swa_lrs
assert [pg["swa_lr"] for pg in optimizer.param_groups] == swa_lrs
self.on_train_epoch_start_called = True

model = TestModel()
swa_callback = StochasticWeightAveraging(swa_lrs=swa_lrs)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=swa_callback,
fast_dev_run=1,
)
trainer.fit(model)
assert model.on_train_epoch_start_called