diff --git a/CHANGELOG.md b/CHANGELOG.md index 980d2a450f786..12c71941ad57c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -285,6 +285,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed SWA to also work with `IterableDataset` ([#8172](https://github.com/PyTorchLightning/pytorch-lightning/pull/8172)) + - Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877)) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 8ebfdc5d80cd6..0cd788c8c8647 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -222,7 +222,8 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo trainer.num_training_batches += 1 trainer.fit_loop._skip_backward = True self._accumulate_grad_batches = trainer.accumulate_grad_batches - trainer.accumulate_grad_batches = len(trainer.train_dataloader) + + trainer.accumulate_grad_batches = trainer.num_training_batches def on_train_epoch_end(self, trainer: 'pl.Trainer', *args): trainer.fit_loop._skip_backward = False diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index e92f8e71da086..8518fe16f0359 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -23,7 +23,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6 from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel, RandomDataset +from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.runif import RunIf if _TORCH_GREATER_EQUAL_1_6: @@ -33,7 +33,7 @@ class SwaTestModel(BoringModel): - def __init__(self, batchnorm: bool = True, interval: str = "epoch"): + def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False): super().__init__() layers = [nn.Linear(32, 32)] if batchnorm: @@ -41,6 +41,7 @@ def __init__(self, batchnorm: bool = True, interval: str = "epoch"): layers += [nn.ReLU(), nn.Linear(32, 2)] self.layer = nn.Sequential(*layers) self.interval = interval + self.iterable_dataset = iterable_dataset def training_step(self, batch, batch_idx): output = self.forward(batch) @@ -48,7 +49,11 @@ def training_step(self, batch, batch_idx): return {"loss": loss} def train_dataloader(self): - return DataLoader(RandomDataset(32, 64), batch_size=2) + + dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset + dset = dset_cls(32, 64) + + return DataLoader(dset, batch_size=2) def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -107,8 +112,10 @@ def on_train_end(self, trainer, pl_module): @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, interval="epoch"): - model = SwaTestModel(batchnorm=batchnorm, interval=interval) +def train_with_swa( + tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, interval="epoch", iterable_dataset=False +): + model = SwaTestModel(batchnorm=batchnorm, interval=interval, iterable_dataset=iterable_dataset) swa_start = 2 max_epochs = 5 swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) @@ -155,8 +162,9 @@ def test_swa_callback_1_gpu(tmpdir): @RunIf(min_torch="1.6.0") @pytest.mark.parametrize("batchnorm", (True, False)) -def test_swa_callback(tmpdir, batchnorm: bool): - train_with_swa(tmpdir, batchnorm=batchnorm) +@pytest.mark.parametrize('iterable_dataset', (True, False)) +def test_swa_callback(tmpdir, batchnorm: bool, iterable_dataset: bool): + train_with_swa(tmpdir, batchnorm=batchnorm, iterable_dataset=iterable_dataset) @RunIf(min_torch="1.6.0")