From 948757917fab5e31d63d53967d873d1e2309f2cc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Jun 2023 07:21:53 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/utilities/seed.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index 6ed3b2fbef8ff..fce85d733419e 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -19,6 +19,7 @@ if _LIGHTNING_XPU_AVAILABLE: from lightning_xpu.fabric import XPUAccelerator + def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: """Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, sets the following environment variables: @@ -118,9 +119,8 @@ def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) -> } if include_cuda: states["torch.cuda"] = torch.cuda.get_rng_state_all() - if include_xpu: - if XPUAccelerator.is_available(): - states["torch.xpu"] = XPUAccelerator._collect_rng_states() + if include_xpu and XPUAccelerator.is_available(): + states["torch.xpu"] = XPUAccelerator._collect_rng_states() return states @@ -131,9 +131,8 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: # torch.cuda rng_state is only included since v1.8. if "torch.cuda" in rng_state_dict: torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"]) - if "torch.xpu" in rng_state_dict: - if XPUAccelerator.is_available(): - XPUAccelerator._set_rng_states(rng_state_dict) + if "torch.xpu" in rng_state_dict and XPUAccelerator.is_available(): + XPUAccelerator._set_rng_states(rng_state_dict) np.random.set_state(rng_state_dict["numpy"]) version, state, gauss = rng_state_dict["python"] python_set_rng_state((version, tuple(state), gauss))