diff --git a/gptqmodel/utils/linalg_warmup.py b/gptqmodel/utils/linalg_warmup.py index cace622f8..10bfacd00 100644 --- a/gptqmodel/utils/linalg_warmup.py +++ b/gptqmodel/utils/linalg_warmup.py @@ -44,7 +44,11 @@ def run_torch_linalg_warmup(device: torch.device) -> None: still runs once per physical device so backend-specific handles are initialized where needed. """ with _GLOBAL_WARMUP_LOCK: - dtypes = (torch.float32, torch.float64) + if device.type == "mps": + dtypes = (torch.float32,) # MPS backend does not implement float64. + else: + dtypes = (torch.float32, torch.float64) + for dtype in dtypes: _run_cholesky_and_eigh(device, dtype) _run_svd(device, dtype)