diff --git a/README.md b/README.md index d6e8cce3a..6453ac01c 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@

## Latest News +* 11/9/2025 [5.4.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.4.0): โœจNew Intel CPU and XPU hw optimized AWQ `TorchFusedAWQ` kernel. Torch Fused kernels now compatible with `torch.compile`. Fixed AWQ MoE model compatibility and reduced vram usage. * 11/3/2025 [5.2.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.2.0): ๐ŸŽ‰Minimax M2 support with [ModelCloud BF16 M2 Model](https://huggingface.co/ModelCloud/MiniMax-M2-BF16). New `VramStrategy.Balanced` quantization property for reduced memory usage for large MoE on multi-3090 (24GB) devices. โœจMarin model. New AWQ Torch reference kernel. Fix AWQ Marlin kernel for bf16. Fix GLM 4.5/4.6 MoE missing `mtp` layers on model save (HF bug). Modular refractor. ๐ŸŽ‰AWQ support out of beta with full feature support in including multi-gpu quant and MoE vram saving. โœจBrumby (attention free) model support. โœจBrumby (attention free) model support. โœจIBM Granite Nano support. New `calibration_concat_separator` config option. * 10/24/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): ๐ŸŽ‰ Data-parallel quant support for `MoE` models on multi-gpu using `nogil` Python. `offload_to_disk` support enabled by default to massively reduce `cpu` ram usage. New `Intel` and `AMD` cpu hw accelerated `TorchFused` kernel. Packing stage is now 4x faster and now inlined with quantization. `Vram` pressure for large models reduced during quantization. @@ -202,8 +203,8 @@ GPT-QModel is validated for Linux, MacOS, and Windows 11: |-----------------|---------------| --- | -------------- |-----------------------------------------------| | ๐Ÿง Linux | Nvidia GPU | โœ… | `Ampere+` | Marlin, Exllama V2, Exallma V1, Triton, Torch | | ๐Ÿง Linux | AMD GPU | โœ… | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch | -| ๐Ÿง Linux | Intel XPU | โœ… | `Arc`, `Datacenter Max` | Torch Fused (Python 2.8+), Torch | -| ๐Ÿง Linux | Intel/AMD CPU | โœ… | `avx`, `amx`, `xmx` | Torch Fused (Python 2.8+), Torch | +| ๐Ÿง Linux | Intel XPU | โœ… | `Arc`, `Datacenter Max` | TorchFused, TorchFusedAWQ, Torch | +| ๐Ÿง Linux | Intel/AMD CPU | โœ… | `avx`, `amx`, `xmx` | TorchFused, TorchFusedAWQ, Torch | | ๐ŸŽ MacOS | GPU (Metal) / CPU | โœ… | `Apple Silicon`, `M1+` | Torch, MLX via conversion | | ๐ŸชŸ Windows | GPU (Nvidia) / CPU | โœ… | `Nvidia` | Torch | diff --git a/docs/torch_fused_int4_transformations.md b/docs/torch_fused_int4_transformations.md new file mode 100644 index 000000000..13636ff5d --- /dev/null +++ b/docs/torch_fused_int4_transformations.md @@ -0,0 +1,288 @@ +``` +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +``` + +# Torch Fused INT4 Transformations + +This note explains what `TorchFusedQuantLinear.transform_xpu` and `transform_cpu` +do to GPTQ-format tensors before calling the fused `torch.ops.aten` kernels. +The goal is to document the exact tensor shapes, the axis permutations, and the +bit packing order expected by `aten._weight_int4pack_mm_*` so you do not need to +reverse engineer the loops in `gptqmodel/nn_modules/qlinear/torch_fused.py:175-219`. + +## Terminology and starting layout + +Let: + +* `I` โ€“ number of input features. +* `O` โ€“ number of output features. +* `B` โ€“ quantization bits (always 4 here). +* `W` โ€“ number of bits stored per lane in `pack_dtype` (`W = 32` by default). +* `pack_factor = W / B` โ€“ how many quantized values share one lane (8 when `B=4`). +* `group_size` โ€“ number of input channels that share one `(scale, zero)` pair. +* `G = ceil(I / group_size)` โ€“ number of groups (and rows in `scales`/`qzeros`). + +Immediately after loading a GPTQ v2 checkpoint: + +``` +qweight : [I / pack_factor, O] dtype = pack_dtype (int32) +qzeros : [G, O / pack_factor] dtype = pack_dtype (int32) +scales : [G, O] dtype = fp16 +g_idx : [I] dtype = int32 (maps input channel -> group id) +``` + +Each entry of `qweight`/`qzeros` is a 32-bit lane that packs `pack_factor` +4-bit nibbles. Conceptually, a single column of `qweight` (one output channel) +looks like this before unpacking: + +``` +raw lane bits (int32) โ†’ [in_{k+7}] [in_{k+6}] โ€ฆ [in_{k+1}] [in_{k}] +bit positions โ†’ 31..28 27..24 7..4 3..0 +``` + +## `transform_xpu(dtype)` + +The XPU path needs tensors that match +`aten._weight_int4pack_mm_with_scales_and_zeros`. The routine performs five +steps: + +1. **Scales cast** โ€“ `self.scales = self.scales.clone().to(dtype)`. No layout changes. +2. **Unpack `qzeros`** โ€“ expand each 32-bit lane into `pack_factor` nibbles, mask + with `0xF`, then reshape to `[G, O]`. + + ``` + Before unpack (per group g): + qzeros[g] = [ lane_0, lane_1, โ€ฆ ] (each lane holds 8 outputs) + After unpack: + zeros[g] = [ z_{0}, z_{1}, โ€ฆ, z_{O-1} ] + + lane layout + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ 32 bits โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + | z_{b+7} | โ€ฆ | z_{b+1} | z_{b} | + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ† reshaped into consecutive columns + ``` + +3. **Unpack and reorder `qweight`** โ€“ identical nibble extraction produces a + tensor shaped `[I, O]`. It is then re-indexed with `ret_idx` so that input + rows follow the `g_idx` schedule used during quantization, and finally + transposed to `[O, I]`. At this point every row corresponds to one output + channel and every column corresponds to an *unpacked* input channel. + + ``` + weight_full (after transpose): + input columns โ†’ + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + out0โ”‚ w00 w01 w02 w03 w04 w05 w06 w07 โ€ฆ w0(I-1) โ”‚ + out1โ”‚ w10 w11 w12 w13 w14 w15 w16 w17 โ€ฆ w1(I-1) โ”‚ + โ”‚ โ‹ฎ โ”‚ + ``` + +4. **Pack rows into XPU layout** โ€“ the double `for` loop rebuilds `int32` + lanes, but now the rows are `O` (output channels) instead of packed input + clusters. The resulting tensor has shape `[O, I / pack_factor]`. + + ``` + packed[row=j, col=k] stores inputs (8 values) = + weight_full[j, 8k + i] for i = 0..7 + + 31..28 27..24 23..20 19..16 15..12 11..8 7..4 3..0 + [in+7] [in+6] [in+5] [in+4] [in+3] [in+2] [in+1] [in+0] + ``` + +5. **Finalize buffers** โ€“ `self.qweight = packed.contiguous()` (int32) and + `self.qzeros = zeros.contiguous()` (float, `[G, O]`). These, together with + `self.scales`, match the signature of + `aten._weight_int4pack_mm_with_scales_and_zeros(x, qweight, group_size, scales, qzeros)`. + +For XPU execution, `_fused_op_forward` also permutes activations before the +matmul: + +``` +x = x[:, ret_idx] +``` + +This applies the inverse of the group-wise reordering performed in step 3, +ensuring that column `i` of `qweight` always multiplies the same logical input +channel the calibration used. + +### Visual summary (XPU) + +``` + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” unpack+permute โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +raw qweight โ†’โ”‚ I/8 ร— O โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ†’ โ”‚ O ร— I โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + pack rows โ†“ + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ O ร— (I/8) โ”‚ int32 lanes + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + +raw qzeros โ†’ [G ร— O/8] lanes โ”€โ”€unpackโ”€โ”€โ–บ zeros [G ร— O] +scales โ†’ [G ร— O] (cast to `dtype`) +``` + +## `transform_cpu(dtype)` + +The CPU path shares the unpack/reorder logic but delegates the final packing to +PyTorchโ€™s helper so the layout matches +`aten._weight_int4pack_mm_for_cpu`. Steps: + +1. **Scales cast** โ€“ identical to the XPU path. +2. **Unpack + reorder `qweight`** โ€“ same as step 3 above, yielding + `weight_full = [O, I]` with 4-bit integers. +3. **Convert to int4pack** โ€“ `torch.ops.aten._convert_weight_to_int4pack_for_cpu` + repacks that matrix into `torch.uint8` tiles of shape `[O, I * B / 8]` + (i.e., `I/2` columns when `B=4`). Each byte stores two adjacent inputs. + + ``` + byte layout (per output row j): + bits 7..4 โ†’ weight_full[j, 2k+1] + bits 3..0 โ†’ weight_full[j, 2k] + ``` + + The helper currently requires both `O` and `I` to be multiples of 16; the op + raises `_convert_weight_to_int4pack_cpu : expect N to be dividable by 16` + otherwise. + +4. **Merge scales and zeros** โ€“ The fused CPU kernel expects scale and zero + offsets in a single tensor, so `pack_scales_and_zeros` stacks them along the + last dimension: + + ``` + scales_and_zeros[g, o] = [ scale[g, o], zero[g, o] ] shape = [G, O, 2] + + group g + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ out dimension โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ [ s, z ] [ s, z ] โ€ฆ [ s, z ] โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + ``` + + The current GPTQ fused path only uses symmetric int4, so `self.qzeros` is + zeroed before packing (`zero[g, o] = 0`). Non-symmetric per-group offsets + would require extending this block. + +5. **Buffers used at runtime** โ€“ `self.qweight` is now the `uint8` + int4pack tensor, `self.scales_and_zeros` stores the merged metadata, and + `_fused_op_forward` calls + `aten._weight_int4pack_mm_for_cpu(x, qweight_uint8, group_size, scales_and_zeros)`. + +### Visual summary (CPU) + +``` +weight_full (O ร— I, ints) โ”€โ”€_convert_weight_to_int4pack_for_cpuโ”€โ”€โ–บ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ O ร— I โ”‚ โ”‚ O ร— (I/2) โ”‚ uint8 +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ†‘ โ†‘ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ unpack & transpose from raw qweight โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + +scales (G ร— O, dtype `dtype`) +qzeros (G ร— O, zeroed) โ”€โ”€โ–บ scales_and_zeros (G ร— O ร— 2) +``` + +## Activation permutation and fused matmul + +Both device paths rely on the same activation permutation: + +1. `ret_idx` is built once from `g_idx` so that unpacked rows can be restored to + the calibration order. +2. Before calling any fused matmul, `_fused_op_forward` applies `x = x[:, ret_idx]`. +3. The matmul then multiplies `x` with the packed `qweight`: + + * XPU: `aten._weight_int4pack_mm_with_scales_and_zeros` + consumes `qweight[int32][O, I/8]`, `scales[G, O]`, and `qzeros[G, O]`. + * CPU: `aten._weight_int4pack_mm_for_cpu` + consumes `qweight[uint8][O, I/2]` and `scales_and_zeros[G, O, 2]`. + +Because the same `ret_idx` is used for both the unpacked weight (during packing) +and the activation tensor (during inference), every nibble in the packed matrix +aligns with the correct logical input column. + +## Comparing XPU vs CPU transformations + +Although both device paths share the same unpack โ†’ reorder โ†’ transpose steps, +they diverge in how the packed tensors are laid out and what the fused matmul +expects afterward. The table below highlights the key differences for quick +debugging. + +| Aspect | XPU (`transform_xpu`) | CPU (`transform_cpu`) | +|----------------------------|---------------------------------------------------------------|-------------------------------------------------------------------| +| Packed `qweight` shape | `[O, I / 8]`, dtype `int32` | `[O, I / 2]`, dtype `uint8` | +| Bits per storage lane | 32-bit lane packs 8 inputs; nibble order `[in+7 โ€ฆ in+0]` | 8-bit lane packs 2 inputs; high nibble = odd, low nibble = even | +| Packing direction | Manual double-loop packs along **columns** of `weight_full` | `_convert_weight_to_int4pack_for_cpu` packs along **columns** into bytes | +| Per-group zeros | Unpacked to full `[G, O]` tensor and passed separately | Forced to zero and merged with scales via `pack_scales_and_zeros` | +| Scale format | One tensor per group (`scales[G, O]`) | Concatenated `[..., 0] = scale`, `[..., 1] = zero` (`float`) | +| Fused kernel call | `_weight_int4pack_mm_with_scales_and_zeros(x, qW, gsz, s, z)` | `_weight_int4pack_mm_for_cpu(x, qW, gsz, scales_and_zeros)` | +| Alignment requirements | Determined by manual pack loop (only needs `I % 8 == 0`) | Kernel enforces `I % 16 == 0` and `O % 16 == 0` | +| Activation permutation | `x = x[:, ret_idx]` prior to matmul (same code path) | Same permutation reuse | + +Visually, you can think of the difference as *row-major lane packing* (XPU) +versus *byte-tiling* (CPU): + +``` +XPU: | int32 lane | = [w7][w6][w5][w4][w3][w2][w1][w0] +CPU: | uint8 lane | = [w1][w0] +``` + +Both forms originate from the same `[O, I]` intermediate; the divergence is only +in the final storage type, accompanying metadata, and fused operator ABI. + +## AWQ compatibility (`torch_fused_awq.py`) + +`TorchFusedAwqQuantLinear` (`gptqmodel/nn_modules/qlinear/torch_fused_awq.py`) +reuses the CPU fused kernel while accepting checkpoints emitted by the AWQ +tooling. The module always expects `qweight` to be stored in the AWQ layout +`[in_features, out_features / pack_factor]`, meaning each row corresponds to a +single logical input channel. `transform_cpu_awq` performs a fixed shim before +the standard CPU packing runs: + +1. **Unpack AWQ rows** โ€“ `unpack_awq` expands each column lane into eight + outputs, yielding `iweight[int8][I, O]` and `izeros[int8][G, O]`. Both + tensors are then permuted with `reverse_awq_order` (the inverse of + `quantization.awq.utils.packing_utils.AWQ_ORDER`) so the columns match the + logical transformer layout expected by GPTQ. +2. **Normalize zero codes** โ€“ AWQ stores integer zero points per output channel. + `transform_cpu_awq` converts them into floating offsets compatible with the + fused kernel using + `zeros_fp16 = (2^{bits-1} - izeros) * scales_fp32`, keeping the result in + `float16` so the metadata matches the original AWQ calibration statistics. +3. **Repack into GPTQ lanes** โ€“ The unpacked `iweight` matrix is reshaped to + `[I / pack_factor, pack_factor, O]` and re-packed along the `pack_factor` + dimension so each row once again represents eight inputs inside a 32-bit + lane. After this step `self.qweight` is indistinguishable from a GPTQ v2 + tensor, which means the regular `transform_cpu` logic can run unchanged. +4. **Delegate to the base CPU transform** โ€“ Calling `super().transform_cpu` + converts the temporary GPTQ-formatted `qweight` into the `[O, I/2]` `uint8` + int4pack layout and produces `scales_and_zeros` from the (temporarily zeroed) + metadata. +5. **Restore AWQ metadata** โ€“ Immediately afterward, the AWQ shim reinstates + the real `float16` scales and the converted zero offsets, then rebuilds + `scales_and_zeros = pack_scales_and_zeros(scales, zeros_fp16)`. This ensures + `_weight_int4pack_mm_for_cpu` receives the same affine parameters the AWQ + calibration solved for. + +Because the shim runs entirely on the CPU path, `TorchFusedAwqQuantLinear` +currently raises `NotImplementedError` when asked to run the fused transform on +`xpu` devices. If the module has not been transformed yet (or fused ops are +unavailable), inference falls back to the dense AWQ matmul computed by +`awq_weight_dequantize`, which simply dequantizes the cached AWQ tensors on the fly. + +## Quick reference + +| Stage | Shape / dtype (int4) | Notes | +|--------------------------------|-----------------------------------------------------------|------------------------------------------------| +| Raw `qweight` | `[I / 8, O]`, `int32` | 8 nibbles per lane | +| After unpack + transpose | `[O, I]`, `int8` (values in `[0, 15]`) | Used by both device paths | +| Packed XPU `qweight` | `[O, I / 8]`, `int32` | Bits `[3:0]` hold the lowest-numbered channel | +| Packed CPU `qweight` | `[O, I / 2]`, `uint8` | High nibble = odd input, low nibble = even | +| `qzeros` (post-XPU transform) | `[G, O]`, matches `scales` | Passed separately to the XPU fused op | +| `scales_and_zeros` (CPU only) | `[G, O, 2]`, float | `[..., 0] = scale`, `[..., 1] = zero` | +| Raw AWQ `qweight` | `[I, O / 8]`, `int32` | Rows are single inputs packed across outputs | +| Unpacked AWQ weights/zeros | `iweight[I, O]`, `izeros[G, O]`, `int8` | Produced by `unpack_awq` + `reverse_awq_order` | +| AWQ zero offsets (final) | `[G, O]`, `float16` | `(2^{bits-1} - izeros) * scales`; merged via `pack_scales_and_zeros` | + +These details mirror the expectations of the Intel XPU and CPU fused matmul +kernels, and the ASCII layouts above describe how rows/columns line up inside +every packed tensor before the fused matmul executes. diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index c17ef2c73..160076411 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -111,6 +111,38 @@ def optimize(self): super().optimize() + def _build_ret_idx(self) -> torch.Tensor: + existing = getattr(self, "ret_idx", None) + total = self.g_idx.shape[0] + if isinstance(existing, torch.Tensor) and existing.numel() == total: + return existing + + device = self.g_idx.device + ret_idx = torch.zeros(total, dtype=torch.int32, device=device) + group_size = max(int(self.group_size), 1) + groups = total // group_size + remainder = total % group_size + g_idx = self.g_idx.to(torch.int32) + g_idx_2 = g_idx * group_size + + if remainder > 0: + mask = g_idx == groups + if mask.any(): + g_idx_2[mask] += torch.arange(remainder, device=device, dtype=torch.int32) + + if groups > 0: + base = torch.arange(group_size, device=device, dtype=torch.int32) + for i in range(groups): + mask = g_idx == i + if not mask.any(): + continue + count = int(mask.sum().item()) + g_idx_2[mask] += base[:count] + + ret_idx[g_idx_2] = torch.arange(total, device=device, dtype=torch.int32) + self.ret_idx = ret_idx + return ret_idx + def train(self, mode: bool = True): old_train = self.training if mode == old_train: @@ -141,7 +173,7 @@ def train(self, mode: bool = True): return super().train(mode=mode) def transform_xpu(self, dtype): - self.scales = self.scales.clone().to(dtype).contiguous() + self.scales = self.scales.to(dtype).contiguous() # Unpack qzeros zeros = torch.bitwise_right_shift( torch.unsqueeze(self.qzeros, 2).expand(-1, -1, self.pack_factor), @@ -156,17 +188,8 @@ def transform_xpu(self, dtype): ).to(self.dequant_dtype), self.maxq ) - self.ret_idx = torch.zeros(self.g_idx.shape[0], dtype=torch.int32).to(self.g_idx.device) - groups = self.g_idx.shape[0] // self.group_size - remainder = self.g_idx.shape[0] % self.group_size - g_idx_2 = self.g_idx * self.group_size - if remainder > 0: - g_idx_2[self.g_idx == groups] += torch.arange(remainder).to(self.g_idx_2.device).to(self.g_idx_2.dtype) - arange_tensor = torch.arange(self.group_size).to(self.g_idx.device).to(self.g_idx.dtype) - for i in range(groups): - g_idx_2[self.g_idx == i] += arange_tensor - self.ret_idx[g_idx_2] = torch.arange(self.g_idx.shape[0]).to(self.ret_idx.device).to(self.ret_idx.dtype) - weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t() + ret_idx = self._build_ret_idx() + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, ret_idx).t() # Pack qweight packed = torch.zeros(weight.shape[0], weight.shape[1] // self.pack_factor, dtype=torch.int32, device=weight.device) for col in range(weight.shape[1] // self.pack_factor): @@ -177,8 +200,8 @@ def transform_xpu(self, dtype): self.qweight = packed.contiguous() self.qzeros = zeros.contiguous() - def transform_cpu(self, dtype): - self.scales = self.scales.clone().to(dtype).contiguous() + def transform_cpu(self, dtype, do_scales_and_zeros: bool = True): + self.scales = self.scales.to(dtype).contiguous() # Unpack and reorder qweight weight = torch.bitwise_and( torch.bitwise_right_shift( @@ -187,20 +210,13 @@ def transform_cpu(self, dtype): ).to(self.dequant_dtype), self.maxq ) - self.ret_idx = torch.zeros(self.g_idx.shape[0], dtype=torch.int32).to(self.g_idx.device) - groups = self.g_idx.shape[0] // self.group_size - remainder = self.g_idx.shape[0] % self.group_size - g_idx_2 = self.g_idx * self.group_size - if remainder > 0: - g_idx_2[self.g_idx == groups] += torch.arange(remainder).to(self.g_idx_2.device).to(self.g_idx_2.dtype) - arange_tensor = torch.arange(self.group_size).to(self.g_idx.device).to(self.g_idx.dtype) - for i in range(groups): - g_idx_2[self.g_idx == i] += arange_tensor - self.ret_idx[g_idx_2] = torch.arange(self.g_idx.shape[0]).to(self.ret_idx.device).to(self.ret_idx.dtype) - weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, self.ret_idx).t() + ret_idx = self._build_ret_idx() + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, ret_idx).t() self.qweight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(weight.int(), 1).contiguous() - self.qzeros = torch.zeros_like(self.scales).contiguous() - self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) + + if do_scales_and_zeros: + self.qzeros = torch.zeros_like(self.scales).contiguous() + self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) def transform(self, dtype, device): if device == "xpu": diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py new file mode 100644 index 000000000..aaa075fd2 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import math + +import torch + +from ...adapter.adapter import Adapter +from ...quantization.awq.utils.packing_utils import ( + dequantize_gemm, + reverse_awq_order, + unpack_awq, +) +from ...utils.backend import BACKEND +from ...utils.logger import setup_logger +from ...utils.torch import TORCH_HAS_FUSED_OPS +from .torch_fused import Int4PackedOp, TorchFusedQuantLinear, pack_scales_and_zeros + + +log = setup_logger() + + +class TorchFusedAwqQuantLinear(TorchFusedQuantLinear): + """Torch fused AWQ variant based on GPTQ fused kernels via CPU int4 packing.""" + + QUANT_TYPE = "torch_fused_awq" + + # inherit from torch fused + SUPPORTS_BITS = TorchFusedQuantLinear.SUPPORTS_BITS + SUPPORTS_GROUP_SIZE = TorchFusedQuantLinear.SUPPORTS_GROUP_SIZE + SUPPORTS_DESC_ACT = TorchFusedQuantLinear.SUPPORTS_DESC_ACT + SUPPORTS_SYM = TorchFusedQuantLinear.SUPPORTS_SYM + SUPPORTS_SHARDS = TorchFusedQuantLinear.SUPPORTS_SHARDS + SUPPORTS_TRAINING = TorchFusedQuantLinear.SUPPORTS_TRAINING + SUPPORTS_AUTO_PADDING = TorchFusedQuantLinear.SUPPORTS_AUTO_PADDING + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = TorchFusedQuantLinear.SUPPORTS_IN_FEATURES_DIVISIBLE_BY + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = TorchFusedQuantLinear.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY + SUPPORTS_DEVICES = TorchFusedQuantLinear.SUPPORTS_DEVICES + SUPPORTS_PLATFORM = TorchFusedQuantLinear.SUPPORTS_PLATFORM + SUPPORTS_PACK_DTYPES = TorchFusedQuantLinear.SUPPORTS_PACK_DTYPES + SUPPORTS_ADAPTERS = TorchFusedQuantLinear.SUPPORTS_ADAPTERS + REQUIRES_FORMAT_V2 = TorchFusedQuantLinear.REQUIRES_FORMAT_V2 + + # AWQ kernels are only accuracy validate for float16 for now + SUPPORTS_DTYPES = [torch.float16] + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = True, + **kwargs, + ): + kwargs.setdefault("backend", BACKEND.TORCH_FUSED_AWQ) + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + adapter=adapter, + # Skip base buffer init, we need to manually init buffers for awq + register_buffers=False, + **kwargs, + ) + + # Create awq buffers + if register_buffers: + # AWQ packs each input row into pack_factor-wide columns for int4 lanes. + pack_cols = max(1, self.out_features // self.pack_factor) + qweight_shape = (self.in_features, pack_cols) + group_size = max(int(self.group_size), 1) + # Each group holds group_size input rows; ceil ensures remaining rows are included. + group_rows = max(1, math.ceil(self.in_features / group_size)) + + self.register_buffer( + "qweight", + torch.zeros(qweight_shape, dtype=self.pack_dtype), + ) + + self.register_buffer( + "qzeros", + torch.zeros((group_rows, pack_cols), dtype=self.pack_dtype), + ) + + self.register_buffer( + "scales", + torch.zeros((group_rows, self.out_features), dtype=torch.float16), + ) + + self.register_buffer("g_idx", torch.arange(self.in_features, dtype=torch.int32) // group_size) + + if bias: + self.register_buffer("bias", torch.zeros(self.out_features, dtype=torch.float16)) + else: + self.bias = None + + def prepare_awq_fused_tensors(self, need_zeros: bool = True): + self.scales.to(torch.float16).contiguous() + + iweight, izeros = unpack_awq(self.qweight, self.qzeros, self.bits) + iweight, izeros = reverse_awq_order(iweight, izeros, self.bits) + max_val = (1 << self.bits) - 1 + iweight = torch.bitwise_and(iweight, max_val) + if izeros is None: + raise RuntimeError("AWQ fused kernel requires zero points.") + izeros = torch.bitwise_and(izeros, max_val) + + if need_zeros: + zero_offset = 1 << (self.bits - 1) + zeros = (zero_offset - izeros.reshape_as(self.scales)) * self.scales + + gptq_qweight = self.pack_awq_qweight(iweight) + gptq_qzeros = self.pack_awq_qzeros(izeros) + return gptq_qweight, gptq_qzeros, self.scales, zeros if need_zeros else None + + def pack_awq_qweight(self, iweight: torch.Tensor) -> torch.Tensor: + in_features, out_features = iweight.shape + pack_factor = int(self.pack_factor) + if in_features % pack_factor != 0: + raise ValueError( + f"AWQ in_features={in_features} must be divisible by pack_factor={pack_factor}." + ) + rows = iweight.view(in_features // pack_factor, pack_factor, out_features) + packed = torch.zeros( + (rows.shape[0], out_features), + dtype=self.pack_dtype, + device=iweight.device, + ) + shifts = range(0, pack_factor * self.bits, self.bits) + for lane, shift in enumerate(shifts): + packed |= rows[:, lane, :].to(torch.int32) << shift + return packed.contiguous() + + def pack_awq_qzeros(self, izeros: torch.Tensor) -> torch.Tensor: + pack_factor = int(self.pack_factor) + if izeros.shape[1] % pack_factor != 0: + raise ValueError( + f"AWQ qzeros dimension {izeros.shape[1]} must be divisible by pack_factor={pack_factor}." + ) + cols = izeros.view(izeros.shape[0], izeros.shape[1] // pack_factor, pack_factor) + packed = torch.zeros( + (cols.shape[0], cols.shape[1]), + dtype=self.pack_dtype, + device=izeros.device, + ) + shifts = range(0, pack_factor * self.bits, self.bits) + for lane, shift in enumerate(shifts): + packed |= cols[:, :, lane].to(torch.int32) << shift + return packed.contiguous() + + def transform_cpu_awq(self, dtype): + self.qweight, self.qzeros, scales, zeros = self.prepare_awq_fused_tensors() + + super().transform_cpu(dtype, do_scales_and_zeros=False) + + self.scales = scales.to(device=self.qweight.device, dtype=dtype).contiguous() + self.qzeros = zeros.to(device=self.qweight.device, dtype=dtype).contiguous() + self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) + + def transform_xpu_awq(self, dtype): + self.qweight, self.qzeros, scales, _ = self.prepare_awq_fused_tensors(need_zeros=False) + + super().transform_xpu(dtype) + + self.scales = scales.to(device=self.qweight.device, dtype=dtype).contiguous() + + def transform_cpu(self, dtype): + self.transform_cpu_awq(dtype) + + def awq_weight_dequantize(self, device, dtype): + return dequantize_gemm( + self.qweight, + self.qzeros, + self.scales, + self.bits, + self.group_size, + ).to(device=device, dtype=dtype) + + def transform(self, dtype, device): + if device == "cpu": + self.transform_cpu(dtype) + elif device == "xpu": + self.transform_xpu_awq(dtype) + else: + raise NotImplementedError( + "TorchFusedAwqQuantLinear only supports fused transforms on CPU or XPU devices." + ) + + def forward(self, x: torch.Tensor): + out_shape = x.shape[:-1] + (self.out_features,) + x_flat = x.reshape(-1, x.shape[-1]) + self.assert_supported_dtype(x_flat.dtype) + if not self.training and not self.transformed and TORCH_HAS_FUSED_OPS: + self.transform(x_flat.dtype, x_flat.device.type) + self.transformed = True + if x_flat.device.type == "cpu": + self.torch_fused_op = Int4PackedOp( + self.qweight, self.scales_and_zeros, self.group_size + ).eval() + import torch._inductor.config as config + config.freezing = True + config.max_autotune = True + + if self.transformed: + # log.debug("awq calling fused op") + out = self._fused_op_forward(x_flat) + else: + # log.debug("awq dense path") + weight = self.awq_weight_dequantize(device=x_flat.device, dtype=x_flat.dtype) + out = torch.matmul(x_flat, weight) + + if self.bias is not None: + out.add_(self.bias) + if self.adapter: + out = self.adapter.apply(x=x_flat, out=out) + + return out.reshape(out_shape) + + def assert_supported_dtype(self, dtype: torch.dtype): + if dtype not in self.SUPPORTS_DTYPES: + supported = ", ".join(str(d) for d in self.SUPPORTS_DTYPES) + raise TypeError( + f"{self.__class__.__name__} only supports input dtypes [{supported}], but received {dtype}." + ) + + +__all__ = ["TorchFusedAwqQuantLinear"] diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py index 0744c0511..2d2746bee 100644 --- a/gptqmodel/utils/backend.py +++ b/gptqmodel/utils/backend.py @@ -12,6 +12,7 @@ class BACKEND(str, Enum): # gptq TORCH_FUSED = "torch_fused" # optimized for Intel XPU + TORCH_FUSED_AWQ = "torch_fused_awq" # AWQ variant of torch fused kernel TORCH = "torch" # GOOD: about 80% of triton TRITON = "triton" # VERY GOOD: all-around kernel EXLLAMA_V1 = "exllama_v1" # FAST: optimized for batching == 1 diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 0f72d4f92..03805504f 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -30,6 +30,7 @@ from ..nn_modules.qlinear.qqq import QQQQuantLinear from ..nn_modules.qlinear.torch import TorchQuantLinear from ..nn_modules.qlinear.torch_fused import TorchFusedQuantLinear +from ..nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear from ..nn_modules.qlinear.tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT, TritonV2QuantLinear from ..quantization import FORMAT, METHOD from ..utils.logger import setup_logger @@ -70,6 +71,7 @@ BACKEND.GEMM: AwqGEMMQuantLinear, BACKEND.GEMV: AwqGEMVQuantLinear, BACKEND.GEMV_FAST: AwqGEMVFastQuantLinear, + BACKEND.TORCH_FUSED_AWQ: TorchFusedAwqQuantLinear, BACKEND.TORCH_AWQ: AwqTorchQuantLinear, }), } @@ -85,7 +87,15 @@ FORMAT.QQQ: [BACKEND.QQQ], }, METHOD.AWQ: { - FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM, BACKEND.TORCH_AWQ], + FORMAT.GEMM: [ + BACKEND.MACHETE, + BACKEND.MARLIN, + BACKEND.EXLLAMA_V2, + BACKEND.EXLLAMA_V1, + BACKEND.GEMM, + BACKEND.TORCH_FUSED_AWQ, + BACKEND.TORCH_AWQ, + ], FORMAT.GEMV: [BACKEND.GEMV], FORMAT.GEMV_FAST: [BACKEND.GEMV_FAST], FORMAT.MARLIN: [BACKEND.MACHETE, BACKEND.MARLIN], diff --git a/gptqmodel/version.py b/gptqmodel/version.py index b62cbd7ea..1051a72aa 100644 --- a/gptqmodel/version.py +++ b/gptqmodel/version.py @@ -7,4 +7,4 @@ # even minor versions are release # 5.2.0 => release, 5.1.0 => devel # micro version (5.2.x) denotes patch fix, i.e. 5.2.1 is a patch fix release -__version__ = "5.3.0" +__version__ = "5.4.0" diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 9e446beeb..22186f6d4 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -22,6 +22,7 @@ marlin_import_exception, ) from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear +from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear from gptqmodel.utils.marlin import marlin_make_workspace_new @@ -30,12 +31,17 @@ log = LogBar.shared() DEVICE = torch.device("cuda:0") +CPU_DEVICE = torch.device("cpu") GREEN = "\033[32m" RED = "\033[31m" RESET = "\033[0m" +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + class TestAwqKernelOutput(unittest.TestCase): MODEL_PATH = Path("/monster/data/model/deepseek-r1-distill-qwen-7b-awq") TARGET = "model.layers.20.self_attn.v_proj" @@ -50,18 +56,20 @@ class TestAwqKernelOutput(unittest.TestCase): (BACKEND.GEMM, torch.float16, 0.004), # (BACKEND.GEMM, torch.bfloat16, 0.05), (BACKEND.MARLIN, torch.float16, 0.006), + (BACKEND.TORCH_FUSED_AWQ, torch.float16, 0.004), # (BACKEND.MARLIN, torch.bfloat16, 0.05), ] @classmethod def setUpClass(cls) -> None: - if not torch.cuda.is_available(): - raise unittest.SkipTest("CUDA is required for AWQ kernel output checks.") - - cls.device = DEVICE + cls.cuda_available = torch.cuda.is_available() + cls.device = DEVICE if cls.cuda_available else CPU_DEVICE cls.log = log cls._weight_map = cls._load_weight_map() cls.backend_skip_reason: Dict[BACKEND, str] = {} + if not cls.cuda_available: + cls.backend_skip_reason[BACKEND.GEMM] = "CUDA is required for GEMM backend." + cls.backend_skip_reason[BACKEND.MARLIN] = "CUDA is required for AWQ Marlin kernel." try: tensors = cls._load_awq_tensors(cls.TARGET) @@ -74,6 +82,10 @@ def setUpClass(cls) -> None: scales_cpu, bias_cpu, ) = tensors + cls.qweight_cpu = qweight_cpu + cls.qzeros_cpu = qzeros_cpu + cls.scales_cpu = scales_cpu + cls.bias_cpu = bias_cpu cls.in_features = qweight_cpu.shape[0] cls.out_features = qweight_cpu.shape[1] * (32 // cls.BITS) @@ -84,14 +96,28 @@ def setUpClass(cls) -> None: qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu ) - cls.modules[BACKEND.GEMM] = cls._build_gemm_module( - qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu + cls.modules[BACKEND.GEMM] = ( + cls._build_gemm_module(qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu) + if cls.cuda_available + else None ) - cls.modules[BACKEND.MARLIN] = cls._build_marlin_module( - qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu + cls.modules[BACKEND.MARLIN] = ( + cls._build_marlin_module(qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu) + if cls.cuda_available + else None ) + try: + cls.modules[BACKEND.TORCH_FUSED_AWQ] = cls._build_torch_fused_awq_module( + qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu + ) + except Exception as exc: + cls.backend_skip_reason[BACKEND.TORCH_FUSED_AWQ] = ( + f"Torch fused AWQ kernel unavailable: {exc}" + ) + cls.modules[BACKEND.TORCH_FUSED_AWQ] = None + base_inputs = cls._generate_inputs() cls.inputs: Dict[torch.dtype, List[torch.Tensor]] = {} cls.reference_outputs: Dict[torch.dtype, List[torch.Tensor]] = {} @@ -123,7 +149,8 @@ def tearDownClass(cls) -> None: for module in getattr(cls, "modules", {}).values(): if module is not None: del module - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() @classmethod def _load_weight_map(cls) -> Dict[str, str]: @@ -247,19 +274,52 @@ def _build_torch_awq_module( module.post_init() return module + @classmethod + def _build_torch_fused_awq_module( + cls, + qweight_cpu: torch.Tensor, + qzeros_cpu: torch.Tensor, + scales_cpu: torch.Tensor, + bias_cpu: torch.Tensor, + *, + device: torch.device = CPU_DEVICE, + ) -> TorchFusedAwqQuantLinear: + module = TorchFusedAwqQuantLinear( + bits=cls.BITS, + group_size=cls.GROUP_SIZE, + sym=True, + desc_act=False, + in_features=cls.in_features, + out_features=cls.out_features, + bias=True, + adapter=None, + register_buffers=True, + ).to(device) + + module.qweight.copy_(qweight_cpu.to(device)) + module.qzeros.copy_(qzeros_cpu.to(device)) + module.scales.copy_(scales_cpu.to(torch.float16).to(device)) + module.bias.copy_(bias_cpu.to(torch.float16).to(device)) + + module.eval() + module.post_init() + return module + @classmethod def _generate_inputs(cls) -> List[torch.Tensor]: large_shapes = [(4, 32), (2, 64), (1, 96)] medium_shapes = [(2, 32), (1, 48), (1, 32)] small_shapes = [(1, 32), (1, 24), (1, 16)] - try: - total_mem_gb = ( - torch.cuda.get_device_properties(cls.device).total_memory - / (1024 ** 3) - ) - except Exception: # pragma: no cover - total_mem_gb = 0.0 + total_mem_gb = 0.0 + if cls.device.type == "cuda": + try: + total_mem_gb = ( + torch.cuda.get_device_properties(cls.device).total_memory + / (1024 ** 3) + ) + except Exception: # pragma: no cover + total_mem_gb = 0.0 if os.getenv("GPTQMODEL_FAST_TESTS", "0") == "1": shapes = small_shapes @@ -288,19 +348,37 @@ def _forward( *, compute_dtype: Optional[torch.dtype] = None, output_dtype: Optional[torch.dtype] = None, + target_device: Optional[torch.device] = None, ) -> List[torch.Tensor]: + if target_device is None: + target_device = cls._infer_module_device(module) outputs: List[torch.Tensor] = [] with torch.inference_mode(): for tensor in inputs: local_tensor = tensor - if compute_dtype is not None and tensor.dtype != compute_dtype: - local_tensor = tensor.to(dtype=compute_dtype) + if local_tensor.device != target_device: + local_tensor = local_tensor.to(device=target_device) + if compute_dtype is not None and local_tensor.dtype != compute_dtype: + local_tensor = local_tensor.to(dtype=compute_dtype) result = module(local_tensor) if output_dtype is not None and result.dtype != output_dtype: result = result.to(dtype=output_dtype) outputs.append(result.detach().cpu()) return outputs + @staticmethod + def _infer_module_device(module: torch.nn.Module) -> torch.device: + try: + tensor = next(module.parameters()) + return tensor.device + except StopIteration: + pass + try: + tensor = next(module.buffers()) + return tensor.device + except StopIteration: + return torch.device("cpu") + def _maybe_skip_backend(self, backend: BACKEND) -> None: reason = self.backend_skip_reason.get(backend) if reason: @@ -315,6 +393,7 @@ def _summarize_results( atol: float, title: str, reference_label: str, + device: Optional[torch.device] = None, ) -> None: failures = [] total = len(actual_outputs) @@ -341,12 +420,14 @@ def _summarize_results( status = f"{GREEN}PASS{RESET}" if not failures else f"{RED}FAIL{RESET}" avg_abs_diff = mean_abs_diff / total if total else 0.0 details = "\n\n".join(str(detail) for detail in failures) if failures else "-" + device_label = str(device) if device is not None else "-" table = tabulate( [ [ backend.name, str(dtype), + device_label, total, f"{max_abs_diff:.6f}", f"{avg_abs_diff:.6f}", @@ -358,6 +439,7 @@ def _summarize_results( headers=[ "Backend", "DType", + "Device", "Samples", "MaxAbsDiff", "MeanAbsDiff", @@ -397,3 +479,42 @@ def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: fl title=f"AWQ Kernel Output {dtype}", reference_label="Torch AWQ output", ) + + @parameterized.expand( + [ + ("cpu", "cpu"), + ("xpu", "xpu:0"), + ] + ) + def test_torch_fused_awq_devices(self, _label: str, device_str: str) -> None: + self._maybe_skip_backend(BACKEND.TORCH_FUSED_AWQ) + if device_str.startswith("xpu") and not _xpu_available(): + self.skipTest("Torch fused AWQ XPU test requires Intel XPU runtime.") + + device = torch.device(device_str) + module = self._build_torch_fused_awq_module( + self.qweight_cpu, + self.qzeros_cpu, + self.scales_cpu, + self.bias_cpu, + device=device, + ) + + try: + actual_outputs = self._forward( + module, + self.inputs[torch.float16], + target_device=device, + ) + self._summarize_results( + reference_outputs=self.reference_outputs[torch.float16], + actual_outputs=actual_outputs, + backend=BACKEND.TORCH_FUSED_AWQ, + dtype=torch.float16, + atol=0.004, + title=f"Torch Fused AWQ Device {device_str}", + reference_label="Torch AWQ output", + device=device, + ) + finally: + del module diff --git a/tests/test_kernel_output_torch_fused.py b/tests/test_kernel_output_torch_fused.py index f299046db..aa1ddff02 100644 --- a/tests/test_kernel_output_torch_fused.py +++ b/tests/test_kernel_output_torch_fused.py @@ -9,6 +9,7 @@ from logbar import LogBar from parameterized import parameterized from torch import Tensor +from tabulate import tabulate from gptqmodel import BACKEND, GPTQModel from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear @@ -19,6 +20,10 @@ log = LogBar.shared() +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + class TestKernelOutput(unittest.TestCase): model_path = "sliuau/llama3.2-1b-4bit-group128" # hf "sliuau/llama3.2-1b-4bit-group128" target_qliner_map = { @@ -88,3 +93,115 @@ class TestKernelOutputXPU(TestKernelOutput): class TestKernelOutputXPUBFloat16(TestKernelOutputXPU): dtype = torch.bfloat16 + + +class TestTorchFusedKernelDevices(unittest.TestCase): + model_path = TestKernelOutput.model_path + target_qliner_map = TestKernelOutput.target_qliner_map + target = TestKernelOutput.target + dtype = torch.float16 + m = [1, 16, 64, 256] + k = 2048 + input_samples_each_size = 5 + r_tolerance = 0.0 + a_tolerance = 0.01 + reference_backend = BACKEND.TORCH + reference_device = "cpu" + + @classmethod + def setUpClass(cls): + torch.manual_seed(0) + cls.inputs = [] + for dim_0 in cls.m: + for _ in range(cls.input_samples_each_size): + cls.inputs.append(torch.rand((dim_0, cls.k), dtype=cls.dtype)) + + reference_model = GPTQModel.load( + cls.model_path, + backend=cls.reference_backend, + device=cls.reference_device, + dtype=cls.dtype, + ) + cls.reference_outputs = [ + cls.forward(reference_model, sample, cls.reference_backend) + for sample in cls.inputs + ] + del reference_model + + @classmethod + def forward(cls, model, x, backend: BACKEND): + target_qlinear_cls = cls.target_qliner_map[backend] + modules = find_modules(model.model, layers=[target_qlinear_cls]) + result = None + for name, module in modules.items(): + if name == cls.target: + result = module(x.to(model.device)) + break + + assert result is not None + return result + + def assert_on_mismatch(self, a: Tensor, b: Tensor, rtol=0.00005, atol=0.00005): + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + @parameterized.expand([ + ("cpu", "cpu"), + ("xpu", "xpu:0"), + ]) + def test_torch_fused_matches_cpu_reference(self, _name: str, device: str): + if device.startswith("xpu") and not _xpu_available(): + self.skipTest("Test requires XPU") + + model = GPTQModel.load( + self.model_path, + backend=BACKEND.TORCH_FUSED, + device=device, + dtype=self.dtype, + ) + failures = [] + for idx, sample in enumerate(self.inputs): + model_input = sample.to(model.device) + fused_out = self.forward(model, model_input, BACKEND.TORCH_FUSED) + reference = self.reference_outputs[idx] + try: + self.assert_on_mismatch( + reference.to("cpu"), + fused_out.to("cpu"), + self.r_tolerance, + self.a_tolerance, + ) + except AssertionError as exc: + failures.append(f"Sample {idx}: {str(exc).splitlines()[0]}") + + status = "PASS" if not failures else "FAIL" + table = tabulate( + [ + [ + BACKEND.TORCH_FUSED.name, + str(self.dtype), + device, + len(self.inputs), + f"{self.r_tolerance:.2e}", + f"{self.a_tolerance:.2e}", + status, + len(failures), + "\n\n".join(failures) if failures else "-", + ] + ], + headers=[ + "Backend", + "DType", + "Device", + "Samples", + "RTol", + "ATol", + "Status", + "Failures", + "Details", + ], + tablefmt="github", + ) + log.info("\nTorch Fused vs CPU Reference\n" + table) + + if failures: + raise AssertionError(f"{len(failures)} mismatched samples on device {device}") diff --git a/tests/test_torch_fused_awq.py b/tests/test_torch_fused_awq.py new file mode 100644 index 000000000..35f102838 --- /dev/null +++ b/tests/test_torch_fused_awq.py @@ -0,0 +1,185 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import json +import os +from functools import lru_cache +from pathlib import Path + +import pytest +import torch +from safetensors import safe_open +from tabulate import tabulate + +from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear +from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear +from gptqmodel.utils.torch import TORCH_HAS_FUSED_OPS + + +CHECKPOINT_DIR = Path("/monster/data/model/deepseek-r1-distill-qwen-7b-awq") +CHECKPOINT_MODULE = os.environ.get( + "GPTQMODEL_AWQ_TEST_MODULE", "model.layers.0.mlp.up_proj" +) + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +@lru_cache(maxsize=1) +def _load_awq_checkpoint_module(): + if not CHECKPOINT_DIR.exists(): + pytest.skip(f"AWQ checkpoint not available at {CHECKPOINT_DIR}") + + index_path = CHECKPOINT_DIR / "model.safetensors.index.json" + if not index_path.exists(): + pytest.skip(f"Missing model index at {index_path}") + + with index_path.open("r", encoding="utf-8") as fh: + index_data = json.load(fh) + weight_map = index_data["weight_map"] + + config_path = CHECKPOINT_DIR / "config.json" + with config_path.open("r", encoding="utf-8") as fh: + config = json.load(fh) + quant_cfg = config.get("quantization_config", {}) + bits = int(quant_cfg.get("bits", 4)) + group_size = int(quant_cfg.get("group_size", 128)) + + suffixes = ["qweight", "qzeros", "scales", "bias"] + tensors = {} + file_to_keys = {} + for suffix in suffixes: + full_key = f"{CHECKPOINT_MODULE}.{suffix}" + filename = weight_map.get(full_key) + if filename is None: + if suffix == "bias": + continue + raise KeyError(f"Missing tensor '{full_key}' in checkpoint index.") + file_to_keys.setdefault(filename, []).append(full_key) + + for filename, keys in file_to_keys.items(): + tensor_path = CHECKPOINT_DIR / filename + with safe_open(tensor_path, framework="pt", device="cpu") as handle: + for key in keys: + tensors[key] = handle.get_tensor(key).clone() + + qweight = tensors[f"{CHECKPOINT_MODULE}.qweight"].to(torch.int32).contiguous() + qzeros = tensors[f"{CHECKPOINT_MODULE}.qzeros"].to(torch.int32).contiguous() + scales = tensors[f"{CHECKPOINT_MODULE}.scales"].to(torch.float16).contiguous() + bias_key = f"{CHECKPOINT_MODULE}.bias" + bias = tensors.get(bias_key) + if bias is not None: + bias = bias.to(torch.float16).contiguous() + + pack_factor = 32 // bits + in_features = qweight.shape[0] + out_features = qweight.shape[1] * pack_factor + + return { + "bits": bits, + "group_size": group_size, + "in_features": in_features, + "out_features": out_features, + "qweight": qweight, + "qzeros": qzeros, + "scales": scales, + "bias": bias, + } + + +@pytest.mark.skipif(not TORCH_HAS_FUSED_OPS, reason="Torch fused ops require PyTorch>=2.8") +@pytest.mark.parametrize( + "device_str", + [ + pytest.param("cpu", id="cpu"), + pytest.param( + "xpu:0", + id="xpu", + marks=pytest.mark.skipif( + not _xpu_available(), reason="Torch fused AWQ XPU test requires Intel XPU runtime." + ), + ), + ], +) +def test_torch_fused_awq_matches_checkpoint_module(device_str: str): + module_data = _load_awq_checkpoint_module() + bits = module_data["bits"] + group_size = module_data["group_size"] + in_features = module_data["in_features"] + out_features = module_data["out_features"] + qweight = module_data["qweight"] + qzeros = module_data["qzeros"] + scales = module_data["scales"] + bias = module_data["bias"] + + device = torch.device(device_str) + + awq_module = AwqTorchQuantLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=bias is not None, + register_buffers=True, + ) + fused_module = TorchFusedAwqQuantLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=bias is not None, + register_buffers=True, + ) + + awq_module.qweight.copy_(qweight) + awq_module.qzeros.copy_(qzeros) + awq_module.scales.copy_(scales) + if bias is not None: + awq_module.bias.copy_(bias) + awq_module.post_init() + awq_module.eval() + + fused_module.register_buffer("qweight", qweight.clone(), persistent=True) + fused_module.qzeros.copy_(qzeros) + fused_module.scales.copy_(scales) + if bias is not None: + fused_module.bias.copy_(bias) + fused_module.post_init() + fused_module.eval() + + awq_module.to(device) + fused_module.to(device) + + dtype = torch.float16 + batch = 4 + x = torch.randn(batch, in_features, dtype=dtype, device=device) + baseline = awq_module(x) + fused_out = fused_module(x) + + rtol = 5e-3 + atol = 5e-3 + abs_diff = (fused_out - baseline).abs() + rel_diff = abs_diff / baseline.abs().clamp_min(1e-6) + summary = tabulate( + [ + [ + device_str, + str(dtype), + f"{rtol:.4g}", + f"{atol:.4g}", + f"{abs_diff.max().item():.4e}", + f"{rel_diff.max().item():.4e}", + ] + ], + headers=["Device", "DType", "RTol", "ATol", "AbsMaxDiff", "RelMaxDiff"], + tablefmt="github", + ) + print(summary) + torch.testing.assert_close(fused_out, baseline, rtol=rtol, atol=atol)