Skip to content
Merged

Machete #2082

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,4 @@ example.py
/gptqmodel_ext/marlin/kernel_fp16_ku4b8.cu
/gptqmodel_ext/marlin/kernel_fp16_ku8b128.cu
/gptqmodel_offload/
/gptqmodel_ext/machete/generated/
3 changes: 3 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ recursive-include gptqmodel_ext/exllama *.h *.cuh *.cu *.cpp
recursive-include gptqmodel_ext/exllamav2 *.h *.cuh *.cu *.cpp
recursive-include gptqmodel_ext/exllama_eora/eora *.h *.cuh *.cu *.cpp *.py
recursive-include gptqmodel_ext/marlin *.h *.cuh *.cu *.cpp
recursive-include gptqmodel_ext/machete *.h *.hpp *.cuh *.cu *.cpp *.py
recursive-include gptqmodel_ext/cutlass_extensions *.h *.hpp *.cuh *.cu *.cpp *.py
recursive-include gptqmodel_ext/qqq *.h *.cuh *.cu *.cpp
include gptqmodel_ext/pack_block_cpu.cpp
include gptqmodel_ext/marlin/generate_kernels.py
include gptqmodel_ext/machete/generate.py
recursive-exclude gptqmodel_ext __pycache__ *.pyc
prune tests/
prune format/
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
</p>

## Latest News
* 10/20/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): 🎉 Data-parallel quant support for `MoE` models on multi-gpu using `nogil` Python. `offload_to_disk` support enabled by
* 10/21/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): 🎉 Data-parallel quant support for `MoE` models on multi-gpu using `nogil` Python. `offload_to_disk` support enabled by
default to massively reduce `cpu` ram usage. New `Intel` and `AMD` cpu hw accelerated `TorchFused` kernel. Packing stage is now 4x faster and now inlined with quantization. `Vram` pressure for large models reduced during quantization.
`Machete` kernel added for Hopper+/Blackwell acceleration for gptq and awq models.
`act_group_aware` is 16k+ times faster and now the default when `desc_act=False` for higher quality recovery without inference penalty of `desc_act=True`. New beta quality `AWQ` support with full `gemm`,
`gemm_fast`, `marlin` kernel support. `LFM`, `Ling`, `Qwen3 Omni` model support. Quantization is now faster with reduced vram usage. Enhanced logging support with `LogBar`.
* 09/16/2025 [4.2.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.5): `hyb_act` renamed to `act_group_aware`. Removed finicky `torch` import within `setup.py`. Packing bug fix and prebuilt Pytorch 2.8 whls.
Expand Down Expand Up @@ -196,14 +197,14 @@ Native support support some of the most popular multi-modal models:

GPT-QModel is validated for Linux, MacOS, and Windows 11:

| Platform | Device | | Optimized Arch | Kernels |
|-----------------|---------------| --- | -------------- |-----------------------------------------------|
| 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` | Marlin, Exllama V2, Exallma V1, Triton, Torch |
| 🐧 Linux | AMD GPU | ✅ | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch |
| 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | Torch Fused (Python 2.8+), Torch |
| 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx`, `xmx` | Torch Fused, Torch |
| 🍎 MacOS | GPU (Metal) / CPU | ✅ | `Apple Silicon`, `M1+` | Torch, MLX via conversion |
| 🪟 Windows | GPU (Nvidia) / CPU | ✅ | `Nvidia` | Torch |
| Platform | Device | | Optimized Arch | Kernels |
|-----------------|---------------| --- | -------------- |--------------------------------------------------------|
| 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` | Machete, Marlin, Exllama V2, Exallma V1, Triton, Torch |
| 🐧 Linux | AMD GPU | ✅ | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch |
| 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | Torch Fused (Python 2.8+), Torch |
| 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx`, `xmx` | Torch Fused (Python 2.8+), Torch |
| 🍎 MacOS | GPU (Metal) / CPU | ✅ | `Apple Silicon`, `M1+` | Torch, MLX via conversion |
| 🪟 Windows | GPU (Nvidia) / CPU | ✅ | `Nvidia` | Torch |


## Install
Expand Down
27 changes: 17 additions & 10 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from importlib.metadata import PackageNotFoundError, version
from typing import Dict, List, Optional, Union

import accelerate
import torch
import transformers

