Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 81 additions & 22 deletions gptqmodel/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,51 @@ class BalanceStrategy(str, Enum):
except BaseException:
pass

BACKENDS_HAS_FP32_PRECISION = hasattr(torch.backends, "fp32_precision")


def _set_tf32_state(enabled: bool) -> None:
if BACKENDS_HAS_FP32_PRECISION:
mode = "tf32" if enabled else "ieee"
torch.backends.fp32_precision = mode
torch.backends.cuda.matmul.fp32_precision = mode
torch.backends.cudnn.fp32_precision = mode
torch.backends.cudnn.conv.fp32_precision = mode
torch.backends.cudnn.rnn.fp32_precision = mode
return

torch.backends.cuda.matmul.allow_tf32 = enabled
torch.backends.cudnn.allow_tf32 = enabled


def _snapshot_tf32_state():
if BACKENDS_HAS_FP32_PRECISION:
return (
torch.backends.fp32_precision,
torch.backends.cuda.matmul.fp32_precision,
torch.backends.cudnn.fp32_precision,
torch.backends.cudnn.conv.fp32_precision,
torch.backends.cudnn.rnn.fp32_precision,
)

return (
torch.backends.cuda.matmul.allow_tf32,
torch.backends.cudnn.allow_tf32,
)


def _restore_tf32_state(state) -> None:
if BACKENDS_HAS_FP32_PRECISION:
torch.backends.fp32_precision = state[0]
torch.backends.cuda.matmul.fp32_precision = state[1]
torch.backends.cudnn.fp32_precision = state[2]
torch.backends.cudnn.conv.fp32_precision = state[3]
torch.backends.cudnn.rnn.fp32_precision = state[4]
return

torch.backends.cuda.matmul.allow_tf32 = state[0]
torch.backends.cudnn.allow_tf32 = state[1]

def torch_compile(module: Union[torch.nn.Module, Callable], backend:str ="inductor", mode: str = None, fullgraph=False):
# requires torch >2.8 for proper torch.compile + Python 3.13.3t (freethreading)
if has_gil_disabled() and not gte_python_3_13_3():
Expand Down Expand Up @@ -248,24 +293,31 @@ def tf32_enable_guard():
yield
return

if torch.backends.fp32_precision == "tf32":
if BACKENDS_HAS_FP32_PRECISION:
if torch.backends.fp32_precision == "tf32":
yield
return

previous_state = _snapshot_tf32_state()
_set_tf32_state(True)

try:
yield
finally:
_restore_tf32_state(previous_state)
return

previous_state = _snapshot_tf32_state()
if previous_state[0] and previous_state[1]:
yield
return

torch.backends.fp32_precision = "tf32"
torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.fp32_precision = "tf32"
torch.backends.cudnn.conv.fp32_precision = "tf32"
torch.backends.cudnn.rnn.fp32_precision = "tf32"
_set_tf32_state(True)

try:
yield
finally:
torch.backends.fp32_precision = "ieee"
torch.backends.cuda.matmul.fp32_precision = "ieee"
torch.backends.cudnn.fp32_precision = "ieee"
torch.backends.cudnn.conv.fp32_precision = "ieee"
torch.backends.cudnn.rnn.fp32_precision = "ieee"
_restore_tf32_state(previous_state)


@contextmanager
Expand All @@ -274,21 +326,28 @@ def tf32_disable_guard():
yield
return

if torch.backends.fp32_precision == "ieee":
if BACKENDS_HAS_FP32_PRECISION:
if torch.backends.fp32_precision == "ieee":
yield
return

previous_state = _snapshot_tf32_state()
_set_tf32_state(False)

try:
yield
finally:
_restore_tf32_state(previous_state)
return

previous_state = _snapshot_tf32_state()
if not previous_state[0] and not previous_state[1]:
yield
return

torch.backends.fp32_precision = "ieee"
torch.backends.cuda.matmul.fp32_precision = "ieee"
torch.backends.cudnn.fp32_precision = "ieee"
torch.backends.cudnn.conv.fp32_precision = "ieee"
torch.backends.cudnn.rnn.fp32_precision = "ieee"
_set_tf32_state(False)

try:
yield
finally:
torch.backends.fp32_precision = "tf32"
torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.fp32_precision = "tf32"
torch.backends.cudnn.conv.fp32_precision = "tf32"
torch.backends.cudnn.rnn.fp32_precision = "tf32"
_restore_tf32_state(previous_state)