Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] authored and jingxu10 committed Jun 5, 2023
1 parent 15855ee commit 9487579
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/lightning/fabric/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
if _LIGHTNING_XPU_AVAILABLE:
from lightning_xpu.fabric import XPUAccelerator


def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
sets the following environment variables:
Expand Down Expand Up @@ -118,9 +119,8 @@ def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) ->
}
if include_cuda:
states["torch.cuda"] = torch.cuda.get_rng_state_all()
if include_xpu:
if XPUAccelerator.is_available():
states["torch.xpu"] = XPUAccelerator._collect_rng_states()
if include_xpu and XPUAccelerator.is_available():
states["torch.xpu"] = XPUAccelerator._collect_rng_states()
return states


Expand All @@ -131,9 +131,8 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
# torch.cuda rng_state is only included since v1.8.
if "torch.cuda" in rng_state_dict:
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
if "torch.xpu" in rng_state_dict:
if XPUAccelerator.is_available():
XPUAccelerator._set_rng_states(rng_state_dict)
if "torch.xpu" in rng_state_dict and XPUAccelerator.is_available():
XPUAccelerator._set_rng_states(rng_state_dict)
np.random.set_state(rng_state_dict["numpy"])
version, state, gauss = rng_state_dict["python"]
python_set_rng_state((version, tuple(state), gauss))

0 comments on commit 9487579

Please sign in to comment.