Skip to content

Commit

Permalink
align worker seed (#9041)
Browse files Browse the repository at this point in the history
* align worker seed

* refine
  • Loading branch information
Flowingsun007 authored Sep 2, 2022
1 parent 6758485 commit c43beda
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
7 changes: 1 addition & 6 deletions python/oneflow/utils/data/_utils/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,7 @@ def cleanup_shm_at_exit(num, frame):
flow.set_num_threads(1)
seed = base_seed + worker_id
random.seed(seed)
generator.manual_seed(seed)
if HAS_NUMPY:
np_seed = _generate_state(base_seed, worker_id)
import numpy as np

np.random.seed(np_seed)
generator.manual_seed(seed) # PyTorch use torch.manual_seed(seed)

global _worker_info
_worker_info = WorkerInfo(
Expand Down
5 changes: 4 additions & 1 deletion python/oneflow/utils/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import oneflow.multiprocessing as multiprocessing
from oneflow._utils import ExceptionWrapper
import oneflow as flow
import numpy as np

string_classes = (str, bytes)

Expand Down Expand Up @@ -502,8 +503,10 @@ def __init__(self, loader: DataLoader) -> None:
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._generator = loader.generator
self._base_seed = flow.tensor([0], dtype=flow.int64).uniform_().numpy().item()
# self._base_seed = flow.empty((), dtype=flow.int64).random_(generator=loader.generator).item()
self._base_seed = flow.randint(
0, np.iinfo(np.int64).max, (), generator=loader.generator
).item()
self._persistent_workers = loader.persistent_workers
self._num_yielded = 0
self._profile_name = "enumerate(DataLoader)#{}.__next__".format(
Expand Down

0 comments on commit c43beda

Please sign in to comment.