-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[BUG] Check environ before selecting a seed to prevent warning message #4743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dd26ea6
73badb4
4528af9
f75d7f7
3c87893
3b92392
94efd3a
7745541
60611d9
a5e788f
038d413
dbcc9c3
6fe921b
7e74e10
34b6570
93c231c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,8 +20,8 @@ | |||||||||||
|
|
||||||||||||
| import numpy as np | ||||||||||||
| import torch | ||||||||||||
|
|
||||||||||||
| from pytorch_lightning import _logger as log | ||||||||||||
| from pytorch_lightning.utilities import rank_zero_warn | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def seed_everything(seed: Optional[int] = None) -> int: | ||||||||||||
|
|
@@ -41,18 +41,17 @@ def seed_everything(seed: Optional[int] = None) -> int: | |||||||||||
|
|
||||||||||||
| try: | ||||||||||||
| if seed is None: | ||||||||||||
| seed = os.environ.get("PL_GLOBAL_SEED", _select_seed_randomly(min_seed_value, max_seed_value)) | ||||||||||||
| seed = os.environ.get("PL_GLOBAL_SEED") | ||||||||||||
| seed = int(seed) | ||||||||||||
| except (TypeError, ValueError): | ||||||||||||
| seed = _select_seed_randomly(min_seed_value, max_seed_value) | ||||||||||||
| rank_zero_warn(f"No correct seed found, seed set to {seed}") | ||||||||||||
|
Comment on lines
47
to
+48
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
|
||||||||||||
| if (seed > max_seed_value) or (seed < min_seed_value): | ||||||||||||
| log.warning( | ||||||||||||
| f"{seed} is not in bounds, \ | ||||||||||||
| numpy accepts from {min_seed_value} to {max_seed_value}" | ||||||||||||
| ) | ||||||||||||
| if not (min_seed_value <= seed <= max_seed_value): | ||||||||||||
| rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") | ||||||||||||
| seed = _select_seed_randomly(min_seed_value, max_seed_value) | ||||||||||||
|
|
||||||||||||
| log.info(f"Global seed set to {seed}") | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will log on each rank, is this desired?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was going back and forth with myself about this, I think this is a good idea to ensure that the seed is set correctly on all processes, but I can be persuaded otherwise
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine :) |
||||||||||||
| os.environ["PL_GLOBAL_SEED"] = str(seed) | ||||||||||||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
| random.seed(seed) | ||||||||||||
| np.random.seed(seed) | ||||||||||||
|
|
@@ -62,6 +61,4 @@ def seed_everything(seed: Optional[int] = None) -> int: | |||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: | ||||||||||||
| seed = random.randint(min_seed_value, max_seed_value) | ||||||||||||
| log.warning(f"No correct seed found, seed set to {seed}") | ||||||||||||
| return seed | ||||||||||||
| return random.randint(min_seed_value, max_seed_value) | ||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| import os | ||
|
|
||
| from unittest import mock | ||
| import pytest | ||
|
|
||
| import pytorch_lightning.utilities.seed as seed_utils | ||
|
|
||
|
|
||
| @mock.patch.dict(os.environ, {}, clear=True) | ||
| def test_seed_stays_same_with_multiple_seed_everything_calls(): | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Ensure that after the initial seed everything, | ||
| the seed stays the same for the same run. | ||
| """ | ||
| with pytest.warns(UserWarning, match="No correct seed found"): | ||
| seed_utils.seed_everything() | ||
| initial_seed = os.environ.get("PL_GLOBAL_SEED") | ||
|
|
||
| with pytest.warns(None) as record: | ||
| seed_utils.seed_everything() | ||
| assert not record # does not warn | ||
| seed = os.environ.get("PL_GLOBAL_SEED") | ||
|
|
||
| assert initial_seed == seed | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True) | ||
| def test_correct_seed_with_environment_variable(): | ||
| """ | ||
| Ensure that the PL_GLOBAL_SEED environment is read | ||
| """ | ||
| assert seed_utils.seed_everything() == 2020 | ||
|
|
||
|
|
||
| @mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True) | ||
| @mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123) | ||
| def test_invalid_seed(): | ||
| """ | ||
| Ensure that we still fix the seed even if an invalid seed is given | ||
| """ | ||
| with pytest.warns(UserWarning, match="No correct seed found"): | ||
| seed = seed_utils.seed_everything() | ||
| assert seed == 123 | ||
|
|
||
|
|
||
| @mock.patch.dict(os.environ, {}, clear=True) | ||
| @mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123) | ||
| @pytest.mark.parametrize("seed", (10e9, -10e9)) | ||
| def test_out_of_bounds_seed(seed): | ||
| """ | ||
| Ensure that we still fix the seed even if an out-of-bounds seed is given | ||
| """ | ||
| with pytest.warns(UserWarning, match="is not in bounds"): | ||
| actual = seed_utils.seed_everything(seed) | ||
| assert actual == 123 | ||
Uh oh!
There was an error while loading. Please reload this page.