From dd26ea66522209769ab19c7cdf062445339f5dcc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 18 Nov 2020 16:20:02 +0000 Subject: [PATCH 1/9] Check environment var independently to selecting a seed to prevent unnecessary warning message --- pytorch_lightning/utilities/seed.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 1ce782f967ebb..6ccf353a61405 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -40,8 +40,9 @@ def seed_everything(seed: Optional[int] = None) -> int: min_seed_value = np.iinfo(np.uint32).min try: + seed = os.environ.get("PL_GLOBAL_SEED") if seed is None: - seed = os.environ.get("PL_GLOBAL_SEED", _select_seed_randomly(min_seed_value, max_seed_value)) + seed = _select_seed_randomly(min_seed_value, max_seed_value) seed = int(seed) except (TypeError, ValueError): seed = _select_seed_randomly(min_seed_value, max_seed_value) From 73badb478a0c47ce571d7d6d1f196886cacee357 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 18 Nov 2020 16:52:04 +0000 Subject: [PATCH 2/9] Add if statement to check if PL_GLOBAL_SEED has been set --- pytorch_lightning/utilities/seed.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 6ccf353a61405..29c093bb7e389 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -40,9 +40,11 @@ def seed_everything(seed: Optional[int] = None) -> int: min_seed_value = np.iinfo(np.uint32).min try: - seed = os.environ.get("PL_GLOBAL_SEED") if seed is None: - seed = _select_seed_randomly(min_seed_value, max_seed_value) + if "PL_GLOBAL_SEED" in os.environ: + seed = os.environ["PL_GLOBAL_SEED"] + else: + seed = _select_seed_randomly(min_seed_value, max_seed_value) seed = int(seed) except (TypeError, ValueError): seed = _select_seed_randomly(min_seed_value, max_seed_value) From 4528af94ba7653289118bd6ca675293f6a45198d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 18 Nov 2020 17:39:33 +0000 Subject: [PATCH 3/9] Added seed test to ensure that the seed stays the same, in case --- tests/utilities/test_seed.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 tests/utilities/test_seed.py diff --git a/tests/utilities/test_seed.py b/tests/utilities/test_seed.py new file mode 100644 index 0000000000000..02ea8930935dc --- /dev/null +++ b/tests/utilities/test_seed.py @@ -0,0 +1,17 @@ +import os + +import pytorch_lightning as pl + + +def test_seed_stays_same_with_multiple_seed_everything_calls(): + """ + Test to ensure that after the initial seed everything, the seed stays the same for the same run. + """ + + pl.utilities.seed.seed_everything() + initial_seed = os.environ.get('PL_GLOBAL_SEED') + + pl.utilities.seed.seed_everything() + seed = os.environ.get('PL_GLOBAL_SEED') + + assert initial_seed == seed From f75d7f70a800900c778d01323f463c8137d09110 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 18 Nov 2020 20:17:40 +0100 Subject: [PATCH 4/9] if --- pytorch_lightning/utilities/seed.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 29c093bb7e389..bfcd450ded2dd 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -49,10 +49,9 @@ def seed_everything(seed: Optional[int] = None) -> int: except (TypeError, ValueError): seed = _select_seed_randomly(min_seed_value, max_seed_value) - if (seed > max_seed_value) or (seed < min_seed_value): + if not (seed < min_seed_value <= seed <= max_seed_value): log.warning( - f"{seed} is not in bounds, \ - numpy accepts from {min_seed_value} to {max_seed_value}" + f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}" ) seed = _select_seed_randomly(min_seed_value, max_seed_value) From 3c878938c455d0d4d7e038d68a5a1fb229d01c11 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 18 Nov 2020 21:38:17 +0000 Subject: [PATCH 5/9] Delete global seed after test has finished --- tests/utilities/test_seed.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/utilities/test_seed.py b/tests/utilities/test_seed.py index 02ea8930935dc..d3344e3a58fa3 100644 --- a/tests/utilities/test_seed.py +++ b/tests/utilities/test_seed.py @@ -15,3 +15,5 @@ def test_seed_stays_same_with_multiple_seed_everything_calls(): seed = os.environ.get('PL_GLOBAL_SEED') assert initial_seed == seed + + del os.environ['PL_GLOBAL_SEED'] From 3b92392ac8c761fb62f39c03e49ad04dc8434170 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 19 Nov 2020 01:23:09 +0100 Subject: [PATCH 6/9] Fix code, add tests --- pytorch_lightning/utilities/seed.py | 18 ++++------ tests/utilities/test_seed.py | 55 +++++++++++++++++++++++++---- 2 files changed, 54 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index bfcd450ded2dd..b495f1399a581 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -21,7 +21,7 @@ import numpy as np import torch -from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import rank_zero_warn def seed_everything(seed: Optional[int] = None) -> int: @@ -41,18 +41,14 @@ def seed_everything(seed: Optional[int] = None) -> int: try: if seed is None: - if "PL_GLOBAL_SEED" in os.environ: - seed = os.environ["PL_GLOBAL_SEED"] - else: - seed = _select_seed_randomly(min_seed_value, max_seed_value) + seed = os.environ.get("PL_GLOBAL_SEED") seed = int(seed) except (TypeError, ValueError): seed = _select_seed_randomly(min_seed_value, max_seed_value) + rank_zero_warn(f"No correct seed found, seed set to {seed}") - if not (seed < min_seed_value <= seed <= max_seed_value): - log.warning( - f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}" - ) + if not (min_seed_value <= seed <= max_seed_value): + rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") seed = _select_seed_randomly(min_seed_value, max_seed_value) os.environ["PL_GLOBAL_SEED"] = str(seed) @@ -64,6 +60,4 @@ def seed_everything(seed: Optional[int] = None) -> int: def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: - seed = random.randint(min_seed_value, max_seed_value) - log.warning(f"No correct seed found, seed set to {seed}") - return seed + return random.randint(min_seed_value, max_seed_value) diff --git a/tests/utilities/test_seed.py b/tests/utilities/test_seed.py index d3344e3a58fa3..51571d0243ad8 100644 --- a/tests/utilities/test_seed.py +++ b/tests/utilities/test_seed.py @@ -1,19 +1,60 @@ import os -import pytorch_lightning as pl +import pytest + +import pytorch_lightning.utilities.seed as seed_utils def test_seed_stays_same_with_multiple_seed_everything_calls(): """ - Test to ensure that after the initial seed everything, the seed stays the same for the same run. + Ensure that after the initial seed everything, + the seed stays the same for the same run. """ - pl.utilities.seed.seed_everything() - initial_seed = os.environ.get('PL_GLOBAL_SEED') + with pytest.warns(UserWarning, match="No correct seed found"): + seed_utils.seed_everything() + initial_seed = os.environ.get("PL_GLOBAL_SEED") - pl.utilities.seed.seed_everything() - seed = os.environ.get('PL_GLOBAL_SEED') + with pytest.warns(None) as record: + seed_utils.seed_everything() + assert not record # does not warn + seed = os.environ.get("PL_GLOBAL_SEED") assert initial_seed == seed + del os.environ["PL_GLOBAL_SEED"] + + +def test_correct_seed_with_environment_variable(monkeypatch): + """ + Ensure that the PL_GLOBAL_SEED environment is read + """ + expected = 2020 + monkeypatch.setenv("PL_GLOBAL_SEED", str(expected)) + assert seed_utils.seed_everything() == expected + del os.environ["PL_GLOBAL_SEED"] + - del os.environ['PL_GLOBAL_SEED'] +def test_invalid_seed(monkeypatch): + """ + Ensure that we still fix the seed even if an invalid seed is given + """ + expected = 123 + monkeypatch.setenv("PL_GLOBAL_SEED", "invalid") + monkeypatch.setattr(seed_utils, "_select_seed_randomly", lambda *_: expected) + with pytest.warns(UserWarning, match="No correct seed found"): + seed = seed_utils.seed_everything() + assert seed == expected + del os.environ["PL_GLOBAL_SEED"] + + +@pytest.mark.parametrize("seed", (10e9, -10e9)) +def test_out_of_bounds_seed(monkeypatch, seed): + """ + Ensure that we still fix the seed even if an out-of-bounds seed is given + """ + expected = 123 + monkeypatch.setattr(seed_utils, "_select_seed_randomly", lambda *_: expected) + with pytest.warns(UserWarning, match="is not in bounds"): + actual = seed_utils.seed_everything(seed) + assert actual == expected + del os.environ["PL_GLOBAL_SEED"] From 7745541380e25b43f702510418227fd761076ece Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 20 Nov 2020 17:00:49 +0000 Subject: [PATCH 7/9] Ensure seed does not exist before tests start --- tests/utilities/test_seed.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/utilities/test_seed.py b/tests/utilities/test_seed.py index 51571d0243ad8..0e7efc5172dd0 100644 --- a/tests/utilities/test_seed.py +++ b/tests/utilities/test_seed.py @@ -10,6 +10,8 @@ def test_seed_stays_same_with_multiple_seed_everything_calls(): Ensure that after the initial seed everything, the seed stays the same for the same run. """ + if "PL_GLOBAL_SEED" in os.environ: + del os.environ["PL_GLOBAL_SEED"] with pytest.warns(UserWarning, match="No correct seed found"): seed_utils.seed_everything() @@ -28,6 +30,8 @@ def test_correct_seed_with_environment_variable(monkeypatch): """ Ensure that the PL_GLOBAL_SEED environment is read """ + if "PL_GLOBAL_SEED" in os.environ: + del os.environ["PL_GLOBAL_SEED"] expected = 2020 monkeypatch.setenv("PL_GLOBAL_SEED", str(expected)) assert seed_utils.seed_everything() == expected @@ -38,6 +42,8 @@ def test_invalid_seed(monkeypatch): """ Ensure that we still fix the seed even if an invalid seed is given """ + if "PL_GLOBAL_SEED" in os.environ: + del os.environ["PL_GLOBAL_SEED"] expected = 123 monkeypatch.setenv("PL_GLOBAL_SEED", "invalid") monkeypatch.setattr(seed_utils, "_select_seed_randomly", lambda *_: expected) @@ -52,6 +58,8 @@ def test_out_of_bounds_seed(monkeypatch, seed): """ Ensure that we still fix the seed even if an out-of-bounds seed is given """ + if "PL_GLOBAL_SEED" in os.environ: + del os.environ["PL_GLOBAL_SEED"] expected = 123 monkeypatch.setattr(seed_utils, "_select_seed_randomly", lambda *_: expected) with pytest.warns(UserWarning, match="is not in bounds"): From 038d41396983c0767872e5739204b82e6458ce94 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 7 Jan 2021 17:00:52 +0000 Subject: [PATCH 8/9] Refactor test based on review, add log call --- pytorch_lightning/utilities/seed.py | 3 ++- tests/utilities/test_seed.py | 39 ++++++++++------------------- 2 files changed, 15 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index b495f1399a581..16bc39bd7f142 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -20,7 +20,7 @@ import numpy as np import torch - +from pytorch_lightning import _logger as log from pytorch_lightning.utilities import rank_zero_warn @@ -51,6 +51,7 @@ def seed_everything(seed: Optional[int] = None) -> int: rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") seed = _select_seed_randomly(min_seed_value, max_seed_value) + log.info(f"Global seed set to {seed}") os.environ["PL_GLOBAL_SEED"] = str(seed) random.seed(seed) np.random.seed(seed) diff --git a/tests/utilities/test_seed.py b/tests/utilities/test_seed.py index 0e7efc5172dd0..7787757a91c68 100644 --- a/tests/utilities/test_seed.py +++ b/tests/utilities/test_seed.py @@ -1,18 +1,17 @@ import os +from unittest import mock import pytest import pytorch_lightning.utilities.seed as seed_utils +@mock.patch.dict(os.environ, {}) def test_seed_stays_same_with_multiple_seed_everything_calls(): """ Ensure that after the initial seed everything, the seed stays the same for the same run. """ - if "PL_GLOBAL_SEED" in os.environ: - del os.environ["PL_GLOBAL_SEED"] - with pytest.warns(UserWarning, match="No correct seed found"): seed_utils.seed_everything() initial_seed = os.environ.get("PL_GLOBAL_SEED") @@ -23,46 +22,34 @@ def test_seed_stays_same_with_multiple_seed_everything_calls(): seed = os.environ.get("PL_GLOBAL_SEED") assert initial_seed == seed - del os.environ["PL_GLOBAL_SEED"] -def test_correct_seed_with_environment_variable(monkeypatch): +@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}) +def test_correct_seed_with_environment_variable(): """ Ensure that the PL_GLOBAL_SEED environment is read """ - if "PL_GLOBAL_SEED" in os.environ: - del os.environ["PL_GLOBAL_SEED"] - expected = 2020 - monkeypatch.setenv("PL_GLOBAL_SEED", str(expected)) - assert seed_utils.seed_everything() == expected - del os.environ["PL_GLOBAL_SEED"] + assert seed_utils.seed_everything() == 2020 -def test_invalid_seed(monkeypatch): +@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}) +@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123) +def test_invalid_seed(): """ Ensure that we still fix the seed even if an invalid seed is given """ - if "PL_GLOBAL_SEED" in os.environ: - del os.environ["PL_GLOBAL_SEED"] - expected = 123 - monkeypatch.setenv("PL_GLOBAL_SEED", "invalid") - monkeypatch.setattr(seed_utils, "_select_seed_randomly", lambda *_: expected) with pytest.warns(UserWarning, match="No correct seed found"): seed = seed_utils.seed_everything() - assert seed == expected - del os.environ["PL_GLOBAL_SEED"] + assert seed == 123 +@mock.patch.dict(os.environ, {}) +@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123) @pytest.mark.parametrize("seed", (10e9, -10e9)) -def test_out_of_bounds_seed(monkeypatch, seed): +def test_out_of_bounds_seed(seed): """ Ensure that we still fix the seed even if an out-of-bounds seed is given """ - if "PL_GLOBAL_SEED" in os.environ: - del os.environ["PL_GLOBAL_SEED"] - expected = 123 - monkeypatch.setattr(seed_utils, "_select_seed_randomly", lambda *_: expected) with pytest.warns(UserWarning, match="is not in bounds"): actual = seed_utils.seed_everything(seed) - assert actual == expected - del os.environ["PL_GLOBAL_SEED"] + assert actual == 123 From 34b6570fe80e9f077db9250f3417eaf411615ea5 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 11 Jan 2021 13:39:21 +0000 Subject: [PATCH 9/9] Ensure we clear the os environ in patched dict --- tests/utilities/test_seed.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/utilities/test_seed.py b/tests/utilities/test_seed.py index 7787757a91c68..7fa6df516c304 100644 --- a/tests/utilities/test_seed.py +++ b/tests/utilities/test_seed.py @@ -6,7 +6,7 @@ import pytorch_lightning.utilities.seed as seed_utils -@mock.patch.dict(os.environ, {}) +@mock.patch.dict(os.environ, {}, clear=True) def test_seed_stays_same_with_multiple_seed_everything_calls(): """ Ensure that after the initial seed everything, @@ -24,7 +24,7 @@ def test_seed_stays_same_with_multiple_seed_everything_calls(): assert initial_seed == seed -@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}) +@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True) def test_correct_seed_with_environment_variable(): """ Ensure that the PL_GLOBAL_SEED environment is read @@ -32,7 +32,7 @@ def test_correct_seed_with_environment_variable(): assert seed_utils.seed_everything() == 2020 -@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}) +@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True) @mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123) def test_invalid_seed(): """ @@ -43,7 +43,7 @@ def test_invalid_seed(): assert seed == 123 -@mock.patch.dict(os.environ, {}) +@mock.patch.dict(os.environ, {}, clear=True) @mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123) @pytest.mark.parametrize("seed", (10e9, -10e9)) def test_out_of_bounds_seed(seed):