From a106c18af4ed5939f78e9fed9a03d386d3092ab5 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 08:19:50 +0000 Subject: [PATCH 1/7] add awq kernel test --- tests/test_kernel_output_awq.py | 350 ++++++++++++++++++++++++++++++++ 1 file changed, 350 insertions(+) create mode 100644 tests/test_kernel_output_awq.py diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py new file mode 100644 index 000000000..55f91968f --- /dev/null +++ b/tests/test_kernel_output_awq.py @@ -0,0 +1,350 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import os +import tempfile +import unittest +from typing import List + +import torch +from datasets import load_dataset +from logbar import LogBar +from parameterized import parameterized +from tabulate import tabulate +from transformers import AutoTokenizer + +from gptqmodel import BACKEND, FORMAT, GPTQModel, QuantizeConfig +from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear +from gptqmodel.nn_modules.qlinear.awq_gemv_fast import ( + AwqGEMVFastQuantLinear, + awq_v2_ext, + msg as awq_v2_msg, +) +from gptqmodel.nn_modules.qlinear.awq_marlin import ( + AwqMarlinQuantLinear, + marlin_import_exception, +) +from gptqmodel.quantization import METHOD +from gptqmodel.utils.marlin import marlin_make_workspace_new +from gptqmodel.utils.model import find_modules + + +os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") + +log = LogBar.shared() + +DEVICE = torch.device("cuda:0") + +GREEN = "\033[32m" +RED = "\033[31m" +RESET = "\033[0m" + + +class Data: + def __init__(self) -> None: + self.inputs: List[torch.Tensor] = [] + self.reference_outputs: List[torch.Tensor] = [] + + +class TestAwqKernelOutput(unittest.TestCase): + pretrained_model_id = "/monster/data/model/Llama-3.2-1B" + dataset_path = "/monster/data/model/dataset/c4-train.00000-of-01024.json.gz" + target = "model.layers.6.self_attn.v_proj" + group_size = 128 + calibration_concat_size = 0 + + target_qliner_map = { + BACKEND.GEMM: AwqGEMMQuantLinear, + BACKEND.GEMV_FAST: AwqGEMVFastQuantLinear, + BACKEND.MARLIN: AwqMarlinQuantLinear, + } + + backend_to_format = { + BACKEND.GEMM: FORMAT.GEMM, + BACKEND.MARLIN: FORMAT.GEMM, + BACKEND.GEMV_FAST: FORMAT.GEMV_FAST, + } + + float16_cases = [ + (BACKEND.GEMM, torch.float16, 0.0), + (BACKEND.GEMV_FAST, torch.float16, 0.0005), + (BACKEND.MARLIN, torch.float16, 0.0005), + ] + + @classmethod + def setUpClass(cls) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is required for AWQ kernel output checks.") + + cls.test_dtypes = [torch.float16] + cls.quantized_tempdirs = {} + cls.quantized_model_paths = {} + cls.data = {} + + try: + cls._prepare_calibration_dataset() + cls._quantize_models() + cls._prepare_random_inputs() + except unittest.SkipTest: + raise + except Exception as exc: # pragma: no cover - defensive skip for CI env mismatches + raise unittest.SkipTest(f"Skipping AWQ kernel output tests: {exc}") from exc + + @classmethod + def tearDownClass(cls) -> None: + for tmp_dir in getattr(cls, "quantized_tempdirs", {}).values(): + tmp_dir.cleanup() + + @classmethod + def _prepare_calibration_dataset(cls) -> None: + try: + cls.tokenizer = AutoTokenizer.from_pretrained(cls.pretrained_model_id, use_fast=True) + except Exception as exc: + raise unittest.SkipTest(f"Tokenizer unavailable for AWQ tests: {exc}") from exc + + requested_samples = os.getenv("GPTQMODEL_AWQ_KERNEL_SAMPLES") + if requested_samples is not None: + sample_count = max(8, int(requested_samples)) + else: + try: + total_mem_gb = ( + torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory + / (1024 ** 3) + ) + except Exception: # pragma: no cover - fallback on inspect failure + total_mem_gb = 0.0 + + if total_mem_gb >= 80: + sample_count = 256 + elif total_mem_gb >= 48: + sample_count = 128 + else: + sample_count = 48 + + try: + dataset = load_dataset("json", data_files=cls.dataset_path, split="train") + except Exception as exc: + raise unittest.SkipTest(f"Calibration dataset unavailable for AWQ tests: {exc}") from exc + + if len(dataset) < sample_count: + raise unittest.SkipTest( + f"Calibration dataset too small ({len(dataset)} < {sample_count})." + ) + + cls.calibration_dataset = dataset.select(range(sample_count)) + + @classmethod + def _quantize_models(cls) -> None: + quantize_targets = [ + (FORMAT.GEMM, cls.group_size), + (FORMAT.GEMV_FAST, cls.group_size), + ] + + for checkpoint_format, group_size in quantize_targets: + quantize_config = QuantizeConfig( + bits=4, + group_size=group_size, + quant_method=METHOD.AWQ, + format=checkpoint_format, + ) + + model = GPTQModel.load( + cls.pretrained_model_id, + quantize_config=quantize_config, + ) + + model.quantize(cls.calibration_dataset, batch_size=1, calibration_concat_size=cls.calibration_concat_size) + + tmp_dir = tempfile.TemporaryDirectory() + model.save(tmp_dir.name) + + cls.quantized_tempdirs[(checkpoint_format, group_size)] = tmp_dir + cls.quantized_model_paths[(checkpoint_format, group_size)] = tmp_dir.name + + del model + torch.cuda.empty_cache() + + @classmethod + def _prepare_random_inputs(cls) -> None: + model_path = cls.quantized_model_paths[(FORMAT.GEMM, cls.group_size)] + model = GPTQModel.load(model_path, backend=BACKEND.GEMM, dtype=torch.float16) + + modules = find_modules(model.model, layers=[AwqGEMMQuantLinear]) + if cls.target not in modules: + raise unittest.SkipTest(f"Target layer `{cls.target}` missing in quantized model.") + + module = modules[cls.target] + in_features = module.in_features + + large_shapes = [(1, 128), (1, 64), (1, 48)] + medium_shapes = [(1, 64), (1, 48), (1, 32)] + small_shapes = [(1, 32), (1, 24), (1, 16)] + + try: + total_mem_gb = ( + torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory + / (1024 ** 3) + ) + except Exception: # pragma: no cover + total_mem_gb = 0.0 + + if os.getenv("GPTQMODEL_FAST_TESTS", "0") == "1": + shapes = small_shapes + elif total_mem_gb >= 80: + shapes = large_shapes + elif total_mem_gb >= 48: + shapes = medium_shapes + else: + shapes = small_shapes + + for dtype in cls.test_dtypes: + data = Data() + cls.data[dtype] = data + + with torch.inference_mode(): + for batch_tokens, seq_len in shapes: + inputs = torch.rand( + (batch_tokens, seq_len, in_features), + device=DEVICE, + dtype=dtype, + ) + data.inputs.append(inputs) + + reference_outputs = cls._forward( + model_path=model_path, + backend=BACKEND.GEMM, + dtype=dtype, + inputs=data.inputs, + ) + data.reference_outputs.extend(reference_outputs) + + del module + del model + torch.cuda.empty_cache() + + @classmethod + def _forward( + cls, + model_path: str, + backend: BACKEND, + dtype: torch.dtype, + inputs: List[torch.Tensor], + ) -> List[torch.Tensor]: + model = GPTQModel.load(model_path, backend=backend, dtype=dtype) + + target_qlinear_cls = cls.target_qliner_map[backend] + modules = find_modules(model.model, layers=[target_qlinear_cls]) + if cls.target not in modules: + raise unittest.SkipTest(f"Target layer `{cls.target}` missing for backend `{backend}`.") + + module = modules[cls.target] + + outputs: List[torch.Tensor] = [] + with torch.inference_mode(): + for tensor in inputs: + outputs.append(module(tensor)) + + del module + del model + torch.cuda.empty_cache() + + return outputs + + def _maybe_skip_backend(self, backend: BACKEND) -> None: + if backend == BACKEND.GEMV_FAST and awq_v2_ext is None: + self.skipTest(f"AWQ GEMV_FAST kernel unavailable: {awq_v2_msg}") + + if backend == BACKEND.MARLIN: + if marlin_import_exception is not None: + self.skipTest(f"AWQ Marlin kernel unavailable: {marlin_import_exception}") + + # Validate CUDA capability for Marlin kernels. + try: + workspace = marlin_make_workspace_new(DEVICE) + del workspace + torch.cuda.empty_cache() + except Exception as exc: + self.skipTest(f"Unable to allocate Marlin workspace: {exc}") + + def _summarize_results( + self, + reference_outputs: List[torch.Tensor], + actual_outputs: List[torch.Tensor], + backend: BACKEND, + dtype: torch.dtype, + atol: float, + title: str, + reference_label: str, + ) -> None: + failures = [] + total = len(actual_outputs) + + for idx, (reference, actual) in enumerate(zip(reference_outputs, actual_outputs)): + is_close_tensor = torch.isclose(reference, actual, rtol=0.15, atol=atol) + if not bool(torch.all(is_close_tensor)): + failures.append( + "Sample {idx}:\nExpected ({ref_label}) = {expected}\nActual = {actual_val}".format( + idx=idx, + ref_label=reference_label, + expected=reference.detach().cpu().tolist(), + actual_val=actual.detach().cpu().tolist(), + ) + ) + + status = f"{GREEN}PASS{RESET}" if not failures else f"{RED}FAIL{RESET}" + details = "\n\n".join(str(detail) for detail in failures) if failures else "-" + + table = tabulate( + [ + [ + backend.name, + str(dtype), + total, + status, + len(failures), + details, + ] + ], + headers=[ + "Backend", + "DType", + "Samples", + "Status", + "Failures", + "Expected vs Actual", + ], + tablefmt="github", + ) + log.info("\n" + title + "\n" + table) + + if failures: + raise AssertionError( + f"{len(failures)} mismatched outputs for backend `{backend}` and dtype `{dtype}`" + ) + + @parameterized.expand(float16_cases) + def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: float) -> None: + self._maybe_skip_backend(backend) + + quant_format = self.backend_to_format[backend] + model_path = self.quantized_model_paths[(quant_format, self.group_size)] + + data = self.data[dtype] + actual_outputs = self._forward( + model_path=model_path, + backend=backend, + dtype=dtype, + inputs=data.inputs, + ) + + self._summarize_results( + reference_outputs=data.reference_outputs, + actual_outputs=actual_outputs, + backend=backend, + dtype=dtype, + atol=atol, + title=f"AWQ Kernel Output {dtype}", + reference_label="AWQ GEMM output", + ) From 7447c61378405320e71412da75c4c26099376bc5 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 09:04:38 +0000 Subject: [PATCH 2/7] add awq kernel test --- tests/test_kernel_output_awq.py | 354 ++++++++++++++------------------ 1 file changed, 157 insertions(+), 197 deletions(-) diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 55f91968f..3dbf9abf1 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -3,32 +3,25 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import json import os -import tempfile import unittest -from typing import List +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Tuple import torch -from datasets import load_dataset from logbar import LogBar from parameterized import parameterized +from safetensors.torch import safe_open from tabulate import tabulate -from transformers import AutoTokenizer -from gptqmodel import BACKEND, FORMAT, GPTQModel, QuantizeConfig +from gptqmodel import BACKEND from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear -from gptqmodel.nn_modules.qlinear.awq_gemv_fast import ( - AwqGEMVFastQuantLinear, - awq_v2_ext, - msg as awq_v2_msg, -) from gptqmodel.nn_modules.qlinear.awq_marlin import ( AwqMarlinQuantLinear, marlin_import_exception, ) -from gptqmodel.quantization import METHOD from gptqmodel.utils.marlin import marlin_make_workspace_new -from gptqmodel.utils.model import find_modules os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") @@ -42,35 +35,16 @@ RESET = "\033[0m" -class Data: - def __init__(self) -> None: - self.inputs: List[torch.Tensor] = [] - self.reference_outputs: List[torch.Tensor] = [] - - class TestAwqKernelOutput(unittest.TestCase): - pretrained_model_id = "/monster/data/model/Llama-3.2-1B" - dataset_path = "/monster/data/model/dataset/c4-train.00000-of-01024.json.gz" - target = "model.layers.6.self_attn.v_proj" - group_size = 128 - calibration_concat_size = 0 - - target_qliner_map = { - BACKEND.GEMM: AwqGEMMQuantLinear, - BACKEND.GEMV_FAST: AwqGEMVFastQuantLinear, - BACKEND.MARLIN: AwqMarlinQuantLinear, - } - - backend_to_format = { - BACKEND.GEMM: FORMAT.GEMM, - BACKEND.MARLIN: FORMAT.GEMM, - BACKEND.GEMV_FAST: FORMAT.GEMV_FAST, - } + MODEL_PATH = Path("/monster/data/model/deepseek-r1-distill-qwen-7b-awq") + TARGET = "model.layers.20.self_attn.v_proj" + BITS = 4 + GROUP_SIZE = 128 + DTYPE = torch.float16 float16_cases = [ - (BACKEND.GEMM, torch.float16, 0.0), - (BACKEND.GEMV_FAST, torch.float16, 0.0005), - (BACKEND.MARLIN, torch.float16, 0.0005), + (BACKEND.GEMM, 0.0), + (BACKEND.MARLIN, 0.01), ] @classmethod @@ -78,113 +52,148 @@ def setUpClass(cls) -> None: if not torch.cuda.is_available(): raise unittest.SkipTest("CUDA is required for AWQ kernel output checks.") - cls.test_dtypes = [torch.float16] - cls.quantized_tempdirs = {} - cls.quantized_model_paths = {} - cls.data = {} + cls.device = DEVICE + cls.log = log + cls._weight_map = cls._load_weight_map() + cls.backend_skip_reason: Dict[BACKEND, str] = {} try: - cls._prepare_calibration_dataset() - cls._quantize_models() - cls._prepare_random_inputs() - except unittest.SkipTest: - raise - except Exception as exc: # pragma: no cover - defensive skip for CI env mismatches - raise unittest.SkipTest(f"Skipping AWQ kernel output tests: {exc}") from exc + tensors = cls._load_awq_tensors(cls.TARGET) + except Exception as exc: # pragma: no cover - skip if model unavailable + raise unittest.SkipTest(f"Unable to load AWQ tensors: {exc}") from exc - @classmethod - def tearDownClass(cls) -> None: - for tmp_dir in getattr(cls, "quantized_tempdirs", {}).values(): - tmp_dir.cleanup() + ( + qweight_cpu, + qzeros_cpu, + scales_cpu, + bias_cpu, + ) = tensors - @classmethod - def _prepare_calibration_dataset(cls) -> None: - try: - cls.tokenizer = AutoTokenizer.from_pretrained(cls.pretrained_model_id, use_fast=True) - except Exception as exc: - raise unittest.SkipTest(f"Tokenizer unavailable for AWQ tests: {exc}") from exc + cls.in_features = qweight_cpu.shape[0] + cls.out_features = qweight_cpu.shape[1] * (32 // cls.BITS) - requested_samples = os.getenv("GPTQMODEL_AWQ_KERNEL_SAMPLES") - if requested_samples is not None: - sample_count = max(8, int(requested_samples)) - else: - try: - total_mem_gb = ( - torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory - / (1024 ** 3) - ) - except Exception: # pragma: no cover - fallback on inspect failure - total_mem_gb = 0.0 + cls.modules: Dict[BACKEND, Optional[torch.nn.Module]] = {} - if total_mem_gb >= 80: - sample_count = 256 - elif total_mem_gb >= 48: - sample_count = 128 - else: - sample_count = 48 + cls.modules[BACKEND.GEMM] = cls._build_gemm_module( + qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu + ) - try: - dataset = load_dataset("json", data_files=cls.dataset_path, split="train") - except Exception as exc: - raise unittest.SkipTest(f"Calibration dataset unavailable for AWQ tests: {exc}") from exc + cls.modules[BACKEND.MARLIN] = cls._build_marlin_module( + qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu + ) - if len(dataset) < sample_count: - raise unittest.SkipTest( - f"Calibration dataset too small ({len(dataset)} < {sample_count})." - ) + cls.inputs = cls._generate_inputs() + cls.reference_outputs = cls._forward(cls.modules[BACKEND.GEMM], cls.inputs) - cls.calibration_dataset = dataset.select(range(sample_count)) + @classmethod + def tearDownClass(cls) -> None: + for module in getattr(cls, "modules", {}).values(): + if module is not None: + del module + torch.cuda.empty_cache() @classmethod - def _quantize_models(cls) -> None: - quantize_targets = [ - (FORMAT.GEMM, cls.group_size), - (FORMAT.GEMV_FAST, cls.group_size), - ] - - for checkpoint_format, group_size in quantize_targets: - quantize_config = QuantizeConfig( - bits=4, - group_size=group_size, - quant_method=METHOD.AWQ, - format=checkpoint_format, - ) + def _load_weight_map(cls) -> Dict[str, str]: + index_path = cls.MODEL_PATH / "model.safetensors.index.json" + with open(index_path, "r") as handle: + index = json.load(handle) + return index["weight_map"] - model = GPTQModel.load( - cls.pretrained_model_id, - quantize_config=quantize_config, - ) + @classmethod + def _load_tensor(cls, key: str) -> torch.Tensor: + if key not in cls._weight_map: + raise KeyError(f"Tensor `{key}` not found in weight map.") + filename = cls.MODEL_PATH / cls._weight_map[key] + with safe_open(filename, framework="pt", device="cpu") as f: + return f.get_tensor(key) - model.quantize(cls.calibration_dataset, batch_size=1, calibration_concat_size=cls.calibration_concat_size) + @classmethod + def _load_awq_tensors(cls, target: str) -> Tuple[torch.Tensor, ...]: + qweight = cls._load_tensor(f"{target}.qweight").contiguous() + qzeros = cls._load_tensor(f"{target}.qzeros").contiguous() + scales = cls._load_tensor(f"{target}.scales").contiguous() + bias = cls._load_tensor(f"{target}.bias").contiguous() + return qweight, qzeros, scales, bias - tmp_dir = tempfile.TemporaryDirectory() - model.save(tmp_dir.name) + @classmethod + def _build_gemm_module( + cls, + qweight_cpu: torch.Tensor, + qzeros_cpu: torch.Tensor, + scales_cpu: torch.Tensor, + bias_cpu: torch.Tensor, + ) -> AwqGEMMQuantLinear: + module = AwqGEMMQuantLinear( + bits=cls.BITS, + group_size=cls.GROUP_SIZE, + sym=True, + desc_act=False, + in_features=cls.in_features, + out_features=cls.out_features, + bias=True, + adapter=None, + register_buffers=True, + ).to(cls.device) + + module.qweight.copy_(qweight_cpu.to(cls.device)) + module.qzeros.copy_(qzeros_cpu.to(cls.device)) + module.scales.copy_(scales_cpu.to(torch.float32).to(cls.device)) + module.bias.copy_(bias_cpu.to(torch.float32).to(cls.device)) + + module.eval() + module.post_init() + return module - cls.quantized_tempdirs[(checkpoint_format, group_size)] = tmp_dir - cls.quantized_model_paths[(checkpoint_format, group_size)] = tmp_dir.name + @classmethod + def _build_marlin_module( + cls, + qweight_cpu: torch.Tensor, + qzeros_cpu: torch.Tensor, + scales_cpu: torch.Tensor, + bias_cpu: torch.Tensor, + ) -> Optional[AwqMarlinQuantLinear]: + if marlin_import_exception is not None: + cls.backend_skip_reason[BACKEND.MARLIN] = f"AWQ Marlin kernel unavailable: {marlin_import_exception}" + return None - del model + try: + workspace = marlin_make_workspace_new(cls.device) + del workspace torch.cuda.empty_cache() + except Exception as exc: + cls.backend_skip_reason[BACKEND.MARLIN] = f"Unable to allocate Marlin workspace: {exc}" + return None + + module = AwqMarlinQuantLinear( + bits=cls.BITS, + group_size=cls.GROUP_SIZE, + sym=True, + desc_act=False, + in_features=cls.in_features, + out_features=cls.out_features, + bias=True, + adapter=None, + register_buffers=True, + ).to(cls.device) + + module.qweight.data.copy_(qweight_cpu.to(cls.device)) + module.qzeros.data.copy_(qzeros_cpu.to(cls.device)) + module.scales.data.copy_(scales_cpu.to(torch.float16).to(cls.device)) + module.bias.data.copy_(bias_cpu.to(torch.float16).to(cls.device)) + + module.eval() + module.post_init() + return module @classmethod - def _prepare_random_inputs(cls) -> None: - model_path = cls.quantized_model_paths[(FORMAT.GEMM, cls.group_size)] - model = GPTQModel.load(model_path, backend=BACKEND.GEMM, dtype=torch.float16) - - modules = find_modules(model.model, layers=[AwqGEMMQuantLinear]) - if cls.target not in modules: - raise unittest.SkipTest(f"Target layer `{cls.target}` missing in quantized model.") - - module = modules[cls.target] - in_features = module.in_features - - large_shapes = [(1, 128), (1, 64), (1, 48)] - medium_shapes = [(1, 64), (1, 48), (1, 32)] + def _generate_inputs(cls) -> List[torch.Tensor]: + large_shapes = [(4, 32), (2, 64), (1, 96)] + medium_shapes = [(2, 32), (1, 48), (1, 32)] small_shapes = [(1, 32), (1, 24), (1, 16)] try: total_mem_gb = ( - torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory + torch.cuda.get_device_properties(cls.device).total_memory / (1024 ** 3) ) except Exception: # pragma: no cover @@ -199,81 +208,39 @@ def _prepare_random_inputs(cls) -> None: else: shapes = small_shapes - for dtype in cls.test_dtypes: - data = Data() - cls.data[dtype] = data - - with torch.inference_mode(): - for batch_tokens, seq_len in shapes: - inputs = torch.rand( - (batch_tokens, seq_len, in_features), - device=DEVICE, - dtype=dtype, - ) - data.inputs.append(inputs) - - reference_outputs = cls._forward( - model_path=model_path, - backend=BACKEND.GEMM, - dtype=dtype, - inputs=data.inputs, - ) - data.reference_outputs.extend(reference_outputs) - - del module - del model - torch.cuda.empty_cache() + inputs: List[torch.Tensor] = [] + for batch, tokens in shapes: + tensor = torch.rand( + (batch, tokens, cls.in_features), + device=cls.device, + dtype=cls.DTYPE, + ) + inputs.append(tensor) + return inputs @classmethod def _forward( cls, - model_path: str, - backend: BACKEND, - dtype: torch.dtype, - inputs: List[torch.Tensor], + module: torch.nn.Module, + inputs: Iterable[torch.Tensor], ) -> List[torch.Tensor]: - model = GPTQModel.load(model_path, backend=backend, dtype=dtype) - - target_qlinear_cls = cls.target_qliner_map[backend] - modules = find_modules(model.model, layers=[target_qlinear_cls]) - if cls.target not in modules: - raise unittest.SkipTest(f"Target layer `{cls.target}` missing for backend `{backend}`.") - - module = modules[cls.target] - outputs: List[torch.Tensor] = [] with torch.inference_mode(): for tensor in inputs: - outputs.append(module(tensor)) - - del module - del model - torch.cuda.empty_cache() - + result = module(tensor) + outputs.append(result.detach().to(torch.float32).cpu()) return outputs def _maybe_skip_backend(self, backend: BACKEND) -> None: - if backend == BACKEND.GEMV_FAST and awq_v2_ext is None: - self.skipTest(f"AWQ GEMV_FAST kernel unavailable: {awq_v2_msg}") - - if backend == BACKEND.MARLIN: - if marlin_import_exception is not None: - self.skipTest(f"AWQ Marlin kernel unavailable: {marlin_import_exception}") - - # Validate CUDA capability for Marlin kernels. - try: - workspace = marlin_make_workspace_new(DEVICE) - del workspace - torch.cuda.empty_cache() - except Exception as exc: - self.skipTest(f"Unable to allocate Marlin workspace: {exc}") + reason = self.backend_skip_reason.get(backend) + if reason: + self.skipTest(reason) def _summarize_results( self, reference_outputs: List[torch.Tensor], actual_outputs: List[torch.Tensor], backend: BACKEND, - dtype: torch.dtype, atol: float, title: str, reference_label: str, @@ -300,7 +267,7 @@ def _summarize_results( [ [ backend.name, - str(dtype), + str(self.DTYPE), total, status, len(failures), @@ -317,34 +284,27 @@ def _summarize_results( ], tablefmt="github", ) - log.info("\n" + title + "\n" + table) + self.log.info("\n" + title + "\n" + table) if failures: raise AssertionError( - f"{len(failures)} mismatched outputs for backend `{backend}` and dtype `{dtype}`" + f"{len(failures)} mismatched outputs for backend `{backend}`" ) @parameterized.expand(float16_cases) - def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: float) -> None: + def test_awq_kernel_outputs(self, backend: BACKEND, atol: float) -> None: self._maybe_skip_backend(backend) - quant_format = self.backend_to_format[backend] - model_path = self.quantized_model_paths[(quant_format, self.group_size)] - - data = self.data[dtype] - actual_outputs = self._forward( - model_path=model_path, - backend=backend, - dtype=dtype, - inputs=data.inputs, - ) + module = self.modules.get(backend) + if module is None: + self.skipTest(f"Backend `{backend}` module unavailable.") + actual_outputs = self._forward(module, self.inputs) self._summarize_results( - reference_outputs=data.reference_outputs, + reference_outputs=self.reference_outputs, actual_outputs=actual_outputs, backend=backend, - dtype=dtype, atol=atol, - title=f"AWQ Kernel Output {dtype}", + title=f"AWQ Kernel Output {self.DTYPE}", reference_label="AWQ GEMM output", ) From 753412f13877a049360720ef465d16b3023c0421 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 09:10:13 +0000 Subject: [PATCH 3/7] bf16 test --- tests/test_kernel_output_awq.py | 49 ++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 3dbf9abf1..79c74e859 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -40,11 +40,13 @@ class TestAwqKernelOutput(unittest.TestCase): TARGET = "model.layers.20.self_attn.v_proj" BITS = 4 GROUP_SIZE = 128 - DTYPE = torch.float16 + SUPPORTED_DTYPES = (torch.float16, torch.bfloat16) - float16_cases = [ - (BACKEND.GEMM, 0.0), - (BACKEND.MARLIN, 0.01), + backend_cases = [ + (BACKEND.GEMM, torch.float16, 0.0), + (BACKEND.GEMM, torch.bfloat16, 0.0005), + (BACKEND.MARLIN, torch.float16, 0.01), + (BACKEND.MARLIN, torch.bfloat16, 0.015), ] @classmethod @@ -82,8 +84,20 @@ def setUpClass(cls) -> None: qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu ) - cls.inputs = cls._generate_inputs() - cls.reference_outputs = cls._forward(cls.modules[BACKEND.GEMM], cls.inputs) + base_inputs = cls._generate_inputs() + cls.inputs: Dict[torch.dtype, List[torch.Tensor]] = {} + cls.reference_outputs: Dict[torch.dtype, List[torch.Tensor]] = {} + + for dtype in cls.SUPPORTED_DTYPES: + converted_inputs = [ + tensor.to(dtype=dtype) if tensor.dtype != dtype else tensor.clone() + for tensor in base_inputs + ] + cls.inputs[dtype] = converted_inputs + cls.reference_outputs[dtype] = cls._forward( + cls.modules[BACKEND.GEMM], + converted_inputs, + ) @classmethod def tearDownClass(cls) -> None: @@ -213,7 +227,7 @@ def _generate_inputs(cls) -> List[torch.Tensor]: tensor = torch.rand( (batch, tokens, cls.in_features), device=cls.device, - dtype=cls.DTYPE, + dtype=torch.float16, ) inputs.append(tensor) return inputs @@ -241,6 +255,7 @@ def _summarize_results( reference_outputs: List[torch.Tensor], actual_outputs: List[torch.Tensor], backend: BACKEND, + dtype: torch.dtype, atol: float, title: str, reference_label: str, @@ -267,7 +282,7 @@ def _summarize_results( [ [ backend.name, - str(self.DTYPE), + str(dtype), total, status, len(failures), @@ -291,20 +306,28 @@ def _summarize_results( f"{len(failures)} mismatched outputs for backend `{backend}`" ) - @parameterized.expand(float16_cases) - def test_awq_kernel_outputs(self, backend: BACKEND, atol: float) -> None: + @parameterized.expand(backend_cases) + def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: float) -> None: self._maybe_skip_backend(backend) module = self.modules.get(backend) if module is None: self.skipTest(f"Backend `{backend}` module unavailable.") - actual_outputs = self._forward(module, self.inputs) + inputs = self.inputs[dtype] + reference_outputs = self.reference_outputs[dtype] + try: + actual_outputs = self._forward(module, inputs) + except RuntimeError as exc: + if backend == BACKEND.MARLIN and dtype == torch.bfloat16: + self.skipTest(f"AWQ Marlin bf16 execution unavailable: {exc}") + raise self._summarize_results( - reference_outputs=self.reference_outputs, + reference_outputs=reference_outputs, actual_outputs=actual_outputs, backend=backend, + dtype=dtype, atol=atol, - title=f"AWQ Kernel Output {self.DTYPE}", + title=f"AWQ Kernel Output {dtype}", reference_label="AWQ GEMM output", ) From e611cb71ce601521079d65c9c4015697e6a7f127 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 09:18:26 +0000 Subject: [PATCH 4/7] bf16 test --- tests/test_kernel_output_awq.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 79c74e859..34d3d4834 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -264,14 +264,16 @@ def _summarize_results( total = len(actual_outputs) for idx, (reference, actual) in enumerate(zip(reference_outputs, actual_outputs)): - is_close_tensor = torch.isclose(reference, actual, rtol=0.15, atol=atol) + reference_fp32 = reference.to(torch.float32) + actual_fp32 = actual.to(torch.float32) + is_close_tensor = torch.isclose(reference_fp32, actual_fp32, rtol=0.15, atol=atol) if not bool(torch.all(is_close_tensor)): failures.append( "Sample {idx}:\nExpected ({ref_label}) = {expected}\nActual = {actual_val}".format( idx=idx, ref_label=reference_label, - expected=reference.detach().cpu().tolist(), - actual_val=actual.detach().cpu().tolist(), + expected=reference_fp32.detach().cpu().tolist(), + actual_val=actual_fp32.detach().cpu().tolist(), ) ) @@ -316,12 +318,12 @@ def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: fl inputs = self.inputs[dtype] reference_outputs = self.reference_outputs[dtype] - try: + if backend == BACKEND.MARLIN and dtype == torch.bfloat16: + converted_inputs = [tensor.to(torch.float16) for tensor in inputs] + actual_outputs_fp16 = self._forward(module, converted_inputs) + actual_outputs = [tensor.to(dtype) for tensor in actual_outputs_fp16] + else: actual_outputs = self._forward(module, inputs) - except RuntimeError as exc: - if backend == BACKEND.MARLIN and dtype == torch.bfloat16: - self.skipTest(f"AWQ Marlin bf16 execution unavailable: {exc}") - raise self._summarize_results( reference_outputs=reference_outputs, actual_outputs=actual_outputs, From e53c5ee736a470d5b750edd9dd5133acaa3b51bc Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 09:25:50 +0000 Subject: [PATCH 5/7] bf16 test --- gptqmodel/nn_modules/qlinear/awq_marlin.py | 16 +++++++++++++--- tests/test_kernel_output_awq.py | 7 +------ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/awq_marlin.py b/gptqmodel/nn_modules/qlinear/awq_marlin.py index fd55fcd5b..9b35c718e 100644 --- a/gptqmodel/nn_modules/qlinear/awq_marlin.py +++ b/gptqmodel/nn_modules/qlinear/awq_marlin.py @@ -242,10 +242,20 @@ def forward(self, x: torch.Tensor): "Use marlin_post_init() on the whole model." ) + input_tensor = x.contiguous() if self.is_lm_head else x + + weight_scale = self.scales + if weight_scale.dtype != input_tensor.dtype: + weight_scale = weight_scale.to(input_tensor.dtype) + + bias = self.bias + if bias is not None and bias.dtype != input_tensor.dtype: + bias = bias.to(input_tensor.dtype) + out = apply_awq_marlin_linear( - input=x.contiguous() if self.is_lm_head else x, + input=input_tensor, weight=self.qweight, - weight_scale=self.scales, + weight_scale=weight_scale, weight_zp=self.qzeros, g_idx=self.g_idx, g_idx_sort_indices=self.g_idx_sort_indices, @@ -253,7 +263,7 @@ def forward(self, x: torch.Tensor): quant_type=self.weight_type, output_size_per_partition=self.out_features, input_size_per_partition=self.in_features, - bias=self.bias, + bias=bias, ) if self.adapter: diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 34d3d4834..95c8bbcda 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -318,12 +318,7 @@ def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: fl inputs = self.inputs[dtype] reference_outputs = self.reference_outputs[dtype] - if backend == BACKEND.MARLIN and dtype == torch.bfloat16: - converted_inputs = [tensor.to(torch.float16) for tensor in inputs] - actual_outputs_fp16 = self._forward(module, converted_inputs) - actual_outputs = [tensor.to(dtype) for tensor in actual_outputs_fp16] - else: - actual_outputs = self._forward(module, inputs) + actual_outputs = self._forward(module, inputs) self._summarize_results( reference_outputs=reference_outputs, actual_outputs=actual_outputs, From 59c83760413e9f52eea1a6410b4cfe9d422f464a Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 09:31:37 +0000 Subject: [PATCH 6/7] fix awq marlin for bf16 --- gptqmodel/nn_modules/qlinear/awq_marlin.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/awq_marlin.py b/gptqmodel/nn_modules/qlinear/awq_marlin.py index 9b35c718e..e061f78f0 100644 --- a/gptqmodel/nn_modules/qlinear/awq_marlin.py +++ b/gptqmodel/nn_modules/qlinear/awq_marlin.py @@ -244,18 +244,16 @@ def forward(self, x: torch.Tensor): input_tensor = x.contiguous() if self.is_lm_head else x - weight_scale = self.scales - if weight_scale.dtype != input_tensor.dtype: - weight_scale = weight_scale.to(input_tensor.dtype) + if self.scales.dtype != input_tensor.dtype: + self.scales.data = self.scales.data.to(input_tensor.dtype) - bias = self.bias - if bias is not None and bias.dtype != input_tensor.dtype: - bias = bias.to(input_tensor.dtype) + if self.bias is not None and self.bias.dtype != input_tensor.dtype: + self.bias.data = self.bias.data.to(input_tensor.dtype) out = apply_awq_marlin_linear( input=input_tensor, weight=self.qweight, - weight_scale=weight_scale, + weight_scale=self.scales, weight_zp=self.qzeros, g_idx=self.g_idx, g_idx_sort_indices=self.g_idx_sort_indices, @@ -263,7 +261,7 @@ def forward(self, x: torch.Tensor): quant_type=self.weight_type, output_size_per_partition=self.out_features, input_size_per_partition=self.in_features, - bias=bias, + bias=self.bias, ) if self.adapter: From dad87897aee16039af57b079b7758b352431b60c Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 29 Oct 2025 09:33:50 +0000 Subject: [PATCH 7/7] simplify --- gptqmodel/nn_modules/qlinear/awq_marlin.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/awq_marlin.py b/gptqmodel/nn_modules/qlinear/awq_marlin.py index e061f78f0..e7d576159 100644 --- a/gptqmodel/nn_modules/qlinear/awq_marlin.py +++ b/gptqmodel/nn_modules/qlinear/awq_marlin.py @@ -242,16 +242,16 @@ def forward(self, x: torch.Tensor): "Use marlin_post_init() on the whole model." ) - input_tensor = x.contiguous() if self.is_lm_head else x + x = x.contiguous() if self.is_lm_head else x - if self.scales.dtype != input_tensor.dtype: - self.scales.data = self.scales.data.to(input_tensor.dtype) + if self.scales.dtype != x.dtype: + self.scales.data = self.scales.data.to(x.dtype) - if self.bias is not None and self.bias.dtype != input_tensor.dtype: - self.bias.data = self.bias.data.to(input_tensor.dtype) + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) out = apply_awq_marlin_linear( - input=input_tensor, + input=x, weight=self.qweight, weight_scale=self.scales, weight_zp=self.qzeros,