Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto-set DataLoader.worker_init_fn with seed_everything #6960

Merged
merged 41 commits into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
18e61ee
example
awaelchli Apr 11, 2021
588f581
auto add worker fn
awaelchli Apr 11, 2021
bf92793
Revert "example"
awaelchli Apr 11, 2021
49f3200
revert
awaelchli Apr 11, 2021
74893bd
typo
awaelchli Apr 11, 2021
49c1a64
flake
awaelchli Apr 11, 2021
50d35ad
add worker_id
awaelchli Apr 11, 2021
873584e
typo
awaelchli Apr 11, 2021
b71d16c
workers argument in seed_everyting
awaelchli Apr 11, 2021
23f28f3
add global rank for worker_init_fn
awaelchli Apr 11, 2021
59416c9
include torch.manual_seed
awaelchli Apr 11, 2021
3b247b8
fix env var access
awaelchli Apr 12, 2021
5ed6791
ddp test
awaelchli Apr 12, 2021
08e5524
suggestion by r.kern
awaelchli Apr 14, 2021
b7828ec
incorporate seedsequence for torch and stdlib seed
awaelchli Apr 14, 2021
f97997f
Merge branch 'master' into feature/worker_init_fn
awaelchli Apr 14, 2021
b3af3d1
strict assert
awaelchli Apr 15, 2021
e4cda36
strict assert
awaelchli Apr 15, 2021
2f461e7
remove print statement
awaelchli Apr 15, 2021
398ddef
fix global rank issues
awaelchli Apr 15, 2021
becb26b
unused import
awaelchli Apr 15, 2021
c9bfc55
Merge branch 'master' into feature/worker_init_fn
awaelchli Apr 15, 2021
a999b95
ignore coverage for worker function
awaelchli Apr 15, 2021
896bed7
Update tests/trainer/test_dataloaders.py
awaelchli Apr 15, 2021
5b656aa
Update pytorch_lightning/trainer/data_loading.py
awaelchli Apr 15, 2021
69deaf8
Update pytorch_lightning/trainer/data_loading.py
awaelchli Apr 15, 2021
60e0669
update req
awaelchli Apr 15, 2021
2eb979e
Merge remote-tracking branch 'origin/feature/worker_init_fn' into fea…
awaelchli Apr 15, 2021
9e0ac38
32-bit seeding for pytorch prior 1.7
awaelchli Apr 16, 2021
e56d1b0
flake8
awaelchli Apr 16, 2021
e8844b8
test duplicates
awaelchli Apr 16, 2021
cc9f818
update docs
awaelchli Apr 16, 2021
4d6d56b
Update requirements.txt
carmocca Apr 16, 2021
0a62d55
Update requirements.txt
carmocca Apr 16, 2021
1fa020f
Update requirements.txt
carmocca Apr 16, 2021
2124ef5
Update requirements.txt
carmocca Apr 16, 2021
4456bf8
change default
awaelchli Apr 16, 2021
708581a
remove sanity test
awaelchli Apr 16, 2021
fd59239
unused import
awaelchli Apr 16, 2021
99799a0
Revert "remove sanity test"
awaelchli Apr 16, 2021
8514821
better sanity check
awaelchli Apr 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# limitations under the License.
import inspect
import multiprocessing
import os
from abc import ABC
from copy import deepcopy

from functools import partial
from typing import Iterable, List, Optional, Tuple, Union

from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
Expand All @@ -30,6 +33,7 @@
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import pl_worker_init_function


class TrainerDataLoadingMixin(ABC):
Expand Down Expand Up @@ -100,6 +104,10 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
f' in the `DataLoader` init to improve performance.'
)

def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None:
if dataloader.worker_init_fn is None and int(os.environ.get("PL_SEED_WORKERS", "0")):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)

def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:

# don't do anything if it's not a dataloader
Expand Down Expand Up @@ -231,6 +239,9 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
# check the workers recursively
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader')

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode)

Expand Down Expand Up @@ -329,6 +340,9 @@ def _reset_eval_dataloader(
# add samplers
dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None]

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)

loader_num_batches = []

# determine number of batches
Expand Down
40 changes: 37 additions & 3 deletions pytorch_lightning/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,30 @@

import numpy as np
import torch
from pytorch_lightning.utilities.distributed import rank_zero_only

from pytorch_lightning.utilities import rank_zero_warn

log = logging.getLogger(__name__)