Expand Down Expand Up @@ -38,6 +37,7 @@
from ..utils.backend import BACKEND
from ..utils.importer import auto_select_device, normalize_device_device_map, select_quant_linear
from ..utils.logger import setup_logger
from ..utils.machete import _validate_machete_device_support
from ..utils.marlin import _validate_marlin_device_support
from ..utils.model import (
auto_dtype,
Expand Down Expand Up @@ -478,7 +478,6 @@ def skip(*args, **kwargs):

init_contexts = [no_init_weights()]

layer_type = ""
with (ContextManagers(init_contexts)):
cls.before_model_load(cls, load_quantized_model=True)

Expand Down Expand Up @@ -507,8 +506,7 @@ def skip(*args, **kwargs):
# Get the first layer to determine layer type
layers, _ = get_module_by_name_prefix(model, cls.extract_layers_node())

layer0 = layers[0]
layer_type = layer0.__class__.__name__
layers[0]

modules = find_modules(model)
ignore_modules = [cls.lm_head] + cls.get_base_modules(model)
Expand All @@ -535,7 +533,6 @@ def skip(*args, **kwargs):
device=device,
)

log.debug(f"Loader1: device_map {device_map}")
if isinstance(device_map, str) and device_map not in [
"auto",
"balanced",
Expand All @@ -548,8 +545,8 @@ def skip(*args, **kwargs):
)



import torch
from typing import Dict, List, Optional

def build_layerwise_device_map(
model,
Expand Down Expand Up @@ -643,7 +640,7 @@ def assign(mod, device_id):
if owner:
device_map.setdefault(owner, fallback_device)
else:
log.debug(f"Loader: unable to map param '{param_name}' to a module; skipping fallback assignment.")
log.info(f"Loader: unable to map param '{param_name}' to a module; skipping fallback assignment.")

# -------------------------------------------------------------
# 6. Prune parent assignments that would override child devices
Expand All @@ -657,11 +654,11 @@ def assign(mod, device_id):
if child_name != name and child_name.startswith(f"{name}.")
}
if child_devices and (len(child_devices) > 1 or device_id not in child_devices):
log.debug(f"Loader: dropping parent '{name}' from device_map to preserve child placements.")
log.info(f"Loader: dropping parent '{name}' from device_map to preserve child placements.")
device_map.pop(name, None)

# optional logging for debug
log.debug(f"Loader: Built map across {num_gpus} GPU(s), "
log.info(f"Loader: Built map across {num_gpus} GPU(s), "
f"{len(device_map)} entries. First 8: {list(device_map.items())[:8]}")

return device_map
Expand Down Expand Up @@ -707,6 +704,16 @@ def assign(mod, device_id):

qcfg.runtime_format = FORMAT.GPTQ_V2

if backend == BACKEND.MACHETE:
if is_sharded:
raise ValueError(
"Format: The loading of sharded checkpoints with Machete is currently not supported."
)
if not _validate_machete_device_support():
raise ValueError(
f"Kernel: Machete kernel requires compute capability >= 9.0. Detected capability: {torch.cuda.get_device_capability()}"
)

if backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and (
preload_qlinear_kernel == ExllamaV2QuantLinear or qcfg.format == FORMAT.MARLIN):
if is_sharded:
Expand Down Expand Up @@ -742,7 +749,7 @@ def assign(mod, device_id):

# If we use marlin or bitblas to load the quantized model, the model is already a converted model,
# and we no longer need to call load_checkpoint_in_model()
if load_checkpoint_in_model and backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16, BACKEND.BITBLAS]:
if load_checkpoint_in_model and backend not in [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.MARLIN_FP16, BACKEND.BITBLAS]:
load_checkpoint_in_model_then_tie_weights(
model,
dtype=dtype,
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/nn_modules/hooked_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def _replace_module(module, child, name, level: int = 0, debug: bool = False) ->


def replace_module_with_hooked_legacy(module, level: int = 0, quant_lm_head: bool = False):
if level == 0:
log.info("Hooked Modules: Using legacy based config for targeting of modules")
# if level == 0:
# log.info("Hooked Modules: Using legacy based config for targeting of modules")

for name, child in module.named_children():
if not quant_lm_head and hasattr(module, "get_output_embeddings") and child == module.get_output_embeddings():
Expand Down
200 changes: 200 additions & 0 deletions gptqmodel/nn_modules/qlinear/awq_machete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from __future__ import annotations

from typing import Optional, Tuple

import torch

from ...adapter.adapter import Adapter, Lora
from ...models._const import DEVICE, PLATFORM
from ...nn_modules.qlinear import AWQuantLinear
from ...utils.backend import BACKEND
from ...utils.logger import setup_logger
from ...utils.machete import (
_validate_machete_device_support,
machete_import_exception,
machete_mm,
machete_prepack_B,
pack_quantized_values_into_int32,
)
from ...utils.marlin import replace_parameter, unpack_cols
from ...utils.marlin_scalar_type import scalar_types
from ...utils.rocm import IS_ROCM


log = setup_logger()


class AwqMacheteQuantLinear(AWQuantLinear):
SUPPORTS_BITS = [4, 8]
SUPPORTS_GROUP_SIZE = [-1, 32, 64, 128]
SUPPORTS_DESC_ACT = [False] # AWQ kernels do not reorder activations
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
SUPPORTS_TRAINING = False
SUPPORTS_AUTO_PADDING = False
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [64]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [128]

SUPPORTS_DEVICES = [DEVICE.CUDA]
SUPPORTS_PLATFORM = [PLATFORM.LINUX]
SUPPORTS_PACK_DTYPES = [torch.int32]
SUPPORTS_ADAPTERS = [Lora]

SUPPORTS_DTYPES = [torch.float16, torch.bfloat16]

REQUIRES_FORMAT_V2 = False

QUANT_TYPE = "awq_machete"

TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}

