diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index ef570777e75df..1b524d3d15741 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -589,7 +589,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin: rank_zero_info( f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + device = "cpu" if self._accelerator_flag == "cpu" else "xpu" if _lightning_xpu_available() else "cuda" return MixedPrecisionPlugin(self._precision_flag, device) # type: ignore[arg-type] raise RuntimeError("No precision set")