|
23 | 23 | from pytorch_lightning import Trainer |
24 | 24 | from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6 |
25 | 25 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
26 | | -from tests.helpers import BoringModel, RandomDataset |
| 26 | +from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset |
27 | 27 | from tests.helpers.runif import RunIf |
28 | 28 |
|
29 | 29 | if _TORCH_GREATER_EQUAL_1_6: |
|
33 | 33 |
|
34 | 34 | class SwaTestModel(BoringModel): |
35 | 35 |
|
36 | | - def __init__(self, batchnorm: bool = True, interval: str = "epoch"): |
| 36 | + def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False): |
37 | 37 | super().__init__() |
38 | 38 | layers = [nn.Linear(32, 32)] |
39 | 39 | if batchnorm: |
40 | 40 | layers.append(nn.BatchNorm1d(32)) |
41 | 41 | layers += [nn.ReLU(), nn.Linear(32, 2)] |
42 | 42 | self.layer = nn.Sequential(*layers) |
43 | 43 | self.interval = interval |
| 44 | + self.iterable_dataset = iterable_dataset |
44 | 45 |
|
45 | 46 | def training_step(self, batch, batch_idx): |
46 | 47 | output = self.forward(batch) |
47 | 48 | loss = self.loss(batch, output) |
48 | 49 | return {"loss": loss} |
49 | 50 |
|
50 | 51 | def train_dataloader(self): |
51 | | - return DataLoader(RandomDataset(32, 64), batch_size=2) |
| 52 | + |
| 53 | + dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset |
| 54 | + dset = dset_cls(32, 64) |
| 55 | + |
| 56 | + return DataLoader(dset, batch_size=2) |
52 | 57 |
|
53 | 58 | def configure_optimizers(self): |
54 | 59 | optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) |
@@ -107,8 +112,10 @@ def on_train_end(self, trainer, pl_module): |
107 | 112 |
|
108 | 113 |
|
109 | 114 | @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) |
110 | | -def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, interval="epoch"): |
111 | | - model = SwaTestModel(batchnorm=batchnorm, interval=interval) |
| 115 | +def train_with_swa( |
| 116 | + tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, interval="epoch", iterable_dataset=False |
| 117 | +): |
| 118 | + model = SwaTestModel(batchnorm=batchnorm, interval=interval, iterable_dataset=iterable_dataset) |
112 | 119 | swa_start = 2 |
113 | 120 | max_epochs = 5 |
114 | 121 | swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) |
@@ -154,8 +161,9 @@ def test_swa_callback_1_gpu(tmpdir): |
154 | 161 |
|
155 | 162 | @RunIf(min_torch="1.6.0") |
156 | 163 | @pytest.mark.parametrize("batchnorm", (True, False)) |
157 | | -def test_swa_callback(tmpdir, batchnorm: bool): |
158 | | - train_with_swa(tmpdir, batchnorm=batchnorm) |
| 164 | +@pytest.mark.parametrize('iterable_dataset', (True, False)) |
| 165 | +def test_swa_callback(tmpdir, batchnorm: bool, iterable_dataset: bool): |
| 166 | + train_with_swa(tmpdir, batchnorm=batchnorm, iterable_dataset=iterable_dataset) |
159 | 167 |
|
160 | 168 |
|
161 | 169 | @RunIf(min_torch="1.6.0") |
|
0 commit comments