From 893b65486b17211b93d0e446ea08a4a1389dc9f1 Mon Sep 17 00:00:00 2001 From: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> Date: Tue, 31 Oct 2023 13:16:45 +0530 Subject: [PATCH] Fix precision device --- .../pytorch/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index ef570777e75df..1b524d3d15741 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -589,7 +589,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin: rank_zero_info( f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + device = "cpu" if self._accelerator_flag == "cpu" else "xpu" if _lightning_xpu_available() else "cuda" return MixedPrecisionPlugin(self._precision_flag, device) # type: ignore[arg-type] raise RuntimeError("No precision set")