From 93d16bcbb4487c77c68df27c807d54de1d047552 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 1 Oct 2025 02:03:35 +0000 Subject: [PATCH 1/5] test tf32 Signed-off-by: Qubitium --- tests/test_tf32_performance.py | 106 +++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 tests/test_tf32_performance.py diff --git a/tests/test_tf32_performance.py b/tests/test_tf32_performance.py new file mode 100644 index 000000000..e4b746a63 --- /dev/null +++ b/tests/test_tf32_performance.py @@ -0,0 +1,106 @@ +import pytest +import torch +from tabulate import tabulate + + +def _supports_bfloat16() -> bool: + major, _ = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0) + return major >= 8 # Ampere or newer + + +def _measure_linear_time(batch: int, in_features: int, out_features: int, dtype: torch.dtype, *, runs: int = 10) -> float: + linear = torch.nn.Linear(in_features, out_features, bias=False).cuda().to(dtype=dtype) + torch.cuda.manual_seed(0) + inp = torch.randn(batch, in_features, device="cuda", dtype=dtype) + + # Warmup + for _ in range(5): + linear(inp) + + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(runs): + linear(inp) + end_event.record() + + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + return elapsed_ms / runs + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA device required") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_tf32_toggle_has_no_large_perf_regression(dtype: torch.dtype): + if dtype is torch.bfloat16 and not _supports_bfloat16(): + pytest.skip("Device does not support bfloat16") + + original_matmul = torch.backends.cuda.matmul.allow_tf32 + original_cudnn = torch.backends.cudnn.allow_tf32 + + try: + shapes = [ + (64, 4096, 4096), + (128, 2048, 8192), + ] + + results = [] + + for batch, in_features, out_features in shapes: + times_tf32 = [] + times_no_tf32 = [] + max_diff = 0.0 + + for _ in range(100): + # TF32 enabled + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + linear = torch.nn.Linear(in_features, out_features, bias=False).cuda().to(dtype=dtype) + inp = torch.randn(batch, in_features, device="cuda", dtype=dtype) + out_tf32 = linear(inp) + times_tf32.append(_measure_linear_time(batch, in_features, out_features, dtype)) + + # TF32 disabled + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + out_no_tf32 = linear(inp) + times_no_tf32.append(_measure_linear_time(batch, in_features, out_features, dtype)) + + max_diff = max(max_diff, float(torch.max(torch.abs(out_tf32 - out_no_tf32)).item())) + + avg_tf32 = sum(times_tf32) / len(times_tf32) + avg_no_tf32 = sum(times_no_tf32) / len(times_no_tf32) + + slower = max(avg_tf32, avg_no_tf32) + faster = min(avg_tf32, avg_no_tf32) + + assert slower <= faster * 1.5, ( + f"TF32 toggle caused >50% slowdown for dtype={dtype}, shape={batch}x{in_features}->{out_features}: " + f"tf32={avg_tf32:.3f}ms, no_tf32={avg_no_tf32:.3f}ms" + ) + + results.append( + { + "dtype": str(dtype).split(".")[-1], + "shape": f"{batch}x{in_features}->{out_features}", + "avg_tf32_ms": avg_tf32, + "avg_no_tf32_ms": avg_no_tf32, + "max_abs_diff": max_diff, + } + ) + + table = tabulate( + [ + [r["dtype"], r["shape"], f"{r['avg_tf32_ms']:.3f}", f"{r['avg_no_tf32_ms']:.3f}", f"{r['max_abs_diff']:.3e}"] + for r in results + ], + headers=["dtype", "shape", "avg_tf32_ms", "avg_no_tf32_ms", "max_abs_diff"], + ) + print("\nTF32 performance summary:\n" + table) + finally: + torch.backends.cuda.matmul.allow_tf32 = original_matmul + torch.backends.cudnn.allow_tf32 = original_cudnn From 588a51ce73b1b95bff16d24045d7a9e5a5a8fac1 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 1 Oct 2025 02:59:32 +0000 Subject: [PATCH 2/5] test tf32 Signed-off-by: Qubitium --- tests/test_tf32_performance.py | 41 +++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/tests/test_tf32_performance.py b/tests/test_tf32_performance.py index e4b746a63..754875b68 100644 --- a/tests/test_tf32_performance.py +++ b/tests/test_tf32_performance.py @@ -1,6 +1,10 @@ import pytest import torch -from tabulate import tabulate + +try: + from tabulate import tabulate +except ImportError: # pragma: no cover + tabulate = None def _supports_bfloat16() -> bool: @@ -44,8 +48,15 @@ def test_tf32_toggle_has_no_large_perf_regression(dtype: torch.dtype): try: shapes = [ + # Llama 3 / Mistral style hidden dims (64, 4096, 4096), - (128, 2048, 8192), + (128, 4096, 11008), + # Qwen 2/3 7B/14B style + (64, 3584, 15360), + (64, 8192, 28672), + # DeepSeek V3 large experts + (32, 7168, 28672), + (32, 12288, 49152), ] results = [] @@ -55,7 +66,7 @@ def test_tf32_toggle_has_no_large_perf_regression(dtype: torch.dtype): times_no_tf32 = [] max_diff = 0.0 - for _ in range(100): + for _ in range(10): # TF32 enabled torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -93,14 +104,22 @@ def test_tf32_toggle_has_no_large_perf_regression(dtype: torch.dtype): } ) - table = tabulate( - [ - [r["dtype"], r["shape"], f"{r['avg_tf32_ms']:.3f}", f"{r['avg_no_tf32_ms']:.3f}", f"{r['max_abs_diff']:.3e}"] - for r in results - ], - headers=["dtype", "shape", "avg_tf32_ms", "avg_no_tf32_ms", "max_abs_diff"], - ) - print("\nTF32 performance summary:\n" + table) + if tabulate: + table = tabulate( + [ + [r["dtype"], r["shape"], f"{r['avg_tf32_ms']:.3f}", f"{r['avg_no_tf32_ms']:.3f}", f"{r['max_abs_diff']:.3e}"] + for r in results + ], + headers=["dtype", "shape", "avg_tf32_ms", "avg_no_tf32_ms", "max_abs_diff"], + ) + print("\nTF32 performance summary:\n" + table) + else: + print("\nTF32 performance summary:") + for r in results: + print( + f"dtype={r['dtype']} shape={r['shape']} avg_tf32={r['avg_tf32_ms']:.3f}ms " + f"avg_no_tf32={r['avg_no_tf32_ms']:.3f}ms max_abs_diff={r['max_abs_diff']:.3e}" + ) finally: torch.backends.cuda.matmul.allow_tf32 = original_matmul torch.backends.cudnn.allow_tf32 = original_cudnn From 8193bc171840612f5e98eb360ff02588c987871f Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 1 Oct 2025 03:10:14 +0000 Subject: [PATCH 3/5] tf32 enable/disable ctx Signed-off-by: Qubitium --- gptqmodel/looper/eora_processor.py | 42 ++++++++++++++------------- gptqmodel/looper/gptq_processor.py | 8 +++-- gptqmodel/looper/qqq_processor.py | 8 +++-- gptqmodel/nn_modules/hooked_linear.py | 23 +++++++++++---- gptqmodel/quantization/gptq.py | 2 -- gptqmodel/quantization/qqq.py | 3 -- gptqmodel/utils/torch.py | 37 +++++++++++++++++++++++ 7 files changed, 87 insertions(+), 36 deletions(-) diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py index 7ee752a30..088895013 100644 --- a/gptqmodel/looper/eora_processor.py +++ b/gptqmodel/looper/eora_processor.py @@ -96,13 +96,14 @@ def is_skipped(self, module: NamedModule) -> bool: def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor): - self.eora_process_input( - input=input, - name=name, - eigen_scaling_diag_matrix=self.eigen_scaling_diag_matrix, - sample_size=self.num_batches, - device=module.target_device, - ) + with tf32_disable_guard(): + self.eora_process_input( + input=input, + name=name, + eigen_scaling_diag_matrix=self.eigen_scaling_diag_matrix, + sample_size=self.num_batches, + device=module.target_device, + ) return tmp def pre_process_streaming(self, module: NamedModule): @@ -131,19 +132,20 @@ def process(self, module: NamedModule): assert w_wq_delta.dtype == torch.float32, f"w_wq_delta dtype: {w_wq_delta.dtype}" # log.info(f"EoRA: module native dtype = `{module_native_dtype}") - A, B = self.eora_compute_lora( - w_wq_delta=w_wq_delta, - name=module.name, - eigen_scaling_diag_matrix=eigen_scaling_diag_matrix, - rank=module.adapter_cfg.rank, - dtype=module.module_dtype, - device=module.target_device, - ) - - del eigen_scaling_diag_matrix - - # wq with A/B applied - computed_wq = wq + (B @ A) + with tf32_disable_guard(): + A, B = self.eora_compute_lora( + w_wq_delta=w_wq_delta, + name=module.name, + eigen_scaling_diag_matrix=eigen_scaling_diag_matrix, + rank=module.adapter_cfg.rank, + dtype=module.module_dtype, + device=module.target_device, + ) + + del eigen_scaling_diag_matrix + + # wq with A/B applied + computed_wq = wq + (B @ A) module.state.update({ "wq": move_to(wq, device=CPU, stream=self.stream), diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index e8084a0f0..409705d93 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -21,7 +21,7 @@ from ..utils.importer import select_quant_linear from ..utils.logger import setup_logger from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module -from ..utils.torch import HAS_CUDA, torch_streamCtx, torch_sync +from ..utils.torch import HAS_CUDA, tf32_disable_guard, torch_streamCtx, torch_sync log = setup_logger() lock = threading.Lock() @@ -103,7 +103,8 @@ def is_skipped(self, module: NamedModule) -> bool: def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): g = self.tasks[name] # noqa: F821 - g.add_batch(inp[0].data, out.data) # noqa: F821 + with tf32_disable_guard(): + g.add_batch(inp[0].data, out.data) # noqa: F821 del inp, out return tmp @@ -125,7 +126,8 @@ def process(self, module: NamedModule): with self.lock: g = self.tasks[module.name] - wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize() + with tf32_disable_guard(): + wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize() q_scales = q_scales.to(CPU) q_zeros = q_zeros.to(CPU) diff --git a/gptqmodel/looper/qqq_processor.py b/gptqmodel/looper/qqq_processor.py index f2fe748de..d1bc276c3 100644 --- a/gptqmodel/looper/qqq_processor.py +++ b/gptqmodel/looper/qqq_processor.py @@ -20,7 +20,7 @@ from ..quantization.qqq import QQQ from ..utils.logger import setup_logger from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module -from ..utils.torch import CPU, DEVICE_0, torch_streamCtx, torch_sync +from ..utils.torch import CPU, DEVICE_0, tf32_disable_guard, torch_streamCtx, torch_sync log = setup_logger() @@ -103,7 +103,8 @@ def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tenso def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): # gptq is mutable. q = self.tasks[name] # noqa: F821 - q.add_batch(inp[0].data, out.data) # noqa: F821 + with tf32_disable_guard(): + q.add_batch(inp[0].data, out.data) # noqa: F821 return tmp def pre_process_streaming(self, module: NamedModule): @@ -121,7 +122,8 @@ def process(self, module: NamedModule): # logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}") ## Need to return the quantized_weight for offloading q = qqq[module.name] - wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, q_scales_extra, nsamples = q.quantize() + with tf32_disable_guard(): + wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, q_scales_extra, nsamples = q.quantize() ## Assign the quantized weight to the weight #gptq[name].layer.weight.data = q_full_weight.to(device=gptq[name].device) diff --git a/gptqmodel/nn_modules/hooked_linear.py b/gptqmodel/nn_modules/hooked_linear.py index a7a3cb0ef..395b0c57c 100644 --- a/gptqmodel/nn_modules/hooked_linear.py +++ b/gptqmodel/nn_modules/hooked_linear.py @@ -10,10 +10,12 @@ from torch import nn from ..utils.logger import setup_logger +from ..utils.torch import tf32_enable_guard log = setup_logger() + class StopForward(Exception): """Signal an intentional early stop of the forward pass.""" pass @@ -37,9 +39,12 @@ def from_conv1d(m: transformers.Conv1D): custom.bias = m.bias return custom + @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(device=self.weight.data.device) - output = super().forward(input) + with tf32_enable_guard(): + output = super().forward(input) + if self.forward_hook: self.forward_hook(self, (input,), output) if self.forward_hook_last: @@ -93,9 +98,11 @@ def from_conv1d(m: torch.nn.Conv1d): custom.bias = m.bias return custom + @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(device=self.weight.data.device) - output = super().forward(input) + with tf32_enable_guard(): + output = super().forward(input) if self.forward_hook: self.forward_hook(self, (input,), output) if self.forward_hook_last: @@ -150,9 +157,11 @@ def from_conv2d(m: torch.nn.Conv2d): custom.bias = m.bias return custom + @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(device=self.weight.data.device) - output = super().forward(input) + with tf32_enable_guard(): + output = super().forward(input) if self.forward_hook: self.forward_hook(self, (input,), output) if self.forward_hook_last: @@ -175,9 +184,11 @@ def from_conv1d(conv1d: transformers.Conv1D): custom.bias = conv1d.bias return custom + @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(device=self.weight.data.device) - output = super().forward(input) + with tf32_enable_guard(): + output = super().forward(input) if self.forward_hook: self.forward_hook(self, (input,), output) if self.forward_hook_last: @@ -201,9 +212,11 @@ def from_linear(linear: torch.nn.Linear): custom_linear.bias = linear.bias return custom_linear + @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(device=self.weight.data.device) - output = super().forward(input) + with tf32_enable_guard(): + output = super().forward(input) if self.forward_hook: self.forward_hook(self, (input,), output) if self.forward_hook_last: diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 84fdb76b1..f3ab4951c 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -30,8 +30,6 @@ log = setup_logger() lock = threading.Lock() -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False # TODO: is there a buffer init threading init bug in torch.linalg? # bypass strange threading bug by warming up torch.linalg.cholesky to setup internal setup calls diff --git a/gptqmodel/quantization/qqq.py b/gptqmodel/quantization/qqq.py index 6d56fadd7..fe7bcc26e 100644 --- a/gptqmodel/quantization/qqq.py +++ b/gptqmodel/quantization/qqq.py @@ -16,9 +16,6 @@ DEBUG = False -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - log = setup_logger() def quantize(x, scale, zero, maxq, sym, groupsize): diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 580ff10f0..bcda2b2ed 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -5,6 +5,7 @@ import contextlib import gc as py_gc +from contextlib import contextmanager from enum import Enum from typing import Callable, List, Union @@ -236,3 +237,39 @@ def device_next(balance_strategy: BalanceStrategy = DEFAULT_BALANCE_STRATEGY) -> def torch_streamCtx(stream: Union[torch.cuda.Stream, torch.xpu.Stream]) -> StreamContext: return torch.cuda.stream(stream) if HAS_CUDA else torch.xpu.stream(stream) + + +@contextmanager +def tf32_enable_guard(): + if not HAS_CUDA: + yield + return + + prev_matmul = torch.backends.cuda.matmul.allow_tf32 + prev_cudnn = torch.backends.cudnn.allow_tf32 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + try: + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_matmul + torch.backends.cudnn.allow_tf32 = prev_cudnn + + +@contextmanager +def tf32_disable_guard(): + if not HAS_CUDA: + yield + return + + prev_matmul = torch.backends.cuda.matmul.allow_tf32 + prev_cudnn = torch.backends.cudnn.allow_tf32 + + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + try: + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_matmul + torch.backends.cudnn.allow_tf32 = prev_cudnn From 920784567838214ae29e8a1820ee7616c961639b Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 1 Oct 2025 03:16:49 +0000 Subject: [PATCH 4/5] ues ctx Signed-off-by: Qubitium --- tests/test_tf32_performance.py | 133 +++++++++++++++------------------ 1 file changed, 62 insertions(+), 71 deletions(-) diff --git a/tests/test_tf32_performance.py b/tests/test_tf32_performance.py index 754875b68..48fbcf94a 100644 --- a/tests/test_tf32_performance.py +++ b/tests/test_tf32_performance.py @@ -1,6 +1,8 @@ import pytest import torch +from gptqmodel.utils.torch import tf32_disable_guard, tf32_enable_guard + try: from tabulate import tabulate except ImportError: # pragma: no cover @@ -43,83 +45,72 @@ def test_tf32_toggle_has_no_large_perf_regression(dtype: torch.dtype): if dtype is torch.bfloat16 and not _supports_bfloat16(): pytest.skip("Device does not support bfloat16") - original_matmul = torch.backends.cuda.matmul.allow_tf32 - original_cudnn = torch.backends.cudnn.allow_tf32 - - try: - shapes = [ - # Llama 3 / Mistral style hidden dims - (64, 4096, 4096), - (128, 4096, 11008), - # Qwen 2/3 7B/14B style - (64, 3584, 15360), - (64, 8192, 28672), - # DeepSeek V3 large experts - (32, 7168, 28672), - (32, 12288, 49152), - ] - - results = [] - - for batch, in_features, out_features in shapes: - times_tf32 = [] - times_no_tf32 = [] - max_diff = 0.0 - - for _ in range(10): - # TF32 enabled - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True + shapes = [ + # Llama 3 / Mistral style hidden dims + (64, 4096, 4096), + (128, 4096, 11008), + # Qwen 2/3 7B/14B style + (64, 3584, 15360), + (64, 8192, 28672), + # DeepSeek V3 large experts + (32, 7168, 28672), + (32, 12288, 49152), + ] + + results = [] + + for batch, in_features, out_features in shapes: + times_tf32 = [] + times_no_tf32 = [] + max_diff = 0.0 + + for _ in range(10): + with tf32_enable_guard(): linear = torch.nn.Linear(in_features, out_features, bias=False).cuda().to(dtype=dtype) inp = torch.randn(batch, in_features, device="cuda", dtype=dtype) out_tf32 = linear(inp) times_tf32.append(_measure_linear_time(batch, in_features, out_features, dtype)) - # TF32 disabled - torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cudnn.allow_tf32 = False + with tf32_disable_guard(): out_no_tf32 = linear(inp) times_no_tf32.append(_measure_linear_time(batch, in_features, out_features, dtype)) - max_diff = max(max_diff, float(torch.max(torch.abs(out_tf32 - out_no_tf32)).item())) - - avg_tf32 = sum(times_tf32) / len(times_tf32) - avg_no_tf32 = sum(times_no_tf32) / len(times_no_tf32) - - slower = max(avg_tf32, avg_no_tf32) - faster = min(avg_tf32, avg_no_tf32) - - assert slower <= faster * 1.5, ( - f"TF32 toggle caused >50% slowdown for dtype={dtype}, shape={batch}x{in_features}->{out_features}: " - f"tf32={avg_tf32:.3f}ms, no_tf32={avg_no_tf32:.3f}ms" - ) - - results.append( - { - "dtype": str(dtype).split(".")[-1], - "shape": f"{batch}x{in_features}->{out_features}", - "avg_tf32_ms": avg_tf32, - "avg_no_tf32_ms": avg_no_tf32, - "max_abs_diff": max_diff, - } - ) - - if tabulate: - table = tabulate( - [ - [r["dtype"], r["shape"], f"{r['avg_tf32_ms']:.3f}", f"{r['avg_no_tf32_ms']:.3f}", f"{r['max_abs_diff']:.3e}"] - for r in results - ], - headers=["dtype", "shape", "avg_tf32_ms", "avg_no_tf32_ms", "max_abs_diff"], + max_diff = max(max_diff, float(torch.max(torch.abs(out_tf32 - out_no_tf32)).item())) + + avg_tf32 = sum(times_tf32) / len(times_tf32) + avg_no_tf32 = sum(times_no_tf32) / len(times_no_tf32) + + slower = max(avg_tf32, avg_no_tf32) + faster = min(avg_tf32, avg_no_tf32) + + assert slower <= faster * 1.5, ( + f"TF32 toggle caused >50% slowdown for dtype={dtype}, shape={batch}x{in_features}->{out_features}: " + f"tf32={avg_tf32:.3f}ms, no_tf32={avg_no_tf32:.3f}ms" + ) + + results.append( + { + "dtype": str(dtype).split(".")[-1], + "shape": f"{batch}x{in_features}->{out_features}", + "avg_tf32_ms": avg_tf32, + "avg_no_tf32_ms": avg_no_tf32, + "max_abs_diff": max_diff, + } + ) + + if tabulate: + table = tabulate( + [ + [r["dtype"], r["shape"], f"{r['avg_tf32_ms']:.3f}", f"{r['avg_no_tf32_ms']:.3f}", f"{r['max_abs_diff']:.3e}"] + for r in results + ], + headers=["dtype", "shape", "avg_tf32_ms", "avg_no_tf32_ms", "max_abs_diff"], + ) + print("\nTF32 performance summary:\n" + table) + else: + print("\nTF32 performance summary:") + for r in results: + print( + f"dtype={r['dtype']} shape={r['shape']} avg_tf32={r['avg_tf32_ms']:.3f}ms " + f"avg_no_tf32={r['avg_no_tf32_ms']:.3f}ms max_abs_diff={r['max_abs_diff']:.3e}" ) - print("\nTF32 performance summary:\n" + table) - else: - print("\nTF32 performance summary:") - for r in results: - print( - f"dtype={r['dtype']} shape={r['shape']} avg_tf32={r['avg_tf32_ms']:.3f}ms " - f"avg_no_tf32={r['avg_no_tf32_ms']:.3f}ms max_abs_diff={r['max_abs_diff']:.3e}" - ) - finally: - torch.backends.cuda.matmul.allow_tf32 = original_matmul - torch.backends.cudnn.allow_tf32 = original_cudnn From d0d34a42537c95e29d765c5cdba1f20a9b3cf416 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 1 Oct 2025 03:55:12 +0000 Subject: [PATCH 5/5] awq use tf32 ctx and fix compat with main Signed-off-by: Qubitium --- gptqmodel/looper/awq_processor.py | 106 ++++++++++++++++-------------- 1 file changed, 58 insertions(+), 48 deletions(-) diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 3fe3465b4..5564fdeff 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -29,17 +29,17 @@ from ..quantization.config import FORMAT, METHOD, QuantizeConfig from ..utils.logger import setup_logger from ..utils.model import get_module_by_name_prefix, move_to -from ..utils.torch import CPU, torch_sync +from ..utils.torch import CPU, tf32_disable_guard, tf32_enable_guard, torch_sync log = setup_logger() class AWQProcessor(LoopProcessor): def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func, - calibration_concat_size: Optional[int], batch_size: int, gptq_model, model, + calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int, gptq_model, model, logger_board: str = "", require_fwd: bool = True, calculate_w_wq_diff: bool = False): super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration=calibration, - calibration_concat_size=calibration_concat_size, + calibration_concat_size=calibration_concat_size, calibration_sort=calibration_sort, prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, logger_board=logger_board, require_fwd=require_fwd) @@ -147,8 +147,9 @@ def forward(self, *args, **kwargs): target_device = model_device print(f"AWQProcessor: model parameters are on meta device, using {target_device} instead") - - self.model(samples.to(torch.device(target_device)), use_cache=False) + + with tf32_enable_guard(): + self.model(samples.to(torch.device(target_device)), use_cache=False) except ValueError: # work with early exit pass modules[0] = modules[0].module # restore @@ -295,15 +296,18 @@ def _search_best_scale( clear_memory(x_sum) # [STEP 3]: Compute output of module + module_kwargs = self._sanitize_kwargs(kwargs, module2inspect) with torch.inference_mode(): - module_kwargs = self._sanitize_kwargs(kwargs, module2inspect) - fp16_output = self._module_forward(inp, module2inspect, module_kwargs) - fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max) + with tf32_enable_guard(): + fp16_output = self._module_forward(inp, module2inspect, module_kwargs) - # [STEP 4]: Compute loss - best_scales, loss = self._compute_best_scale( - inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs - ) + with tf32_disable_guard(): + fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max) + + # [STEP 4]: Compute loss + best_scales, loss = self._compute_best_scale( + inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs + ) return ( get_op_name(module, prev_op), @@ -326,10 +330,11 @@ def layer_quantize(self, module: Module, device: torch.device, named_childs: Dic # self.gptq_model.move_embed(common_device) # Transformers >= 4.48.0 requires positional embeddings should be computed before forward pass - if (self.module_kwargs.get("position_embeddings") is None): - self.module_kwargs["position_embeddings"] = self.model.model.rotary_emb( - self.inps, self.module_kwargs["position_ids"] - ) + if self.module_kwargs.get("position_embeddings") is None: + with tf32_enable_guard(): + self.module_kwargs["position_embeddings"] = self.model.model.rotary_emb( + self.inps, self.module_kwargs["position_ids"] + ) # TODO FIX ME: ??? if (self.module_kwargs.get('attention_mask') is None): @@ -358,31 +363,34 @@ def layer_quantize(self, module: Module, device: torch.device, named_childs: Dic clear_memory() # [STEP 2]: Compute and apply scale list - module_config: List[Dict] = self.gptq_model.awq_get_modules_for_scaling( - module, input_feat, self.module_kwargs - ) - scales_list = [ - self._search_best_scale(module, **layer) - for layer in module_config - ] - apply_scale(module, scales_list, input_feat_dict=input_feat) + with tf32_disable_guard(): + module_config: List[Dict] = self.gptq_model.awq_get_modules_for_scaling( + module, input_feat, self.module_kwargs + ) + scales_list = [ + self._search_best_scale(module, **layer) + for layer in module_config + ] + apply_scale(module, scales_list, input_feat_dict=input_feat) scales_list = append_str_prefix( scales_list, get_op_name(self.model, module) + "." ) # [STEP 3]: Compute and apply clipping list if self.apply_clip: - clip_list = self._search_best_clip( - module, named_linears, input_feat - ) - apply_clip(module, clip_list) + with tf32_disable_guard(): + clip_list = self._search_best_clip( + module, named_linears, input_feat + ) + apply_clip(module, clip_list) clip_list = append_str_prefix( clip_list, get_op_name(self.model, module) + "." ) # [STEP 4]: Quantize weights if not self.export_compatible: - self._apply_quant(module, named_childs, start, scales_list) + with tf32_disable_guard(): + self._apply_quant(module, named_childs, start, scales_list) clear_memory() @@ -397,9 +405,10 @@ def _search_best_clip(self, layer, named_linears, input_feat): continue named_linears[name].to(get_best_device()) - max_val = self._compute_best_clip( - named_linears[name].weight, input_feat[name] - ) + with tf32_disable_guard(): + max_val = self._compute_best_clip( + named_linears[name].weight, input_feat[name] + ) clip_list.append((name, max_val)) named_linears[name].cpu() @@ -615,25 +624,26 @@ def _compute_loss( def _module_forward( self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict ) -> torch.Tensor: - if self.n_parallel_calib_samples is None: - # runs through all samples at once - module_output = module(x, **module_kwargs) - if isinstance(module_output, tuple): - module_output = module_output[0] - else: - # memory efficiently runs through all calibration samples - # but only n_parallel_calib_samples at a time - module_output = [] - partitioned_inputs = torch.split(x, self.n_parallel_calib_samples) - for x_partial in partitioned_inputs: - partial_output = module(x_partial, **module_kwargs) + with tf32_enable_guard(): + if self.n_parallel_calib_samples is None: + # runs through all samples at once + module_output = module(x, **module_kwargs) + if isinstance(module_output, tuple): + module_output = module_output[0] + else: + # memory efficiently runs through all calibration samples + # but only n_parallel_calib_samples at a time + module_output = [] + partitioned_inputs = torch.split(x, self.n_parallel_calib_samples) + for x_partial in partitioned_inputs: + partial_output = module(x_partial, **module_kwargs) - if isinstance(partial_output, tuple): - partial_output = partial_output[0] + if isinstance(partial_output, tuple): + partial_output = partial_output[0] - module_output.append(partial_output.cpu()) + module_output.append(partial_output.cpu()) - module_output = torch.cat(module_output, dim=0) + module_output = torch.cat(module_output, dim=0) return module_output