def seed_everything(seed: Optional[int] = None) -> int:
def seed_everything(seed: Optional[int] = None, workers: bool = True) -> int:
"""
Function that sets seed for pseudo-random number generators in:
pytorch, numpy, python.random
In addition, sets the env variable `PL_GLOBAL_SEED` which will be passed to
spawned subprocesses (e.g. ddp_spawn backend).
In addition, sets the following environment variables:

- `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend).
- `PL_SEED_WORKERS`: (optional) is set to 1 if ```workers=True``.

Args:
seed: the integer value seed for global random state in Lightning.
If `None`, will read seed from `PL_GLOBAL_SEED` env variable
or select it randomly.
workers: if set to ``True``, will properly configure all dataloaders passed to the
Trainer with a ``worker_init_fn``. If the user already provides such a function
for their dataloaders, setting this argument will have no influence. See also:
:func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`.
"""
max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min
Expand All @@ -61,8 +68,35 @@ def seed_everything(seed: Optional[int] = None) -> int:
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"

return seed


def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int:
return random.randint(min_seed_value, max_seed_value)


def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None:
"""
The worker_init_fn that Lightning automatically adds to your dataloader if you previously set
set the seed with :func:`~pytorch_lightning.utilities.seed.seed_everything`.
See also the PyTorch documentation on
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_.
"""
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
global_rank = rank if rank is not None else rank_zero_only.rank
process_seed = torch.initial_seed()
# back out the base seed so we can use all the bits
base_seed = process_seed - worker_id
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
# use 128 bits (4 x 32-bit words)
np.random.seed(ss.generate_state(4))
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
torch_ss, stdlib_ss = ss.spawn(2)
# PyTorch takes a 64-bit seed
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
# use 128 bits expressed as an integer
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
random.seed(stdlib_seed)
7 changes: 3 additions & 4 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler, SequentialSampler
Expand Down Expand Up @@ -72,7 +71,7 @@ def test_dataloader(self):
return [self.create_dataset()] * self._numbers_test_dataloaders


def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator, gpus, num_dl_idx, mode):
def check_replace_distributed_sampler(tmpdir, save_preds_on_dl_idx, accelerator, gpus, num_dl_idx, mode):
num_processes = 2
limit_test_batches = 2
trainer_args = {
Expand Down Expand Up @@ -100,8 +99,8 @@ def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator,

@RunIf(min_gpus=2, special=True)
@pytest.mark.parametrize("mode", [1, 2])
def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode):
check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode)
def test_replace_distributed_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode):
check_replace_distributed_sampler(tmpdir, True, "ddp", 2, 2, mode)


@pytest.mark.parametrize("num_workers", [0, 1])
Expand Down
92 changes: 89 additions & 3 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import partial
from unittest import mock
from unittest.mock import patch
from unittest.mock import patch, Mock

import numpy
import pytest
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import IterableDataset, Subset
from torch.utils.data.dataset import IterableDataset, Subset, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler

import tests.helpers.pipelines as tpipes
from pytorch_lightning import Callback, Trainer
from pytorch_lightning import Callback, Trainer, seed_everything
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import pl_worker_init_function
from tests.base import EvalModelTemplate
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -634,6 +638,88 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl)


def _user_worker_init_fn(_):
pass


def test_auto_add_worker_init_fn():
""" Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used. """
dataset = Mock()
dataloader = DataLoader(dataset)
trainer = Trainer()

# without pl.seed_everything()
trainer.auto_add_worker_init_fn(dataloader)
assert dataloader.worker_init_fn is None

# with forcefully avoiding it
seed_everything(0, workers=False)
trainer.auto_add_worker_init_fn(dataloader)
assert dataloader.worker_init_fn is None

# when user already has a worker_init_fn
user_function = _user_worker_init_fn
dataloader.worker_init_fn = user_function
trainer.auto_add_worker_init_fn(dataloader)
assert dataloader.worker_init_fn is user_function
dataloader.worker_init_fn = None

# main use case
seed_everything(0, workers=True)
trainer.auto_add_worker_init_fn(dataloader)
assert dataloader.worker_init_fn is not None


class NumpyRandomDataset(Dataset):
size = 16

def __getitem__(self, index):
return numpy.random.randint(0, 100, 3)

def __len__(self):
return self.size


class MultiProcessModel(BoringModel):

def __init__(self):
super().__init__()
self.batches_seen = []

def training_step(self, batch, batch_idx):
self.batches_seen.append(batch)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def training_epoch_end(self, outputs):
world_size = 2
num_samples = NumpyRandomDataset.size
all_batches = torch.cat(self.batches_seen)
all_batches = self.all_gather(all_batches)
assert all_batches.shape[0] == world_size
all_batches = all_batches.view(-1, 3)
assert len(torch.unique(all_batches, dim=0)) == num_samples


@RunIf(min_gpus=2)
def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch):
""" Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training. """
dataset = NumpyRandomDataset()
num_workers = 2
batch_size = 2

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
seed_everything(0, workers=True)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
gpus=2,
accelerator="ddp_spawn",
)
model = MultiProcessModel()
model.train_dataloader = None
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
model.val_dataloader = None
trainer.fit(model, train_dataloader=dataloader)


def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
model = BoringModel()
Expand Down