def __init__(
self,
bits: int,
group_size: int,
desc_act: bool,
sym: bool,
in_features: int,
out_features: int,
bias: bool = False,
pack_dtype: torch.dtype = torch.int32,
adapter: Adapter = None,
register_buffers: bool = False,
**kwargs):
if machete_import_exception is not None:
raise ValueError(
"Trying to use the machete backend, but could not import the "
f"C++/CUDA dependencies with the following error: {machete_import_exception}"
)

if bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {bits}. Supported: {list(self.TYPE_MAP.keys())}")

super().__init__(
bits=bits,
group_size=group_size,
sym=sym,
desc_act=False,
in_features=in_features,
out_features=out_features,
bias=bias,
pack_dtype=pack_dtype,
backend=kwargs.pop("backend", BACKEND.MACHETE),
adapter=adapter,
register_buffers=register_buffers,
**kwargs)

self.weight_type = self.TYPE_MAP[self.bits]
self.has_zero_points = True

@classmethod
def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if machete_import_exception is not None:
return False, ImportError(machete_import_exception)
return cls._validate(**args)

@classmethod
def validate_device(cls, device: DEVICE):
super().validate_device(device)
if device == DEVICE.CUDA:
if IS_ROCM:
raise NotImplementedError("Machete kernel is not supported on ROCm.")
if not _validate_machete_device_support():
raise NotImplementedError("Machete kernel requires compute capability >= 9.0.")

def post_init(self):
device = self.qweight.device

# Reconstruct integer weights from packed AWQ representation
qweight_int = unpack_cols(
self.qweight,
self.bits,
self.in_features,
self.out_features,
).to(device=device)

packed = pack_quantized_values_into_int32(
qweight_int,
self.weight_type,
packed_dim=0,
)
packed = packed.t().contiguous().t()
prepacked = machete_prepack_B(
packed,
a_type=self.scales.dtype,
b_type=self.weight_type,
group_scales_type=self.scales.dtype,
)
replace_parameter(
self,
"qweight",
torch.nn.Parameter(prepacked.contiguous(), requires_grad=False),
)

# Ensure scales are contiguous and resident on the correct device.
replace_parameter(
self,
"scales",
torch.nn.Parameter(self.scales.contiguous(), requires_grad=False),
)

# Convert zero-points: unpack columns, then pre-apply scales as expected by machete_mm
effective_group_size = self.in_features if self.group_size == -1 else self.group_size
num_groups = self.in_features // effective_group_size

qzeros_unpacked = unpack_cols(
self.qzeros,
self.bits,
num_groups,
self.out_features,
).to(device=device)

scales = self.scales
qzeros_fp = (-1.0 * scales.to(dtype=scales.dtype) * qzeros_unpacked.to(scales.dtype)).contiguous()
replace_parameter(
self,
"qzeros",
torch.nn.Parameter(qzeros_fp, requires_grad=False),
)

if self.bias is not None:
self.bias = self.bias.to(device=device)

super().post_init()

def forward(self, x: torch.Tensor):
if x.shape[0] == 0:
return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device)

input_2d = x.reshape(-1, x.shape[-1])
group_scales = self.scales.to(dtype=input_2d.dtype)
group_zeros = self.qzeros.to(dtype=input_2d.dtype)

output = machete_mm(
a=input_2d,
b_q=self.qweight,
b_type=self.weight_type,
b_group_scales=group_scales,
b_group_zeros=group_zeros,
b_group_size=self.group_size,
)

if self.bias is not None:
output.add_(self.bias)

result = output.reshape(x.shape[:-1] + (self.out_features,))

if self.adapter:
result = self.adapter.apply(x=x, out=result)

return result


__all__ = ["AwqMacheteQuantLinear"]
Loading