From be8ddddd3410d910376be7e6df06e18292c5df04 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 17 Oct 2025 05:43:23 +0000 Subject: [PATCH 1/7] add d2h to d accuracy test to make sure there is nothing `magical` going on between d2h and h2d ping pong --- tests/test_g2h_tensor_quality.py | 81 ++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 tests/test_g2h_tensor_quality.py diff --git a/tests/test_g2h_tensor_quality.py b/tests/test_g2h_tensor_quality.py new file mode 100644 index 000000000..de18e39a1 --- /dev/null +++ b/tests/test_g2h_tensor_quality.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +# test_g2h_tensor_quality.py +# Check whether major activation dtypes round-trip from GPU -> CPU -> GPU without +# accumulating numerical error. We use GPU 6 as requested, exercise both pinned +# and pageable host buffers, and report per-shape metrics so the test doubles as +# a quick experiment. + +import pytest +import torch + + +def _roundtrip_metrics(src: torch.Tensor, device: torch.device, pin_memory: bool): + if pin_memory: + host = torch.empty_like(src, device="cpu", pin_memory=True) + host.copy_(src, non_blocking=True) + # Ensure the async copy completes before reusing the tensor. + torch.cuda.synchronize(device) + roundtrip = host.to(device, copy=True, non_blocking=True) + else: + host = src.to("cpu", copy=True) + roundtrip = host.to(device, copy=True) + + diff = (src.float() - roundtrip.float()).abs() + return { + "max_abs_diff": diff.max().item(), + "mean_abs_diff": diff.mean().item(), + "nonzero_elements": int((diff != 0).sum().item()), + } + + +@pytest.mark.cuda +@pytest.mark.parametrize("pin_memory", (False, True)) +@pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16, torch.float32)) +def test_gpu_cpu_gpu_roundtrip_lossless(dtype, pin_memory): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for this test.") + + gpu_index = 6 + if torch.cuda.device_count() <= gpu_index: + pytest.skip(f"Need at least {gpu_index + 1} CUDA devices to exercise GPU {gpu_index}.") + + device = torch.device(f"cuda:{gpu_index}") + + # A mix of common activation shapes, including one large 2D tensor to stress memory copies. + shapes = [ + (1, 12288), # Transformer MLP activation + (16, 1024, 64), # Attention block with batch & heads + (3, 224, 224), # Vision-style activation + (4096, 4096), # Large square matrix to amplify any copy issues + ] + + # Ensure deterministic data so that repeated runs are comparable. + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + + results = [] + + with torch.cuda.device(device): + for shape in shapes: + src = torch.randn(shape, device=device, dtype=dtype) + metrics = _roundtrip_metrics(src, device, pin_memory=pin_memory) + results.append({"shape": shape, "dtype": dtype, "pin_memory": pin_memory, **metrics}) + + # Expect no change; any non-zero indicates data corruption during the round-trip. + assert metrics["max_abs_diff"] == 0.0, ( + f"Max diff {metrics['max_abs_diff']} detected for shape {shape} dtype {dtype} pin_memory={pin_memory}" + ) + assert metrics["nonzero_elements"] == 0, ( + f"Found {metrics['nonzero_elements']} differing elements for shape {shape} dtype {dtype} pin_memory={pin_memory}" + ) + + for r in results: + print( + f"GPU6 round-trip dtype={r['dtype']} shape={r['shape']} pin_memory={r['pin_memory']}: " + f"max_abs_diff={r['max_abs_diff']} mean_abs_diff={r['mean_abs_diff']} " + f"nonzero_elements={r['nonzero_elements']}" + ) From aac8fd7ca8acb0b75dc84f77b5fa9d467e9dd7c5 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 17 Oct 2025 08:02:38 +0000 Subject: [PATCH 2/7] disable usage of tf32 toggle --- gptqmodel/looper/awq_processor.py | 93 ++++++++--------- gptqmodel/looper/eora_processor.py | 38 ++++--- gptqmodel/looper/gptq_processor.py | 6 +- gptqmodel/looper/module_looper.py | 8 +- gptqmodel/looper/qqq_processor.py | 6 +- gptqmodel/models/definitions/qwen3_moe.py | 2 +- gptqmodel/nn_modules/hooked_linear.py | 16 +-- gptqmodel/utils/torch.py | 51 +++------ tests/test_tf32_performance.py | 122 ---------------------- 9 files changed, 91 insertions(+), 251 deletions(-) delete mode 100644 tests/test_tf32_performance.py diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 9719d66f7..6ad426586 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -136,8 +136,7 @@ def forward(self, *args, **kwargs): print(f"AWQProcessor: model parameters are on meta device, using {target_device} instead") - with tf32_enable_guard(): - self.model(samples.to(torch.device(target_device)), use_cache=False) + self.model(samples.to(torch.device(target_device)), use_cache=False) except ValueError: # work with early exit pass modules[0] = modules[0].module # restore @@ -283,16 +282,15 @@ def _search_best_scale( # [STEP 3]: Compute output of module module_kwargs = self._sanitize_kwargs(kwargs, module2inspect) - with ctx(torch.inference_mode(), tf32_enable_guard()): + with ctx(torch.inference_mode()): fp16_output = self._module_forward(inp, module2inspect, module_kwargs) - with tf32_disable_guard(): - fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max) + fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max) - # [STEP 4]: Compute loss - best_scales, loss = self._compute_best_scale( - inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs - ) + # [STEP 4]: Compute loss + best_scales, loss = self._compute_best_scale( + inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs + ) return ( get_op_name(module, prev_op), @@ -316,10 +314,9 @@ def layer_quantize(self, module: Module, device: torch.device, named_childs: Dic # Transformers >= 4.48.0 requires positional embeddings should be computed before forward pass if self.module_kwargs.get("position_embeddings") is None: - with tf32_enable_guard(): - self.module_kwargs["position_embeddings"] = self.model.model.rotary_emb( - self.inps, self.module_kwargs["position_ids"] - ) + self.module_kwargs["position_embeddings"] = self.model.model.rotary_emb( + self.inps, self.module_kwargs["position_ids"] + ) # TODO FIX ME: ??? if (self.module_kwargs.get('attention_mask') is None): @@ -346,34 +343,31 @@ def layer_quantize(self, module: Module, device: torch.device, named_childs: Dic input_feat = self._get_input_feat(module, named_linears) # [STEP 2]: Compute and apply scale list - with tf32_disable_guard(): - module_config: List[Dict] = self.gptq_model.awq_get_modules_for_scaling( - module, input_feat, self.module_kwargs - ) - scales_list = [ - self._search_best_scale(module, **layer) - for layer in module_config - ] - apply_scale(module, scales_list, input_feat_dict=input_feat) + module_config: List[Dict] = self.gptq_model.awq_get_modules_for_scaling( + module, input_feat, self.module_kwargs + ) + scales_list = [ + self._search_best_scale(module, **layer) + for layer in module_config + ] + apply_scale(module, scales_list, input_feat_dict=input_feat) scales_list = append_str_prefix( scales_list, get_op_name(self.model, module) + "." ) # [STEP 3]: Compute and apply clipping list if self.apply_clip: - with tf32_disable_guard(): - clip_list = self._search_best_clip( - module, named_linears, input_feat - ) - apply_clip(module, clip_list) + clip_list = self._search_best_clip( + module, named_linears, input_feat + ) + apply_clip(module, clip_list) clip_list = append_str_prefix( clip_list, get_op_name(self.model, module) + "." ) # [STEP 4]: Quantize weights if not self.export_compatible: - with tf32_disable_guard(): - self._apply_quant(module, named_childs, start, scales_list) + self._apply_quant(module, named_childs, start, scales_list) @torch.inference_mode() def _search_best_clip(self, layer, named_linears, input_feat): @@ -386,10 +380,10 @@ def _search_best_clip(self, layer, named_linears, input_feat): continue named_linears[name].to(get_best_device()) - with tf32_disable_guard(): - max_val = self._compute_best_clip( - named_linears[name].weight, input_feat[name] - ) + + max_val = self._compute_best_clip( + named_linears[name].weight, input_feat[name] + ) clip_list.append((name, max_val)) named_linears[name].cpu() @@ -604,26 +598,25 @@ def _compute_loss( def _module_forward( self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict ) -> torch.Tensor: - with tf32_enable_guard(): - if self.n_parallel_calib_samples is None: - # runs through all samples at once - module_output = module(x, **module_kwargs) - if isinstance(module_output, tuple): - module_output = module_output[0] - else: - # memory efficiently runs through all calibration samples - # but only n_parallel_calib_samples at a time - module_output = [] - partitioned_inputs = torch.split(x, self.n_parallel_calib_samples) - for x_partial in partitioned_inputs: - partial_output = module(x_partial, **module_kwargs) + if self.n_parallel_calib_samples is None: + # runs through all samples at once + module_output = module(x, **module_kwargs) + if isinstance(module_output, tuple): + module_output = module_output[0] + else: + # memory efficiently runs through all calibration samples + # but only n_parallel_calib_samples at a time + module_output = [] + partitioned_inputs = torch.split(x, self.n_parallel_calib_samples) + for x_partial in partitioned_inputs: + partial_output = module(x_partial, **module_kwargs) - if isinstance(partial_output, tuple): - partial_output = partial_output[0] + if isinstance(partial_output, tuple): + partial_output = partial_output[0] - module_output.append(partial_output.cpu()) + module_output.append(partial_output.cpu()) - module_output = torch.cat(module_output, dim=0) + module_output = torch.cat(module_output, dim=0) return module_output diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py index 2331b1f6a..893b805ef 100644 --- a/gptqmodel/looper/eora_processor.py +++ b/gptqmodel/looper/eora_processor.py @@ -93,14 +93,13 @@ def is_skipped(self, module: NamedModule) -> bool: def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor): - with tf32_disable_guard(): - batch_index = self.current_batch_index() - batch, contribution, scale = self.eora_process_input( - input=input, - name=name, - sample_size=self.num_batches, - device=module.weight.data.device, - ) + batch_index = self.current_batch_index() + batch, contribution, scale = self.eora_process_input( + input=input, + name=name, + sample_size=self.num_batches, + device=module.weight.data.device, + ) self._accumulate_eora_contribution( name=name, @@ -232,20 +231,19 @@ def process(self, module: NamedModule): assert w_wq_delta.dtype == torch.float32, f"w_wq_delta dtype: {w_wq_delta.dtype}" # log.info(f"EoRA: module native dtype = `{module_native_dtype}") - with tf32_disable_guard(): - A, B = self.eora_compute_lora( - w_wq_delta=w_wq_delta, - name=module.name, - eigen_scaling_diag_matrix=eigen_scaling_diag_matrix, - rank=module.adapter_cfg.rank, - dtype=module.module_dtype, - device=module.weight.data.device, - ) + A, B = self.eora_compute_lora( + w_wq_delta=w_wq_delta, + name=module.name, + eigen_scaling_diag_matrix=eigen_scaling_diag_matrix, + rank=module.adapter_cfg.rank, + dtype=module.module_dtype, + device=module.weight.data.device, + ) - del eigen_scaling_diag_matrix + del eigen_scaling_diag_matrix - # wq with A/B applied - computed_wq = (wq_device + (B @ A)).to(dtype=wq.dtype, device=target_device) + # wq with A/B applied + computed_wq = (wq_device + (B @ A)).to(dtype=wq.dtype, device=target_device) if pad_cols: computed_wq_trim = computed_wq[:, :original_cols] diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index dc0bd922e..9e343a0f9 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -113,8 +113,7 @@ def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tenso def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): g = self.tasks[name] # noqa: F821 batch_idx = self.current_batch_index() - with tf32_disable_guard(): - g.add_batch(inp[0].data, out.data, batch_index=batch_idx) # noqa: F821 + g.add_batch(inp[0].data, out.data, batch_index=batch_idx) # noqa: F821 del inp, out return tmp @@ -173,8 +172,7 @@ def process(self, module: NamedModule): f"while processing '{module.full_name}'." ) - with tf32_disable_guard(): - wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize() + wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize() module.stream_state_payload_to_cpu( { diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 9eac814b5..24467dda9 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -48,7 +48,7 @@ ) from ..utils.model import find_modules, get_module, get_module_by_name_prefix, move_to, nested_move_to from ..utils.offload import offload_to_disk -from ..utils.torch import (CPU, META, timed_gc_collect, torch_sync) +from ..utils.torch import (CPU, META, timed_gc_collect, torch_sync, tf32_high_precision_guard) from .. import DEVICE_THREAD_POOL from .awq_processor import AWQProcessor from .qqq_processor import QQQProcessor @@ -899,8 +899,12 @@ def store_input_hook(module, args, kwargs): return result - @torch.inference_mode() def loop(self, fail_safe: bool = False, **kwargs): + with tf32_high_precision_guard(): + return self._loop_impl(fail_safe=fail_safe, **kwargs) + + @torch.inference_mode() + def _loop_impl(self, fail_safe: bool = False, **kwargs): if self.gptq_model.quantize_config.lm_head: if self.gptq_model.model.config.tie_word_embeddings and hasattr(self.gptq_model.model.model, "_tied_weights_keys"): tied_keys = self.gptq_model.model._tied_weights_keys diff --git a/gptqmodel/looper/qqq_processor.py b/gptqmodel/looper/qqq_processor.py index 9c2bd44cb..365127126 100644 --- a/gptqmodel/looper/qqq_processor.py +++ b/gptqmodel/looper/qqq_processor.py @@ -97,8 +97,7 @@ def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tenso def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): # gptq is mutable. q = self.tasks[name] # noqa: F821 - with tf32_disable_guard(): - q.add_batch(inp[0].data, out.data) # noqa: F821 + q.add_batch(inp[0].data, out.data) # noqa: F821 return tmp def process(self, module: NamedModule): @@ -108,8 +107,7 @@ def process(self, module: NamedModule): # logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}") ## Need to return the quantized_weight for offloading q = qqq[module.name] - with tf32_disable_guard(): - wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, q_scales_extra, nsamples = q.quantize() + wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, q_scales_extra, nsamples = q.quantize() q_scales = q_scales.to(CPU) q_zeros = q_zeros.to(CPU) diff --git a/gptqmodel/models/definitions/qwen3_moe.py b/gptqmodel/models/definitions/qwen3_moe.py index e0f56602a..b0ca5c4be 100644 --- a/gptqmodel/models/definitions/qwen3_moe.py +++ b/gptqmodel/models/definitions/qwen3_moe.py @@ -3,8 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from ..base import BaseQModel from ...quantization import METHOD +from ..base import BaseQModel class Qwen3MoeQModel(BaseQModel): diff --git a/gptqmodel/nn_modules/hooked_linear.py b/gptqmodel/nn_modules/hooked_linear.py index 395b0c57c..867874a26 100644 --- a/gptqmodel/nn_modules/hooked_linear.py +++ b/gptqmodel/nn_modules/hooked_linear.py @@ -10,7 +10,6 @@ from torch import nn from ..utils.logger import setup_logger -from ..utils.torch import tf32_enable_guard log = setup_logger() @@ -42,8 +41,7 @@ def from_conv1d(m: transformers.Conv1D): @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(device=self.weight.data.device) - with tf32_enable_guard(): - output = super().forward(input) + output = super().forward(input) if self.forward_hook: self.forward_hook(self, (input,), output) @@ -101,8 +99,7 @@ def from_conv1d(m: torch.nn.Conv1d): @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(device=self.weight.data.device) - with tf32_enable_guard(): - output = super().forward(input) + output = super().forward(input) if self.forward_hook: self.forward_hook(self, (input,), output) if self.forward_hook_last: @@ -160,8 +157,7 @@ def from_conv2d(m: torch.nn.Conv2d): @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(device=self.weight.data.device) - with tf32_enable_guard(): - output = super().forward(input) + output = super().forward(input) if self.forward_hook: self.forward_hook(self, (input,), output) if self.forward_hook_last: @@ -187,8 +183,7 @@ def from_conv1d(conv1d: transformers.Conv1D): @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(device=self.weight.data.device) - with tf32_enable_guard(): - output = super().forward(input) + output = super().forward(input) if self.forward_hook: self.forward_hook(self, (input,), output) if self.forward_hook_last: @@ -215,8 +210,7 @@ def from_linear(linear: torch.nn.Linear): @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(device=self.weight.data.device) - with tf32_enable_guard(): - output = super().forward(input) + output = super().forward(input) if self.forward_hook: self.forward_hook(self, (input,), output) if self.forward_hook_last: diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index cddca4d46..0aff7dd78 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -132,6 +132,7 @@ def _restore_tf32_state(state) -> None: torch.backends.cuda.matmul.allow_tf32 = state[0] torch.backends.cudnn.allow_tf32 = state[1] + def torch_compile(module: Union[torch.nn.Module, Callable], backend:str ="inductor", mode: str = None, fullgraph=False): # requires torch >2.8 for proper torch.compile + Python 3.13.3t (freethreading) if has_gil_disabled() and not gte_python_3_13_3(): @@ -331,32 +332,13 @@ def torch_streamCtx(stream: Union[torch.cuda.Stream, torch.xpu.Stream]) -> Strea @contextmanager -def tf32_enable_guard(): +def tf32_high_precision_guard(): if not HAS_CUDA: yield return - if BACKENDS_HAS_FP32_PRECISION: - if torch.backends.fp32_precision == "tf32": - yield - return - - previous_state = _snapshot_tf32_state() - _set_tf32_state(True) - - try: - yield - finally: - _restore_tf32_state(previous_state) - return - previous_state = _snapshot_tf32_state() - if previous_state[0] and previous_state[1]: - yield - return - - _set_tf32_state(True) - + _set_tf32_state(False) try: yield finally: @@ -369,27 +351,22 @@ def tf32_disable_guard(): yield return - if BACKENDS_HAS_FP32_PRECISION: - if torch.backends.fp32_precision == "ieee": - yield - return - - previous_state = _snapshot_tf32_state() - _set_tf32_state(False) + previous_state = _snapshot_tf32_state() + _set_tf32_state(False) + try: + yield + finally: + _restore_tf32_state(previous_state) - try: - yield - finally: - _restore_tf32_state(previous_state) - return - previous_state = _snapshot_tf32_state() - if not previous_state[0] and not previous_state[1]: +@contextmanager +def tf32_enable_guard(): + if not HAS_CUDA: yield return - _set_tf32_state(False) - + previous_state = _snapshot_tf32_state() + _set_tf32_state(True) try: yield finally: diff --git a/tests/test_tf32_performance.py b/tests/test_tf32_performance.py deleted file mode 100644 index 870bafe7c..000000000 --- a/tests/test_tf32_performance.py +++ /dev/null @@ -1,122 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -import pytest -import torch - -from gptqmodel.utils.torch import tf32_disable_guard, tf32_enable_guard - - -try: - from tabulate import tabulate -except ImportError: # pragma: no cover - tabulate = None - - -def _supports_bfloat16() -> bool: - major, _ = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0) - return major >= 8 # Ampere or newer - - -def _measure_linear_time(batch: int, in_features: int, out_features: int, dtype: torch.dtype, *, runs: int = 10) -> float: - linear = torch.nn.Linear(in_features, out_features, bias=False).cuda().to(dtype=dtype) - torch.cuda.manual_seed(0) - inp = torch.randn(batch, in_features, device="cuda", dtype=dtype) - - # Warmup - for _ in range(5): - linear(inp) - - torch.cuda.synchronize() - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(runs): - linear(inp) - end_event.record() - - torch.cuda.synchronize() - - elapsed_ms = start_event.elapsed_time(end_event) - return elapsed_ms / runs - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA device required") -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_tf32_toggle_has_no_large_perf_regression(dtype: torch.dtype): - if dtype is torch.bfloat16 and not _supports_bfloat16(): - pytest.skip("Device does not support bfloat16") - - shapes = [ - # Llama 3 / Mistral style hidden dims - (64, 4096, 4096), - (128, 4096, 11008), - # Qwen 2/3 7B/14B style - (64, 3584, 15360), - (64, 8192, 28672), - # DeepSeek V3 large experts - (32, 7168, 28672), - (32, 12288, 49152), - ] - - results = [] - - for batch, in_features, out_features in shapes: - times_tf32 = [] - times_no_tf32 = [] - max_diff = 0.0 - - for _ in range(10): - with tf32_enable_guard(): - linear = torch.nn.Linear(in_features, out_features, bias=False).cuda().to(dtype=dtype) - inp = torch.randn(batch, in_features, device="cuda", dtype=dtype) - out_tf32 = linear(inp) - times_tf32.append(_measure_linear_time(batch, in_features, out_features, dtype)) - - with tf32_disable_guard(): - out_no_tf32 = linear(inp) - times_no_tf32.append(_measure_linear_time(batch, in_features, out_features, dtype)) - - max_diff = max(max_diff, float(torch.max(torch.abs(out_tf32 - out_no_tf32)).item())) - - avg_tf32 = sum(times_tf32) / len(times_tf32) - avg_no_tf32 = sum(times_no_tf32) / len(times_no_tf32) - - slower = max(avg_tf32, avg_no_tf32) - faster = min(avg_tf32, avg_no_tf32) - - assert slower <= faster * 1.5, ( - f"TF32 toggle caused >50% slowdown for dtype={dtype}, shape={batch}x{in_features}->{out_features}: " - f"tf32={avg_tf32:.3f}ms, no_tf32={avg_no_tf32:.3f}ms" - ) - - results.append( - { - "dtype": str(dtype).split(".")[-1], - "shape": f"{batch}x{in_features}->{out_features}", - "avg_tf32_ms": avg_tf32, - "avg_no_tf32_ms": avg_no_tf32, - "max_abs_diff": max_diff, - } - ) - - if tabulate: - table = tabulate( - [ - [r["dtype"], r["shape"], f"{r['avg_tf32_ms']:.3f}", f"{r['avg_no_tf32_ms']:.3f}", f"{r['max_abs_diff']:.3e}"] - for r in results - ], - headers=["dtype", "shape", "avg_tf32_ms", "avg_no_tf32_ms", "max_abs_diff"], - ) - print("\nTF32 performance summary:\n" + table) - else: - print("\nTF32 performance summary:") - for r in results: - print( - f"dtype={r['dtype']} shape={r['shape']} avg_tf32={r['avg_tf32_ms']:.3f}ms " - f"avg_no_tf32={r['avg_no_tf32_ms']:.3f}ms max_abs_diff={r['max_abs_diff']:.3e}" - ) From 200c283e9208ba6dfec3cc4671f104650adddd56 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 17 Oct 2025 09:52:19 +0000 Subject: [PATCH 3/7] use fp64 for per-gpu hessian merges --- gptqmodel/quantization/gptq.py | 35 ++++++++++++++++----------- tests/test_hessian_merge.py | 44 +++++++++++++++++++++++++++------- 2 files changed, 57 insertions(+), 22 deletions(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 21b5c3fb6..9ca1bc96d 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -471,25 +471,30 @@ def _materialize_global_hessian(self, target_device: Optional[torch.device] = No total_samples = sum(self._device_sample_counts.values()) + # Reuse the existing tensor when possible to avoid an extra allocation, + # but always accumulate in float64 for deterministic ordering across devices. reuse_buffer = ( self.H is not None and self.H.shape == (self.columns, self.columns) - and self.H.dtype == torch.float32 and self.H.device == device ) - if reuse_buffer: - result = self.H - result.zero_() + result_fp64: torch.Tensor + # Accumulating in float64 eliminates device-order drift at the cost of + # temporarily holding an FP64 buffer. The extra footprint is roughly + # columns^2 * 4 bytes; for an 8,192-column Llama MLP this is ~268 MB. + if reuse_buffer and self.H.dtype == torch.float64: + result_fp64 = self.H + result_fp64.zero_() else: - result = torch.zeros( + result_fp64 = torch.zeros( (self.columns, self.columns), - dtype=torch.float32, + dtype=torch.float64, device=device, ) if total_samples == 0: - self.H = result + self.H = result_fp64.to(dtype=torch.float32) self.nsamples = 0 self._hessian_dirty = False self._final_hessian_device_hint = device @@ -498,21 +503,23 @@ def _materialize_global_hessian(self, target_device: Optional[torch.device] = No return for partial_device, partial in self._device_hessian_partials.items(): - if partial.device != result.device: - tmp = partial.to(result.device) - result.add_(tmp) + if partial.device != result_fp64.device or partial.dtype != torch.float64: + tmp = partial.to(device=result_fp64.device, dtype=torch.float64) + result_fp64.add_(tmp) del tmp else: - result.add_(partial) + result_fp64.add_(partial) - result.mul_(2.0 / float(total_samples)) + result_fp64.mul_(2.0 / float(total_samples)) - self.H = result + result_fp32 = result_fp64.to(dtype=torch.float32) + self.H = result_fp32 self.nsamples = total_samples self._hessian_dirty = False - self._final_hessian_device_hint = result.device + self._final_hessian_device_hint = result_fp32.device self._device_hessian_partials.clear() self._device_sample_counts.clear() + del result_fp64 def finalize_hessian(self, target_device: Optional[torch.device] = None) -> torch.Tensor: self._materialize_global_hessian(target_device=target_device) diff --git a/tests/test_hessian_merge.py b/tests/test_hessian_merge.py index 51d76e31f..0863bffdd 100644 --- a/tests/test_hessian_merge.py +++ b/tests/test_hessian_merge.py @@ -8,8 +8,8 @@ @pytest.mark.skipif( - not torch.cuda.is_available() or torch.cuda.device_count() < 4, - reason="requires at least 4 CUDA devices", + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="requires at least two CUDA devices", ) @torch.no_grad() def test_hessian_merge_multi_gpu_matches_serial(): @@ -17,15 +17,16 @@ def test_hessian_merge_multi_gpu_matches_serial(): in_features = 16 out_features = 8 - batch_count = 100 - per_device = batch_count // 4 - devices = [torch.device(f"cuda:{idx}") for idx in range(4)] + batch_count = 64 + device_count = 2 + per_device = batch_count // device_count + devices = [torch.device(f"cuda:{idx}") for idx in range(device_count)] base = torch.nn.Linear(in_features, out_features, bias=False).eval() cfg_serial = QuantizeConfig() cfg_multi = copy.deepcopy(cfg_serial) - serial_module = copy.deepcopy(base) + serial_module = copy.deepcopy(base).to(devices[0]) multi_module = copy.deepcopy(base).to(devices[0]) gptq_serial = GPTQ(serial_module, cfg_serial) @@ -34,7 +35,10 @@ def test_hessian_merge_multi_gpu_matches_serial(): samples = [torch.randn(1, 1, in_features) for _ in range(batch_count)] for idx, sample in enumerate(samples): - gptq_serial.add_batch(sample, torch.empty(0), batch_index=idx) + sample_gpu = sample.to(devices[0]) + gptq_serial.add_batch(sample_gpu, torch.empty(0, device=devices[0]), batch_index=idx) + del sample_gpu + torch.cuda.synchronize(device=devices[0]) gptq_serial.finalize_hessian() serial_hessian = gptq_serial.H.detach().cpu() @@ -49,8 +53,32 @@ def test_hessian_merge_multi_gpu_matches_serial(): del sample_gpu torch.cuda.synchronize(device=device) + partials_snapshot = { + dev: tensor.clone() + for dev, tensor in gptq_multi._device_hessian_partials.items() + } + sample_counts_snapshot = dict(gptq_multi._device_sample_counts) + gptq_multi.finalize_hessian() merged_hessian = gptq_multi.H.detach().cpu() assert gptq_multi.nsamples == batch_count - torch.testing.assert_close(merged_hessian, serial_hessian, atol=1e-6, rtol=1e-6) + total_samples = sum(sample_counts_snapshot.values()) + assert total_samples == batch_count + + manual_device = gptq_multi.H.device + manual_accum = torch.zeros( + (gptq_multi.columns, gptq_multi.columns), + dtype=torch.float64, + device=manual_device, + ) + for dev, tensor in partials_snapshot.items(): + manual_accum.add_(tensor.to(device=manual_device, dtype=torch.float64)) + manual_accum.mul_(2.0 / float(total_samples)) + manual_result = manual_accum.to(dtype=torch.float32).cpu() + + # The materialized Hessian should match the explicit fp64 reduction exactly. + assert torch.equal(merged_hessian, manual_result) + + # And the merged Hessian should agree with the serial reference to float32 resolution. + torch.testing.assert_close(merged_hessian, serial_hessian, atol=5e-7, rtol=5e-7) From cad1185750839b5c257ba0aa7783e28c39e206c4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 17 Oct 2025 11:54:46 +0000 Subject: [PATCH 4/7] update tests --- tests/models/test_multi_vs_single_gpu.py | 9 ++- tests/test_hessian_merge.py | 84 ++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 5 deletions(-) diff --git a/tests/models/test_multi_vs_single_gpu.py b/tests/models/test_multi_vs_single_gpu.py index f9eecc11e..d22af06f4 100644 --- a/tests/models/test_multi_vs_single_gpu.py +++ b/tests/models/test_multi_vs_single_gpu.py @@ -232,7 +232,7 @@ def _summarize(stats: Dict[str, List[Dict[str, float]]]) -> Dict[str, Dict[str, }, ) info["batches"] += 1.0 - info["samples"] += item["after"] - item["before"] + info["samples"] += item.get("tokens", 0.0) info["sum_hash"] += item["sum"] info["primary"] = info["primary"] or bool(item.get("is_primary", False)) summary[name] = dict(sorted(per_handle.items(), key=lambda kv: kv[0])) @@ -259,7 +259,6 @@ def _capture_batches(storage: Dict[str, List[Dict[str, float]]], primary_handles def wrapped_add_batch(self, inp, out, batch_index=None): # type: ignore[override] module_name = getattr(self, "name", "") - before = getattr(self, "nsamples", 0) # Summaries calculated before running original implementation try: sum_value = inp.detach().to(dtype=torch.float64).sum().item() @@ -267,13 +266,13 @@ def wrapped_add_batch(self, inp, out, batch_index=None): # type: ignore[overrid sum_value = float("nan") device = str(getattr(inp, "device", "unknown")) + token_count = float(inp.numel() // max(inp.shape[-1], 1)) + original_add_batch(self, inp, out, batch_index=batch_index) - after = getattr(self, "nsamples", 0) storage.setdefault(module_name, []).append( { - "before": float(before), - "after": float(after), + "tokens": token_count, "sum": float(sum_value), "handle": hex(id(self)), "device": device, diff --git a/tests/test_hessian_merge.py b/tests/test_hessian_merge.py index 0863bffdd..7186dcff0 100644 --- a/tests/test_hessian_merge.py +++ b/tests/test_hessian_merge.py @@ -5,6 +5,7 @@ from gptqmodel.quantization.config import QuantizeConfig from gptqmodel.quantization.gptq import GPTQ +from gptqmodel.utils.attn_mask import apply_keep_mask_bt, normalize_seq_mask @pytest.mark.skipif( @@ -82,3 +83,86 @@ def test_hessian_merge_multi_gpu_matches_serial(): # And the merged Hessian should agree with the serial reference to float32 resolution. torch.testing.assert_close(merged_hessian, serial_hessian, atol=5e-7, rtol=5e-7) + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 8, + reason="requires CUDA devices >= 8 to exercise GPUs 6 and 7", +) +@torch.no_grad() +def test_hessian_merge_multi_gpu_with_attention_mask(): + torch.manual_seed(123) + + in_features = 32 + out_features = 16 + batch_size = 3 + seq_len = 21 + batch_count = 10 + + device_serial = torch.device("cuda:6") + devices = [torch.device("cuda:6"), torch.device("cuda:7")] + + base = torch.nn.Linear(in_features, out_features, bias=False).eval() + cfg_serial = QuantizeConfig(mock_quantization=True, desc_act=False) + cfg_multi = copy.deepcopy(cfg_serial) + + serial_module = copy.deepcopy(base).to(device_serial) + multi_module = copy.deepcopy(base).to(device_serial) + + gptq_serial = GPTQ(serial_module, cfg_serial) + gptq_multi = GPTQ(multi_module, cfg_multi) + + samples = [] + for _ in range(batch_count): + hidden = torch.randn(batch_size, seq_len, in_features, dtype=torch.float32) + mask = torch.ones(batch_size, seq_len, dtype=torch.int32) + for row in range(batch_size): + # ensure at least one valid token per row, trim a random tail portion + cutoff = torch.randint(1, seq_len + 1, ()).item() + if cutoff < seq_len: + mask[row, cutoff:] = 0 + samples.append((hidden, mask)) + + total_kept_tokens = 0 + for idx, (hidden, mask) in enumerate(samples): + hidden_gpu = hidden.to(device_serial) + mask_gpu = mask.to(device_serial) + keep = normalize_seq_mask(mask_gpu) + trimmed = apply_keep_mask_bt(hidden_gpu, keep) + total_kept_tokens += trimmed.shape[0] + gptq_serial.add_batch(trimmed, torch.empty(0, device=device_serial), batch_index=idx) + torch.cuda.synchronize(device=device_serial) + + gptq_serial.finalize_hessian() + serial_hessian = gptq_serial.H.detach().cpu() + assert gptq_serial.nsamples == total_kept_tokens + + per_device = batch_count // len(devices) + remainder = batch_count % len(devices) + start = 0 + + device_token_counts = {} + for device_idx, device in enumerate(devices): + extra = 1 if device_idx < remainder else 0 + end = start + per_device + extra + for idx in range(start, end): + hidden, mask = samples[idx] + hidden_gpu = hidden.to(device) + mask_gpu = mask.to(device) + keep = normalize_seq_mask(mask_gpu) + trimmed = apply_keep_mask_bt(hidden_gpu, keep) + device_token_counts[device] = device_token_counts.get(device, 0) + trimmed.shape[0] + gptq_multi.add_batch(trimmed, torch.empty(0, device=device), batch_index=idx) + torch.cuda.synchronize(device=device) + start = end + + assert sum(device_token_counts.values()) == total_kept_tokens + + partial_counts_snapshot = dict(gptq_multi._device_sample_counts) + assert partial_counts_snapshot == device_token_counts + + gptq_multi.finalize_hessian() + merged_hessian = gptq_multi.H.detach().cpu() + + assert gptq_multi.nsamples == total_kept_tokens + torch.testing.assert_close(merged_hessian, serial_hessian, atol=5e-7, rtol=5e-7) From af140fd7628b554776f3d8dcdc50c8e82ddaed74 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 17 Oct 2025 12:11:37 +0000 Subject: [PATCH 5/7] update tests2 --- tests/models/test_multi_vs_single_gpu.py | 124 ++++++++++++++++------- tests/test_hessian_merge.py | 17 ++++ 2 files changed, 107 insertions(+), 34 deletions(-) diff --git a/tests/models/test_multi_vs_single_gpu.py b/tests/models/test_multi_vs_single_gpu.py index d22af06f4..b0e0f94da 100644 --- a/tests/models/test_multi_vs_single_gpu.py +++ b/tests/models/test_multi_vs_single_gpu.py @@ -67,49 +67,80 @@ def test_quantization_first_layer_metrics_match_between_single_and_dual_gpu(self self.skipTest("Requires at least two CUDA devices") if sys.version_info < (3, 13): - self.skipTest("Requires Python 3.13 free-threaded runtime") + self.skipTest("Requires Python 3.13 runtime for multi-GPU regression test") - if not _is_free_threaded(): - self.skipTest("PYTHON_GIL must be disabled (set PYTHON_GIL=0) for multi-threaded quantization") + primary_idx, secondary_idx = self._select_preferred_devices(visible_devices) - single_layer_metrics, single_batch_stats = self._quantize_first_layer(device_indices=[0]) - multi_layer_metrics, multi_batch_stats = self._quantize_first_layer(device_indices=[0, 1]) + single_layer_metrics, single_batch_stats = self._quantize_layers( + device_indices=[primary_idx], + max_layer_index=1, + ) + multi_layer_metrics, multi_batch_stats = self._quantize_layers( + device_indices=[primary_idx, secondary_idx], + max_layer_index=1, + ) - self.assertTrue(single_layer_metrics, "Single-GPU quantization produced no layer-0 metrics") - self.assertTrue(multi_layer_metrics, "Multi-GPU quantization produced no layer-0 metrics") + self.assertTrue(single_layer_metrics, "Single-GPU quantization produced no layer metrics") + self.assertTrue(multi_layer_metrics, "Multi-GPU quantization produced no layer metrics") self.assertEqual( set(single_layer_metrics.keys()), set(multi_layer_metrics.keys()), - "Layer-0 module set differs between single-GPU and multi-GPU quantization", + "Layer set differs between single-GPU and multi-GPU quantization", ) - mismatches: Dict[str, Dict[str, str]] = {} - for module_name in single_layer_metrics: - single = single_layer_metrics[module_name] - multi = multi_layer_metrics[module_name] - if single.samples != multi.samples or single.loss != multi.loss: - mismatches[module_name] = { - "single_samples": str(single.samples), - "multi_samples": str(multi.samples), - "single_loss": str(single.loss), - "multi_loss": str(multi.loss), - } + print("[multi-vs-single] layer metrics summary:") + for layer_idx in sorted(single_layer_metrics): + single_layer = single_layer_metrics[layer_idx] + multi_layer = multi_layer_metrics[layer_idx] + print(f" layer={layer_idx}") + for module_name in sorted(single_layer): + single_val = single_layer[module_name] + multi_val = multi_layer[module_name] + print( + f" {module_name}: " + f"single_loss={single_val.loss} multi_loss={multi_val.loss} " + f"single_samples={single_val.samples} multi_samples={multi_val.samples}" + ) + + mismatches: Dict[Tuple[int, str], Dict[str, str]] = {} + for layer_idx in single_layer_metrics: + single_layer = single_layer_metrics[layer_idx] + multi_layer = multi_layer_metrics[layer_idx] + self.assertEqual( + set(single_layer.keys()), + set(multi_layer.keys()), + f"Layer-{layer_idx} module set differs between single-GPU and multi-GPU quantization", + ) + + for module_name in single_layer: + single = single_layer[module_name] + multi = multi_layer[module_name] + if single.samples != multi.samples or single.loss != multi.loss: + mismatches[(layer_idx, module_name)] = { + "single_samples": str(single.samples), + "multi_samples": str(multi.samples), + "single_loss": str(single.loss), + "multi_loss": str(multi.loss), + } if mismatches: debug_details = self._format_batch_debug(single_batch_stats, multi_batch_stats) details = "; ".join( - f"{module}: loss {info['single_loss']} vs {info['multi_loss']}, " + f"layer {layer}: {module}: loss {info['single_loss']} vs {info['multi_loss']}, " f"samples {info['single_samples']} vs {info['multi_samples']}" - for module, info in mismatches.items() + for (layer, module), info in mismatches.items() ) self.fail( - "Layer-0 quantization metrics diverged between device configurations: " + "Quantization metrics diverged between device configurations: " f"{details}; batch-debug: {debug_details}" ) - def _quantize_first_layer( - self, device_indices: Iterable[int] - ) -> Tuple[Dict[str, LayerMetrics], Dict[str, List[Dict[str, float]]]]: + def _quantize_layers( + self, + *, + device_indices: Iterable[int], + max_layer_index: int, + ) -> Tuple[Dict[int, Dict[str, LayerMetrics]], Dict[str, List[Dict[str, float]]]]: target_devices = [torch.device(f"cuda:{idx}") for idx in device_indices] def selection(_base_device): return target_devices @@ -122,7 +153,7 @@ def __init__(self, layer_idx: int): def layer_complete(self, *, layer_idx: int, submodule_finalized: bool): if self._triggered: return None - if submodule_finalized and layer_idx >= self._layer_idx: + if layer_idx > self._layer_idx or (submodule_finalized and layer_idx >= self._layer_idx): self._triggered = True raise StopMainLoop @@ -138,6 +169,7 @@ def layer_complete(self, *, layer_idx: int, submodule_finalized: bool): v2=self.V2, adapter=self.EORA, device=target_devices[0], + mock_quantization=True, ) load_kwargs = {} @@ -154,7 +186,7 @@ def layer_complete(self, *, layer_idx: int, submodule_finalized: bool): ) dataset = self.load_dataset(model.tokenizer, self.DATASET_SIZE) - model.layer_callback = _StopAfterLayer(layer_idx=0) + model.layer_callback = _StopAfterLayer(layer_idx=max_layer_index) batch_debug: Dict[str, List[Dict[str, float]]] = {} primary_handles: Dict[str, str] = {} @@ -175,24 +207,47 @@ def layer_complete(self, *, layer_idx: int, submodule_finalized: bool): batch_size=self.QUANT_BATCH_SIZE, ) - first_layer_stats = self._extract_first_layer_metrics(model.quant_log) + layer_stats = self._extract_layer_metrics( + model.quant_log, + max_layer_index=max_layer_index, + ) # Clear GPU memory before the next run del dataset del model torch_empty_cache() - return first_layer_stats, batch_debug + return layer_stats, batch_debug - def _extract_first_layer_metrics(self, quant_log: List[Dict[str, str]]) -> Dict[str, LayerMetrics]: - layer_metrics: Dict[str, LayerMetrics] = {} + @staticmethod + def _select_preferred_devices(visible_devices: int) -> Tuple[int, int]: + primary = 6 if visible_devices > 6 else 0 + secondary_preferences = [7, 1, 0] + secondary = None + for candidate in secondary_preferences: + if candidate >= visible_devices: + continue + if candidate == primary: + continue + secondary = candidate + break + if secondary is None: + raise RuntimeError("Could not determine a secondary CUDA device for regression test") + return primary, secondary + + def _extract_layer_metrics( + self, + quant_log: List[Dict[str, str]], + *, + max_layer_index: int, + ) -> Dict[int, Dict[str, LayerMetrics]]: + layer_metrics: Dict[int, Dict[str, LayerMetrics]] = {} for entry in quant_log: try: layer_index = int(entry.get(PROCESS_LOG_LAYER)) except (TypeError, ValueError): continue - - if layer_index != 0: + if layer_index < 0 or layer_index > max_layer_index: continue module_name = entry.get(PROCESS_LOG_MODULE) @@ -204,7 +259,8 @@ def _extract_first_layer_metrics(self, quant_log: List[Dict[str, str]]) -> Dict[ if loss_value is None or sample_value is None: continue - layer_metrics[module_name] = LayerMetrics( + per_layer = layer_metrics.setdefault(layer_index, {}) + per_layer[module_name] = LayerMetrics( loss=Decimal(loss_value), samples=int(sample_value), ) diff --git a/tests/test_hessian_merge.py b/tests/test_hessian_merge.py index 7186dcff0..05fa5ab74 100644 --- a/tests/test_hessian_merge.py +++ b/tests/test_hessian_merge.py @@ -64,6 +64,14 @@ def test_hessian_merge_multi_gpu_matches_serial(): merged_hessian = gptq_multi.H.detach().cpu() assert gptq_multi.nsamples == batch_count + max_abs_diff = (merged_hessian - serial_hessian).abs().max().item() + print( + "[hessian-no-mask] " + f"serial_nsamples={gptq_serial.nsamples} " + f"multi_nsamples={gptq_multi.nsamples} " + f"max_abs_diff={max_abs_diff:.6e}" + ) + total_samples = sum(sample_counts_snapshot.values()) assert total_samples == batch_count @@ -164,5 +172,14 @@ def test_hessian_merge_multi_gpu_with_attention_mask(): gptq_multi.finalize_hessian() merged_hessian = gptq_multi.H.detach().cpu() + max_abs_diff = (merged_hessian - serial_hessian).abs().max().item() + print( + "[hessian-mask] " + f"serial_tokens={total_kept_tokens} " + f"multi_tokens={gptq_multi.nsamples} " + f"per_device={{{', '.join(f'{str(dev)}:{count}' for dev, count in device_token_counts.items())}}} " + f"max_abs_diff={max_abs_diff:.6e}" + ) + assert gptq_multi.nsamples == total_kept_tokens torch.testing.assert_close(merged_hessian, serial_hessian, atol=5e-7, rtol=5e-7) From fc3b7aaf221d50227927bffc4d4f0a52c27f16c8 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 17 Oct 2025 13:18:59 +0000 Subject: [PATCH 6/7] revert fp64 --- gptqmodel/looper/named_module.py | 16 +++++--------- gptqmodel/quantization/gptq.py | 37 ++++++++++++++------------------ 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/gptqmodel/looper/named_module.py b/gptqmodel/looper/named_module.py index c9c85159c..4211f14f5 100644 --- a/gptqmodel/looper/named_module.py +++ b/gptqmodel/looper/named_module.py @@ -187,20 +187,14 @@ def _stream_tensor_dict( store_callback(host_map) return host_map - stream = torch.cuda.Stream(device=first.device) - done_event = torch.cuda.Event(enable_timing=False) host_map: Dict[str, torch.Tensor] = {} - with torch.cuda.stream(stream): - for name, tensor in filtered.items(): - src = tensor.detach() - host = host_pool.acquire(src.shape, src.dtype, src.layout) - host.copy_(src, non_blocking=True) - host_map[name] = host - done_event.record(stream) + for name, tensor in filtered.items(): + src = tensor.detach() + host = host_pool.acquire(src.shape, src.dtype, src.layout) + host.copy_(src, non_blocking=False) + host_map[name] = host with self._state_lock: - events = self.state.setdefault("streaming_events", []) - events.append({"event": done_event, "stream": stream}) store_callback(host_map) return host_map diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 9ca1bc96d..9583f0664 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -471,30 +471,26 @@ def _materialize_global_hessian(self, target_device: Optional[torch.device] = No total_samples = sum(self._device_sample_counts.values()) - # Reuse the existing tensor when possible to avoid an extra allocation, - # but always accumulate in float64 for deterministic ordering across devices. + # Reuse the existing tensor when possible to avoid an extra allocation. reuse_buffer = ( self.H is not None and self.H.shape == (self.columns, self.columns) and self.H.device == device ) - result_fp64: torch.Tensor - # Accumulating in float64 eliminates device-order drift at the cost of - # temporarily holding an FP64 buffer. The extra footprint is roughly - # columns^2 * 4 bytes; for an 8,192-column Llama MLP this is ~268 MB. - if reuse_buffer and self.H.dtype == torch.float64: - result_fp64 = self.H - result_fp64.zero_() + result_accum: torch.Tensor + if reuse_buffer and self.H.dtype == torch.float32: + result_accum = self.H + result_accum.zero_() else: - result_fp64 = torch.zeros( + result_accum = torch.zeros( (self.columns, self.columns), - dtype=torch.float64, + dtype=torch.float32, device=device, ) if total_samples == 0: - self.H = result_fp64.to(dtype=torch.float32) + self.H = result_accum self.nsamples = 0 self._hessian_dirty = False self._final_hessian_device_hint = device @@ -503,23 +499,22 @@ def _materialize_global_hessian(self, target_device: Optional[torch.device] = No return for partial_device, partial in self._device_hessian_partials.items(): - if partial.device != result_fp64.device or partial.dtype != torch.float64: - tmp = partial.to(device=result_fp64.device, dtype=torch.float64) - result_fp64.add_(tmp) + if partial.device != result_accum.device or partial.dtype != torch.float32: + tmp = partial.to(device=result_accum.device, dtype=torch.float32) + result_accum.add_(tmp) del tmp else: - result_fp64.add_(partial) + result_accum.add_(partial) - result_fp64.mul_(2.0 / float(total_samples)) + result_accum.mul_(2.0 / float(total_samples)) - result_fp32 = result_fp64.to(dtype=torch.float32) - self.H = result_fp32 + self.H = result_accum self.nsamples = total_samples self._hessian_dirty = False - self._final_hessian_device_hint = result_fp32.device + self._final_hessian_device_hint = result_accum.device self._device_hessian_partials.clear() self._device_sample_counts.clear() - del result_fp64 + del result_accum def finalize_hessian(self, target_device: Optional[torch.device] = None) -> torch.Tensor: self._materialize_global_hessian(target_device=target_device) From 0d7dbfb62e5d2bb2198d3104718c840793b96bb2 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 17 Oct 2025 13:39:59 +0000 Subject: [PATCH 7/7] re-enable streaming --- gptqmodel/looper/named_module.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/gptqmodel/looper/named_module.py b/gptqmodel/looper/named_module.py index 4211f14f5..ea766a5a3 100644 --- a/gptqmodel/looper/named_module.py +++ b/gptqmodel/looper/named_module.py @@ -189,12 +189,22 @@ def _stream_tensor_dict( host_map: Dict[str, torch.Tensor] = {} - for name, tensor in filtered.items(): - src = tensor.detach() - host = host_pool.acquire(src.shape, src.dtype, src.layout) - host.copy_(src, non_blocking=False) - host_map[name] = host + copy_device = first.device + compute_stream = torch.cuda.current_stream(device=copy_device) + copy_stream = torch.cuda.Stream(device=copy_device) + done_event = torch.cuda.Event(enable_timing=False, blocking=False) + + with torch.cuda.stream(copy_stream): + copy_stream.wait_stream(compute_stream) + for name, tensor in filtered.items(): + src = tensor.detach() + host = host_pool.acquire(src.shape, src.dtype, src.layout) + host.copy_(src, non_blocking=True) + host_map[name] = host + done_event.record(copy_stream) with self._state_lock: + events = self.state.setdefault("streaming_events", []) + events.append({"event": done_event, "stream": copy_stream}) store_callback(host_map) return host_map