diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 0f17a57e1..4534b47c0 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -69,6 +69,51 @@ class BalanceStrategy(str, Enum): except BaseException: pass +BACKENDS_HAS_FP32_PRECISION = hasattr(torch.backends, "fp32_precision") + + +def _set_tf32_state(enabled: bool) -> None: + if BACKENDS_HAS_FP32_PRECISION: + mode = "tf32" if enabled else "ieee" + torch.backends.fp32_precision = mode + torch.backends.cuda.matmul.fp32_precision = mode + torch.backends.cudnn.fp32_precision = mode + torch.backends.cudnn.conv.fp32_precision = mode + torch.backends.cudnn.rnn.fp32_precision = mode + return + + torch.backends.cuda.matmul.allow_tf32 = enabled + torch.backends.cudnn.allow_tf32 = enabled + + +def _snapshot_tf32_state(): + if BACKENDS_HAS_FP32_PRECISION: + return ( + torch.backends.fp32_precision, + torch.backends.cuda.matmul.fp32_precision, + torch.backends.cudnn.fp32_precision, + torch.backends.cudnn.conv.fp32_precision, + torch.backends.cudnn.rnn.fp32_precision, + ) + + return ( + torch.backends.cuda.matmul.allow_tf32, + torch.backends.cudnn.allow_tf32, + ) + + +def _restore_tf32_state(state) -> None: + if BACKENDS_HAS_FP32_PRECISION: + torch.backends.fp32_precision = state[0] + torch.backends.cuda.matmul.fp32_precision = state[1] + torch.backends.cudnn.fp32_precision = state[2] + torch.backends.cudnn.conv.fp32_precision = state[3] + torch.backends.cudnn.rnn.fp32_precision = state[4] + return + + torch.backends.cuda.matmul.allow_tf32 = state[0] + torch.backends.cudnn.allow_tf32 = state[1] + def torch_compile(module: Union[torch.nn.Module, Callable], backend:str ="inductor", mode: str = None, fullgraph=False): # requires torch >2.8 for proper torch.compile + Python 3.13.3t (freethreading) if has_gil_disabled() and not gte_python_3_13_3(): @@ -248,24 +293,31 @@ def tf32_enable_guard(): yield return - if torch.backends.fp32_precision == "tf32": + if BACKENDS_HAS_FP32_PRECISION: + if torch.backends.fp32_precision == "tf32": + yield + return + + previous_state = _snapshot_tf32_state() + _set_tf32_state(True) + + try: + yield + finally: + _restore_tf32_state(previous_state) + return + + previous_state = _snapshot_tf32_state() + if previous_state[0] and previous_state[1]: yield return - torch.backends.fp32_precision = "tf32" - torch.backends.cuda.matmul.fp32_precision = "tf32" - torch.backends.cudnn.fp32_precision = "tf32" - torch.backends.cudnn.conv.fp32_precision = "tf32" - torch.backends.cudnn.rnn.fp32_precision = "tf32" + _set_tf32_state(True) try: yield finally: - torch.backends.fp32_precision = "ieee" - torch.backends.cuda.matmul.fp32_precision = "ieee" - torch.backends.cudnn.fp32_precision = "ieee" - torch.backends.cudnn.conv.fp32_precision = "ieee" - torch.backends.cudnn.rnn.fp32_precision = "ieee" + _restore_tf32_state(previous_state) @contextmanager @@ -274,21 +326,28 @@ def tf32_disable_guard(): yield return - if torch.backends.fp32_precision == "ieee": + if BACKENDS_HAS_FP32_PRECISION: + if torch.backends.fp32_precision == "ieee": + yield + return + + previous_state = _snapshot_tf32_state() + _set_tf32_state(False) + + try: + yield + finally: + _restore_tf32_state(previous_state) + return + + previous_state = _snapshot_tf32_state() + if not previous_state[0] and not previous_state[1]: yield return - torch.backends.fp32_precision = "ieee" - torch.backends.cuda.matmul.fp32_precision = "ieee" - torch.backends.cudnn.fp32_precision = "ieee" - torch.backends.cudnn.conv.fp32_precision = "ieee" - torch.backends.cudnn.rnn.fp32_precision = "ieee" + _set_tf32_state(False) try: yield finally: - torch.backends.fp32_precision = "tf32" - torch.backends.cuda.matmul.fp32_precision = "tf32" - torch.backends.cudnn.fp32_precision = "tf32" - torch.backends.cudnn.conv.fp32_precision = "tf32" - torch.backends.cudnn.rnn.fp32_precision = "tf32" + _restore_tf32_state(previous_state)