-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
test_supporters.py
416 lines (352 loc) · 15.1 KB
/
test_supporters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from typing import Sequence
from unittest import mock
import pytest
import torch
from torch import Tensor
from torch.utils._pytree import tree_flatten
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.trainer.supporters import (
_CombinedDataset,
_MaxSizeCycle,
_MinSize,
_Sequential,
_supported_modes,
CombinedLoader,
)
from tests_pytorch.helpers.runif import RunIf
@pytest.mark.parametrize(
["dataset_1", "dataset_2"],
[
([list(range(10)), list(range(20))]),
([range(10), range(20)]),
([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]),
([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))]),
],
)
def test_combined_dataset(dataset_1, dataset_2):
"""Verify the length of the CombinedDataset."""
datasets = [dataset_1, dataset_2]
combined_dataset = _CombinedDataset(datasets, "max_size_cycle")
assert len(combined_dataset) == 20
combined_dataset = _CombinedDataset(datasets, "min_size")
assert len(combined_dataset) == 10
def test_combined_dataset_length_mode_error():
with pytest.raises(ValueError, match="Unsupported mode 'test'"):
_CombinedDataset([], mode="test")
def test_combined_dataset_no_length():
class Foo:
# map-style
def __len__(self):
return 5
class Bar:
# iterable style
...
class Baz:
# None length
def __len__(self):
pass
cd = _CombinedDataset([Foo(), Bar(), Baz()])
assert len(cd) == 5
cd = _CombinedDataset(Bar)
with pytest.raises(NotImplementedError, match="All datasets are iterable-style"):
len(cd)
def test_combined_loader_modes():
"""Test `CombinedLoaderIterator` given mapping iterables."""
iterables = {
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
}
lengths = [len(v) for v in iterables.values()]
# min_size with dict
min_len = min(lengths)
combined_loader = CombinedLoader(iterables, "min_size")
assert combined_loader._iterator is None
assert len(combined_loader) == min_len
for idx, item in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _MinSize)
assert isinstance(item, dict)
assert list(item) == ["a", "b"]
assert idx == min_len - 1
assert idx == len(combined_loader) - 1
# max_size_cycle with dict
max_len = max(lengths)
combined_loader = CombinedLoader(iterables, "max_size_cycle")
assert combined_loader._iterator is None
assert len(combined_loader) == max_len
for idx, item in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _MaxSizeCycle)
assert isinstance(item, dict)
assert list(item) == ["a", "b"]
assert idx == max_len - 1
assert idx == len(combined_loader) - 1
# sequential with dict
sum_len = sum(lengths)
combined_loader = CombinedLoader(iterables, "sequential")
assert combined_loader._iterator is None
assert len(combined_loader) == sum_len
for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _Sequential)
assert isinstance(batch_idx, int)
assert isinstance(item, Tensor)
assert idx == lengths[-1] - 1
assert total_idx == sum_len - 1
assert total_idx == len(combined_loader) - 1
assert dataloader_idx == len(iterables) - 1
iterables = list(iterables.values())
# min_size with list
combined_loader = CombinedLoader(iterables, "min_size")
assert len(combined_loader) == min_len
for idx, item in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _MinSize)
assert isinstance(item, list)
assert len(item) == 2
assert idx == min_len - 1
assert idx == len(combined_loader) - 1
# max_size_cycle with list
combined_loader = CombinedLoader(iterables, "max_size_cycle")
assert len(combined_loader) == max_len
for idx, item in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _MaxSizeCycle)
assert isinstance(item, list)
assert len(item) == 2
assert idx == max_len - 1
assert idx == len(combined_loader) - 1
# sequential with list
combined_loader = CombinedLoader(iterables, "sequential")
assert combined_loader._iterator is None
assert len(combined_loader) == sum_len
for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _Sequential)
assert isinstance(batch_idx, int)
assert isinstance(item, Tensor)
assert idx == lengths[-1] - 1
assert total_idx == sum_len - 1
assert total_idx == len(combined_loader) - 1
assert dataloader_idx == len(iterables) - 1
def test_combined_loader_raises():
with pytest.raises(ValueError, match="Unsupported mode 'testtt'"):
CombinedLoader([range(10)], "testtt")
combined_loader = CombinedLoader(None, "max_size_cycle")
with pytest.raises(NotImplementedError, match="NoneType` does not define `__len__"):
len(combined_loader)
class TestIterableDataset(IterableDataset):
def __init__(self, size: int = 10):
self.size = size
def __iter__(self):
self.sampler = SequentialSampler(range(self.size))
self.sampler_iter = iter(self.sampler)
return self
def __next__(self):
return next(self.sampler_iter)
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "sequential"])
@pytest.mark.parametrize("use_multiple_dataloaders", [False, True])
def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders):
"""Test `CombinedLoader` of mode 'min_size' given sequence iterables."""
if use_multiple_dataloaders:
loaders = [
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
torch.utils.data.DataLoader(TestIterableDataset(20), batch_size=2),
]
else:
loaders = [
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
]
combined_loader = CombinedLoader(loaders, mode)
has_break = False
for idx, item in enumerate(combined_loader):
assert isinstance(item, Sequence)
if not use_multiple_dataloaders and idx == 4:
has_break = True
break
if mode == "max_size_cycle":
assert all(combined_loader._iterator._consumed) == (not has_break)
expected = 5
if use_multiple_dataloaders:
if mode == "max_size_cycle":
expected = 10
elif mode == "sequential":
expected = 15
assert idx == expected - 1
@pytest.mark.parametrize(
("limits", "expected"),
[
(None, [("a", 0, 0), ("b", 1, 0), ("c", 2, 0), ("d", 0, 1), ("e", 1, 1)]),
([1, 0], [("a", 0, 0)]),
([0, float("inf")], [("d", 0, 1), ("e", 1, 1)]),
([1, 1], [("a", 0, 0), ("d", 0, 1)]),
],
)
def test_sequential_mode_limits(limits, expected):
iterable1 = ["a", "b", "c"]
iterable2 = ["d", "e"]
iterator = _Sequential([iterable1, iterable2], limits)
assert list(iterator) == expected
def test_sequential_mode_limits_raises():
with pytest.raises(ValueError, match=r"number of limits \(0\) and number of iterables \(2\)"):
_Sequential([0, 1], [])
@pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]])
def test_combined_loader_sequence_with_map_and_iterable(lengths):
class MyIterableDataset(IterableDataset):
def __init__(self, size: int = 10):
self.size = size
def __iter__(self):
self.sampler = SequentialSampler(range(self.size))
self.iter_sampler = iter(self.sampler)
return self
def __next__(self):
return next(self.iter_sampler)
class MyMapDataset(Dataset):
def __init__(self, size: int = 10):
self.size = size
def __getitem__(self, index):
return index
def __len__(self):
return self.size
x, y = lengths
loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))]
dataloader = CombinedLoader(loaders, mode="max_size_cycle")
seen = sum(1 for _ in dataloader)
assert seen == max(x, y)
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})
@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
def test_combined_data_loader_validation_test(mps_count_0, cuda_count_2, replace_sampler_ddp):
"""This test makes sure distributed sampler has been properly injected in dataloaders when using
CombinedLoader."""
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
class CustomSampler(RandomSampler):
def __init__(self, data_source, name) -> None:
super().__init__(data_source)
self.name = name
dataset = CustomDataset(range(10))
dataloader = CombinedLoader(
{
"a": DataLoader(CustomDataset(range(10))),
"b": DataLoader(dataset, sampler=CustomSampler(dataset, "custom_sampler")),
"c": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))},
"d": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))],
}
)
trainer = Trainer(replace_sampler_ddp=replace_sampler_ddp, strategy="ddp", accelerator="gpu", devices=2)
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=True)
samplers_flattened = tree_flatten(dataloader.sampler)[0]
assert len(samplers_flattened) == 6
if replace_sampler_ddp:
assert all(isinstance(s, DistributedSampler) for s in samplers_flattened)
else:
assert all(isinstance(s, (SequentialSampler, CustomSampler)) for s in samplers_flattened)
datasets_flattened = tree_flatten(dataloader.dataset.datasets)[0]
assert len(datasets_flattened) == 6
assert all(isinstance(ds, CustomDataset) for ds in datasets_flattened)
@pytest.mark.parametrize("accelerator", ["cpu", pytest.param("gpu", marks=RunIf(min_cuda_gpus=2))])
@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
def test_combined_data_loader_with_max_size_cycle_and_ddp(accelerator, replace_sampler_ddp):
"""This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader
with ddp and `max_size_cycle` mode."""
trainer = Trainer(strategy="ddp", accelerator=accelerator, devices=2, replace_sampler_ddp=replace_sampler_ddp)
dataloader = CombinedLoader(
{"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)},
)
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader) == 4 if replace_sampler_ddp else 8
for a_length in [6, 8, 10]:
dataloader = CombinedLoader(
{
"a": DataLoader(range(a_length), batch_size=1),
"b": DataLoader(range(8), batch_size=1),
},
mode="max_size_cycle",
)
length = max(a_length, 8)
assert len(dataloader) == length
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader) == length // 2 if replace_sampler_ddp else length
if replace_sampler_ddp:
last_batch = list(dataloader)[-1]
if a_length == 6:
assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])}
elif a_length == 8:
assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])}
elif a_length == 10:
assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])}
class InfiniteDataset(IterableDataset):
def __iter__(self):
while True:
yield 1
dataloader = CombinedLoader(
{
"a": DataLoader(InfiniteDataset(), batch_size=1),
"b": DataLoader(range(8), batch_size=1),
},
mode="max_size_cycle",
)
with pytest.raises(NotImplementedError, match="DataLoader` does not define `__len__"):
len(dataloader)
assert len(dataloader.iterables["b"]) == 8
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader.iterables["b"]) == 4 if replace_sampler_ddp else 8
with pytest.raises(NotImplementedError, match="DataLoader` does not define `__len__"):
len(dataloader)
@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
@pytest.mark.parametrize("mode", ("min_size", "max_size_cycle", "sequential"))
@pytest.mark.parametrize("use_combined_loader", [False, True])
def test_combined_dataloader_for_training_with_ddp(replace_sampler_ddp, mode, use_combined_loader):
"""When providing a CombinedLoader as the training data, it should be correctly receive the distributed
samplers."""
dim = 3
n1 = 8
n2 = 6
dataloader = {
"a": DataLoader(RandomDataset(dim, n1), batch_size=1),
"b": DataLoader(RandomDataset(dim, n2), batch_size=1),
}
if use_combined_loader:
dataloader = CombinedLoader(dataloader, mode=mode)
model = BoringModel()
trainer = Trainer(
strategy="ddp",
accelerator="auto",
devices="auto",
replace_sampler_ddp=replace_sampler_ddp,
multiple_trainloader_mode=mode,
)
trainer._data_connector.attach_data(
model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None
)
fn = _supported_modes[mode]["fn"]
expected_length_before_ddp = fn([n1, n2])
expected_length_after_ddp = (
math.ceil(expected_length_before_ddp / trainer.num_devices)
if replace_sampler_ddp
else expected_length_before_ddp
)
trainer.reset_train_dataloader(model=model)
assert trainer.train_dataloader is not None
assert isinstance(trainer.train_dataloader, CombinedLoader)
assert trainer.train_dataloader._mode == mode
assert trainer.num_training_batches == expected_length_after_ddp