diff --git a/tests/parity_fabric/models.py b/tests/parity_fabric/models.py index 4887a4c7f7dba..be6345c158c9a 100644 --- a/tests/parity_fabric/models.py +++ b/tests/parity_fabric/models.py @@ -76,6 +76,7 @@ def get_dataloader(self): dataset, batch_size=self.batch_size, num_workers=2, + persistent_workers=True, ) def get_loss_function(self): diff --git a/tests/parity_pytorch/models.py b/tests/parity_pytorch/models.py index f55b0d6f1f36e..657db3e512bb2 100644 --- a/tests/parity_pytorch/models.py +++ b/tests/parity_pytorch/models.py @@ -59,4 +59,5 @@ def train_dataloader(self): CIFAR10(root=_PATH_DATASETS, train=True, download=True, transform=self.transform), batch_size=32, num_workers=1, + persistent_workers=True, ) diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 656b9cac3d77e..89f1e4114254f 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -638,9 +638,9 @@ def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size # The dataloader runs a check in `DataLoader.check_worker_number_rationality` with pytest.warns(UserWarning, match="This DataLoader will create"): - DataLoader(range(2), num_workers=(cpu_count + 1)) + DataLoader(range(2), num_workers=(cpu_count + 1), persistent_workers=True) with no_warning_call(): - DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size)) + DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size), persistent_workers=True) def test_state(): diff --git a/tests/tests_pytorch/callbacks/test_prediction_writer.py b/tests/tests_pytorch/callbacks/test_prediction_writer.py index 02604f5a195fe..3a3ea69fa9701 100644 --- a/tests/tests_pytorch/callbacks/test_prediction_writer.py +++ b/tests/tests_pytorch/callbacks/test_prediction_writer.py @@ -83,7 +83,9 @@ def test_prediction_writer_batch_indices(num_workers, tmp_path): DummyPredictionWriter.write_on_batch_end = Mock() DummyPredictionWriter.write_on_epoch_end = Mock() - dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers) + dataloader = DataLoader( + RandomDataset(32, 64), batch_size=4, num_workers=num_workers, persistent_workers=num_workers > 0 + ) model = BoringModel() writer = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=writer) diff --git a/tests/tests_pytorch/helpers/advanced_models.py b/tests/tests_pytorch/helpers/advanced_models.py index 4fecf516018c1..c1830cb121eb3 100644 --- a/tests/tests_pytorch/helpers/advanced_models.py +++ b/tests/tests_pytorch/helpers/advanced_models.py @@ -218,4 +218,9 @@ def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) def train_dataloader(self): - return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1) + return DataLoader( + MNIST(root=_PATH_DATASETS, train=True, download=True), + batch_size=128, + num_workers=1, + persistent_workers=True, + ) diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index a820a3d6ee786..54d236b46c000 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -135,7 +135,7 @@ def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path barebones=True, ) model = TestSpawnBoringModel(warning_expected=(num_workers > 0)) - dataloader = DataLoader(RandomDataset(32, 64), num_workers=num_workers) + dataloader = DataLoader(RandomDataset(32, 64), num_workers=num_workers, persistent_workers=num_workers > 0) trainer.fit(model, dataloader) @@ -252,7 +252,9 @@ def test_update_dataloader_with_multiprocessing_context(): """This test verifies that `use_distributed_sampler` conserves multiprocessing context.""" train = RandomDataset(32, 64) context = "spawn" - train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) + train = DataLoader( + train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True, persistent_workers=True + ) new_data_loader = _update_dataloader(train, SequentialSampler(train.dataset)) assert new_data_loader.multiprocessing_context == train.multiprocessing_context diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index a2d29baa9fa6f..31ff3dd56c65a 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -658,10 +658,7 @@ def on_train_epoch_end(self): def test_auto_add_worker_init_fn_distributed(tmp_path, monkeypatch): """Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training.""" dataset = NumpyRandomDataset() - num_workers = 2 - batch_size = 2 - - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + dataloader = DataLoader(dataset, batch_size=2, num_workers=2, persistent_workers=True) seed_everything(0, workers=True) trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp_spawn") model = MultiProcessModel()