From 383d76f71f43130ec22a5299d9da336f8dcda882 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 2 Oct 2025 09:06:51 +0000 Subject: [PATCH 1/3] enable deterministic ci tests Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 2 ++ tests/models/model_test.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index d231f0583..a76db8bfe 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -632,6 +632,8 @@ def loop(self, fail_safe: bool = False, **kwargs): layer_modules = self.gptq_model.simple_layer_modules(model_config=self.gptq_model.model.config, quantize_config=self.gptq_model.quantize_config) + # true-sequential will replay the quantized activations after each subset has been quantized to be used for next subset quantization + # this should always be true for gptq unless you want lower but misleading error_loss that is misleading and will lead to lower post-quantized model if not self.gptq_model.quantize_config.true_sequential: layer_modules = [sum(layer_modules, [])] diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 62f5602d7..a724c9883 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -13,6 +13,12 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +# the CUBLAS env is required for use_deterministic_algorithms +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + +import torch +torch.use_deterministic_algorithms(True) + # -- end do not touch from pathlib import Path # noqa: E402 From 64532d452937445b6167a77bd3d47fada10f1930 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 2 Oct 2025 09:11:17 +0000 Subject: [PATCH 2/3] always capture attention masks or we process pad tokens for lower but inaccurate quant err_loss Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index a76db8bfe..c718870f5 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -493,7 +493,8 @@ def store_input_hook(module, args, kwargs): layer_inputs.append(layer_input) # Keyword arguments. - if kwargs.get("attention_mask") is not None and self.gptq_model.ATTENTION_MASKS_REQUIRED_FOR_INPUT: + # Always capture attention_mask so downstream masking can drop padded tokens + if kwargs.get("attention_mask") is not None: attention_masks.append(kwargs["attention_mask"].to(device=data_device)) else: attention_masks.append(None) From 28e796b5cf2bd57769b3a0aaddb597ba67e3d4fb Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 2 Oct 2025 09:31:26 +0000 Subject: [PATCH 3/3] diable slower deterministic tests Signed-off-by: Qubitium --- tests/models/model_test.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/models/model_test.py b/tests/models/model_test.py index a724c9883..d85ce2ad2 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -13,11 +13,12 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -# the CUBLAS env is required for use_deterministic_algorithms -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - -import torch -torch.use_deterministic_algorithms(True) +# Following makes test results more deterministic but much slower +# # the CUBLAS env is required for use_deterministic_algorithms +# os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +# +# import torch +# torch.use_deterministic_algorithms(True) # -- end do not touch