/
seed.py
45 lines (35 loc) · 1.37 KB
/
seed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""Helper functions to help with reproducibility of models. """
import os
from typing import Optional, Type
import numpy as np
import random
import torch
from pytorch_lightning import _logger as log
def seed_everything(seed: Optional[int] = None) -> int:
"""Function that sets seed for pseudo-random number generators in:
pytorch, numpy, python.random and sets PYTHONHASHSEED environment variable.
"""
max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min
try:
if seed is None:
seed = _select_seed_randomly(min_seed_value, max_seed_value)
else:
seed = int(seed)
except (TypeError, ValueError):
seed = _select_seed_randomly(min_seed_value, max_seed_value)
if (seed > max_seed_value) or (seed < min_seed_value):
log.warning(
f"{seed} is not in bounds, \
numpy accepts from {min_seed_value} to {max_seed_value}"
)
seed = _select_seed_randomly(min_seed_value, max_seed_value)
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
return seed
def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int:
seed = random.randint(min_seed_value, max_seed_value)
log.warning(f"No correct seed found, seed set to {seed}")
return seed