diff --git a/.gitignore b/.gitignore
index 0449bb94d..12e0533b1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -182,3 +182,4 @@ example.py
/gptqmodel_ext/marlin/kernel_fp16_ku4b8.cu
/gptqmodel_ext/marlin/kernel_fp16_ku8b128.cu
/gptqmodel_offload/
+/gptqmodel_ext/machete/generated/
diff --git a/MANIFEST.in b/MANIFEST.in
index 9efddd22b..a18a3e2c1 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -3,9 +3,12 @@ recursive-include gptqmodel_ext/exllama *.h *.cuh *.cu *.cpp
recursive-include gptqmodel_ext/exllamav2 *.h *.cuh *.cu *.cpp
recursive-include gptqmodel_ext/exllama_eora/eora *.h *.cuh *.cu *.cpp *.py
recursive-include gptqmodel_ext/marlin *.h *.cuh *.cu *.cpp
+recursive-include gptqmodel_ext/machete *.h *.hpp *.cuh *.cu *.cpp *.py
+recursive-include gptqmodel_ext/cutlass_extensions *.h *.hpp *.cuh *.cu *.cpp *.py
recursive-include gptqmodel_ext/qqq *.h *.cuh *.cu *.cpp
include gptqmodel_ext/pack_block_cpu.cpp
include gptqmodel_ext/marlin/generate_kernels.py
+include gptqmodel_ext/machete/generate.py
recursive-exclude gptqmodel_ext __pycache__ *.pyc
prune tests/
prune format/
diff --git a/README.md b/README.md
index 04a5d6fe1..f008ef433 100644
--- a/README.md
+++ b/README.md
@@ -17,8 +17,9 @@
## Latest News
-* 10/20/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): 🎉 Data-parallel quant support for `MoE` models on multi-gpu using `nogil` Python. `offload_to_disk` support enabled by
+* 10/21/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): 🎉 Data-parallel quant support for `MoE` models on multi-gpu using `nogil` Python. `offload_to_disk` support enabled by
default to massively reduce `cpu` ram usage. New `Intel` and `AMD` cpu hw accelerated `TorchFused` kernel. Packing stage is now 4x faster and now inlined with quantization. `Vram` pressure for large models reduced during quantization.
+`Machete` kernel added for Hopper+/Blackwell acceleration for gptq and awq models.
`act_group_aware` is 16k+ times faster and now the default when `desc_act=False` for higher quality recovery without inference penalty of `desc_act=True`. New beta quality `AWQ` support with full `gemm`,
`gemm_fast`, `marlin` kernel support. `LFM`, `Ling`, `Qwen3 Omni` model support. Quantization is now faster with reduced vram usage. Enhanced logging support with `LogBar`.
* 09/16/2025 [4.2.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.5): `hyb_act` renamed to `act_group_aware`. Removed finicky `torch` import within `setup.py`. Packing bug fix and prebuilt Pytorch 2.8 whls.
@@ -196,14 +197,14 @@ Native support support some of the most popular multi-modal models:
GPT-QModel is validated for Linux, MacOS, and Windows 11:
-| Platform | Device | | Optimized Arch | Kernels |
-|-----------------|---------------| --- | -------------- |-----------------------------------------------|
-| 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` | Marlin, Exllama V2, Exallma V1, Triton, Torch |
-| 🐧 Linux | AMD GPU | ✅ | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch |
-| 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | Torch Fused (Python 2.8+), Torch |
-| 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx`, `xmx` | Torch Fused, Torch |
-| 🍎 MacOS | GPU (Metal) / CPU | ✅ | `Apple Silicon`, `M1+` | Torch, MLX via conversion |
-| 🪟 Windows | GPU (Nvidia) / CPU | ✅ | `Nvidia` | Torch |
+| Platform | Device | | Optimized Arch | Kernels |
+|-----------------|---------------| --- | -------------- |--------------------------------------------------------|
+| 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` | Machete, Marlin, Exllama V2, Exallma V1, Triton, Torch |
+| 🐧 Linux | AMD GPU | ✅ | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch |
+| 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | Torch Fused (Python 2.8+), Torch |
+| 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx`, `xmx` | Torch Fused (Python 2.8+), Torch |
+| 🍎 MacOS | GPU (Metal) / CPU | ✅ | `Apple Silicon`, `M1+` | Torch, MLX via conversion |
+| 🪟 Windows | GPU (Nvidia) / CPU | ✅ | `Nvidia` | Torch |
## Install
diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py
index c7fca5666..9a55fcab3 100644
--- a/gptqmodel/models/loader.py
+++ b/gptqmodel/models/loader.py
@@ -10,7 +10,6 @@
from importlib.metadata import PackageNotFoundError, version
from typing import Dict, List, Optional, Union
-import accelerate
import torch
import transformers
@@ -38,6 +37,7 @@
from ..utils.backend import BACKEND
from ..utils.importer import auto_select_device, normalize_device_device_map, select_quant_linear
from ..utils.logger import setup_logger
+from ..utils.machete import _validate_machete_device_support
from ..utils.marlin import _validate_marlin_device_support
from ..utils.model import (
auto_dtype,
@@ -478,7 +478,6 @@ def skip(*args, **kwargs):
init_contexts = [no_init_weights()]
- layer_type = ""
with (ContextManagers(init_contexts)):
cls.before_model_load(cls, load_quantized_model=True)
@@ -507,8 +506,7 @@ def skip(*args, **kwargs):
# Get the first layer to determine layer type
layers, _ = get_module_by_name_prefix(model, cls.extract_layers_node())
- layer0 = layers[0]
- layer_type = layer0.__class__.__name__
+ layers[0]
modules = find_modules(model)
ignore_modules = [cls.lm_head] + cls.get_base_modules(model)
@@ -535,7 +533,6 @@ def skip(*args, **kwargs):
device=device,
)
- log.debug(f"Loader1: device_map {device_map}")
if isinstance(device_map, str) and device_map not in [
"auto",
"balanced",
@@ -548,8 +545,8 @@ def skip(*args, **kwargs):
)
+
import torch
- from typing import Dict, List, Optional
def build_layerwise_device_map(
model,
@@ -643,7 +640,7 @@ def assign(mod, device_id):
if owner:
device_map.setdefault(owner, fallback_device)
else:
- log.debug(f"Loader: unable to map param '{param_name}' to a module; skipping fallback assignment.")
+ log.info(f"Loader: unable to map param '{param_name}' to a module; skipping fallback assignment.")
# -------------------------------------------------------------
# 6. Prune parent assignments that would override child devices
@@ -657,11 +654,11 @@ def assign(mod, device_id):
if child_name != name and child_name.startswith(f"{name}.")
}
if child_devices and (len(child_devices) > 1 or device_id not in child_devices):
- log.debug(f"Loader: dropping parent '{name}' from device_map to preserve child placements.")
+ log.info(f"Loader: dropping parent '{name}' from device_map to preserve child placements.")
device_map.pop(name, None)
# optional logging for debug
- log.debug(f"Loader: Built map across {num_gpus} GPU(s), "
+ log.info(f"Loader: Built map across {num_gpus} GPU(s), "
f"{len(device_map)} entries. First 8: {list(device_map.items())[:8]}")
return device_map
@@ -707,6 +704,16 @@ def assign(mod, device_id):
qcfg.runtime_format = FORMAT.GPTQ_V2
+ if backend == BACKEND.MACHETE:
+ if is_sharded:
+ raise ValueError(
+ "Format: The loading of sharded checkpoints with Machete is currently not supported."
+ )
+ if not _validate_machete_device_support():
+ raise ValueError(
+ f"Kernel: Machete kernel requires compute capability >= 9.0. Detected capability: {torch.cuda.get_device_capability()}"
+ )
+
if backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and (
preload_qlinear_kernel == ExllamaV2QuantLinear or qcfg.format == FORMAT.MARLIN):
if is_sharded:
@@ -742,7 +749,7 @@ def assign(mod, device_id):
# If we use marlin or bitblas to load the quantized model, the model is already a converted model,
# and we no longer need to call load_checkpoint_in_model()
- if load_checkpoint_in_model and backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16, BACKEND.BITBLAS]:
+ if load_checkpoint_in_model and backend not in [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.MARLIN_FP16, BACKEND.BITBLAS]:
load_checkpoint_in_model_then_tie_weights(
model,
dtype=dtype,
diff --git a/gptqmodel/nn_modules/hooked_linear.py b/gptqmodel/nn_modules/hooked_linear.py
index 867874a26..6e002ae7a 100644
--- a/gptqmodel/nn_modules/hooked_linear.py
+++ b/gptqmodel/nn_modules/hooked_linear.py
@@ -241,8 +241,8 @@ def _replace_module(module, child, name, level: int = 0, debug: bool = False) ->
def replace_module_with_hooked_legacy(module, level: int = 0, quant_lm_head: bool = False):
- if level == 0:
- log.info("Hooked Modules: Using legacy based config for targeting of modules")
+ # if level == 0:
+ # log.info("Hooked Modules: Using legacy based config for targeting of modules")
for name, child in module.named_children():
if not quant_lm_head and hasattr(module, "get_output_embeddings") and child == module.get_output_embeddings():
diff --git a/gptqmodel/nn_modules/qlinear/awq_machete.py b/gptqmodel/nn_modules/qlinear/awq_machete.py
new file mode 100644
index 000000000..622d7dd35
--- /dev/null
+++ b/gptqmodel/nn_modules/qlinear/awq_machete.py
@@ -0,0 +1,200 @@
+# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
+# SPDX-License-Identifier: Apache-2.0
+# Contact: qubitium@modelcloud.ai, x.com/qubitium
+
+from __future__ import annotations
+
+from typing import Optional, Tuple
+
+import torch
+
+from ...adapter.adapter import Adapter, Lora
+from ...models._const import DEVICE, PLATFORM
+from ...nn_modules.qlinear import AWQuantLinear
+from ...utils.backend import BACKEND
+from ...utils.logger import setup_logger
+from ...utils.machete import (
+ _validate_machete_device_support,
+ machete_import_exception,
+ machete_mm,
+ machete_prepack_B,
+ pack_quantized_values_into_int32,
+)
+from ...utils.marlin import replace_parameter, unpack_cols
+from ...utils.marlin_scalar_type import scalar_types
+from ...utils.rocm import IS_ROCM
+
+
+log = setup_logger()
+
+
+class AwqMacheteQuantLinear(AWQuantLinear):
+ SUPPORTS_BITS = [4, 8]
+ SUPPORTS_GROUP_SIZE = [-1, 32, 64, 128]
+ SUPPORTS_DESC_ACT = [False] # AWQ kernels do not reorder activations
+ SUPPORTS_SYM = [True, False]
+ SUPPORTS_SHARDS = True
+ SUPPORTS_TRAINING = False
+ SUPPORTS_AUTO_PADDING = False
+ SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [64]
+ SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [128]
+
+ SUPPORTS_DEVICES = [DEVICE.CUDA]
+ SUPPORTS_PLATFORM = [PLATFORM.LINUX]
+ SUPPORTS_PACK_DTYPES = [torch.int32]
+ SUPPORTS_ADAPTERS = [Lora]
+
+ SUPPORTS_DTYPES = [torch.float16, torch.bfloat16]
+
+ REQUIRES_FORMAT_V2 = False
+
+ QUANT_TYPE = "awq_machete"
+
+ TYPE_MAP = {
+ 4: scalar_types.uint4,
+ 8: scalar_types.uint8,
+ }
+
+ def __init__(
+ self,
+ bits: int,
+ group_size: int,
+ desc_act: bool,
+ sym: bool,
+ in_features: int,
+ out_features: int,
+ bias: bool = False,
+ pack_dtype: torch.dtype = torch.int32,
+ adapter: Adapter = None,
+ register_buffers: bool = False,
+ **kwargs):
+ if machete_import_exception is not None:
+ raise ValueError(
+ "Trying to use the machete backend, but could not import the "
+ f"C++/CUDA dependencies with the following error: {machete_import_exception}"
+ )
+
+ if bits not in self.TYPE_MAP:
+ raise ValueError(f"Unsupported num_bits = {bits}. Supported: {list(self.TYPE_MAP.keys())}")
+
+ super().__init__(
+ bits=bits,
+ group_size=group_size,
+ sym=sym,
+ desc_act=False,
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ pack_dtype=pack_dtype,
+ backend=kwargs.pop("backend", BACKEND.MACHETE),
+ adapter=adapter,
+ register_buffers=register_buffers,
+ **kwargs)
+
+ self.weight_type = self.TYPE_MAP[self.bits]
+ self.has_zero_points = True
+
+ @classmethod
+ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
+ if machete_import_exception is not None:
+ return False, ImportError(machete_import_exception)
+ return cls._validate(**args)
+
+ @classmethod
+ def validate_device(cls, device: DEVICE):
+ super().validate_device(device)
+ if device == DEVICE.CUDA:
+ if IS_ROCM:
+ raise NotImplementedError("Machete kernel is not supported on ROCm.")
+ if not _validate_machete_device_support():
+ raise NotImplementedError("Machete kernel requires compute capability >= 9.0.")
+
+ def post_init(self):
+ device = self.qweight.device
+
+ # Reconstruct integer weights from packed AWQ representation
+ qweight_int = unpack_cols(
+ self.qweight,
+ self.bits,
+ self.in_features,
+ self.out_features,
+ ).to(device=device)
+
+ packed = pack_quantized_values_into_int32(
+ qweight_int,
+ self.weight_type,
+ packed_dim=0,
+ )
+ packed = packed.t().contiguous().t()
+ prepacked = machete_prepack_B(
+ packed,
+ a_type=self.scales.dtype,
+ b_type=self.weight_type,
+ group_scales_type=self.scales.dtype,
+ )
+ replace_parameter(
+ self,
+ "qweight",
+ torch.nn.Parameter(prepacked.contiguous(), requires_grad=False),
+ )
+
+ # Ensure scales are contiguous and resident on the correct device.
+ replace_parameter(
+ self,
+ "scales",
+ torch.nn.Parameter(self.scales.contiguous(), requires_grad=False),
+ )
+
+ # Convert zero-points: unpack columns, then pre-apply scales as expected by machete_mm
+ effective_group_size = self.in_features if self.group_size == -1 else self.group_size
+ num_groups = self.in_features // effective_group_size
+
+ qzeros_unpacked = unpack_cols(
+ self.qzeros,
+ self.bits,
+ num_groups,
+ self.out_features,
+ ).to(device=device)
+
+ scales = self.scales
+ qzeros_fp = (-1.0 * scales.to(dtype=scales.dtype) * qzeros_unpacked.to(scales.dtype)).contiguous()
+ replace_parameter(
+ self,
+ "qzeros",
+ torch.nn.Parameter(qzeros_fp, requires_grad=False),
+ )
+
+ if self.bias is not None:
+ self.bias = self.bias.to(device=device)
+
+ super().post_init()
+
+ def forward(self, x: torch.Tensor):
+ if x.shape[0] == 0:
+ return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device)
+
+ input_2d = x.reshape(-1, x.shape[-1])
+ group_scales = self.scales.to(dtype=input_2d.dtype)
+ group_zeros = self.qzeros.to(dtype=input_2d.dtype)
+
+ output = machete_mm(
+ a=input_2d,
+ b_q=self.qweight,
+ b_type=self.weight_type,
+ b_group_scales=group_scales,
+ b_group_zeros=group_zeros,
+ b_group_size=self.group_size,
+ )
+
+ if self.bias is not None:
+ output.add_(self.bias)
+
+ result = output.reshape(x.shape[:-1] + (self.out_features,))
+
+ if self.adapter:
+ result = self.adapter.apply(x=x, out=result)
+
+ return result
+
+
+__all__ = ["AwqMacheteQuantLinear"]
diff --git a/gptqmodel/nn_modules/qlinear/machete.py b/gptqmodel/nn_modules/qlinear/machete.py
new file mode 100644
index 000000000..177979e29
--- /dev/null
+++ b/gptqmodel/nn_modules/qlinear/machete.py
@@ -0,0 +1,291 @@
+# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
+# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
+# SPDX-License-Identifier: Apache-2.0
+# Contact: qubitium@modelcloud.ai, x.com/qubitium
+
+from __future__ import annotations
+
+from typing import List, Optional, Tuple
+
+import torch
+
+from ...adapter.adapter import Adapter, Lora
+from ...models._const import DEVICE, PLATFORM
+from ...nn_modules.qlinear import BaseQuantLinear
+from ...utils.backend import BACKEND
+from ...utils.logger import setup_logger
+from ...utils.machete import (
+ _validate_machete_device_support,
+ check_machete_supports_shape,
+ machete_import_exception,
+ machete_mm,
+ machete_prepack_B,
+ pack_quantized_values_into_int32,
+ query_machete_supported_group_sizes,
+ unpack_quantized_values_into_int32,
+)
+from ...utils.marlin import replace_parameter
+from ...utils.marlin_scalar_type import scalar_types
+from ...utils.rocm import IS_ROCM
+
+
+log = setup_logger()
+
+
+class MacheteQuantLinear(BaseQuantLinear):
+ SUPPORTS_BITS = [4, 8]
+ SUPPORTS_GROUP_SIZE = [-1, 64, 128]
+ SUPPORTS_DESC_ACT = [True, False]
+ SUPPORTS_SYM = [True]
+ SUPPORTS_SHARDS = True
+ SUPPORTS_TRAINING = False
+ SUPPORTS_AUTO_PADDING = False
+ SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [64]
+ SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [128]
+
+ SUPPORTS_DEVICES = [DEVICE.CUDA]
+ SUPPORTS_PLATFORM = [PLATFORM.LINUX]
+ SUPPORTS_PACK_DTYPES = [torch.int32]
+ SUPPORTS_ADAPTERS = [Lora]
+ SUPPORTS_DTYPES = [torch.float16, torch.bfloat16]
+
+ REQUIRES_FORMAT_V2 = False
+
+ QUANT_TYPE = "machete"
+
+ TYPE_MAP = {
+ (4, True): scalar_types.uint4b8,
+ (8, True): scalar_types.uint8b128,
+ }
+
+ def __init__(
+ self,
+ bits: int,
+ group_size: int,
+ desc_act: bool,
+ sym: bool,
+ in_features: int,
+ out_features: int,
+ bias: bool = False,
+ pack_dtype: torch.dtype = torch.int32,
+ register_buffers: bool = False,
+ adapter: Adapter = None,
+ **kwargs):
+ if machete_import_exception is not None:
+ raise ValueError(
+ "Trying to use the machete backend, but could not import the "
+ f"C++/CUDA dependencies with the following error: {machete_import_exception}"
+ )
+
+ if (bits, sym) not in self.TYPE_MAP:
+ raise ValueError(f"Unsupported quantization config: bits={bits}, sym={sym}")
+
+ super().__init__(
+ bits=bits,
+ group_size=group_size,
+ sym=sym,
+ desc_act=desc_act,
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ pack_dtype=pack_dtype,
+ backend=kwargs.pop("backend", BACKEND.MACHETE),
+ adapter=adapter,
+ register_buffers=False,
+ **kwargs)
+
+ # Quantized weights (packed)
+ self.register_parameter(
+ "qweight",
+ torch.nn.Parameter(
+ torch.empty(
+ self.in_features // self.pack_factor,
+ self.out_features,
+ dtype=torch.int32,
+ ),
+ requires_grad=False,
+ ),
+ )
+
+ # Activation order indices
+ self.register_parameter(
+ "g_idx",
+ torch.nn.Parameter(
+ torch.empty(self.in_features, dtype=torch.int32),
+ requires_grad=False,
+ ),
+ )
+
+ # Scales
+ scales_rows = self.in_features if self.group_size == -1 else self.in_features // self.group_size
+ self.register_parameter(
+ "scales",
+ torch.nn.Parameter(
+ torch.empty(
+ scales_rows,
+ self.out_features,
+ dtype=torch.float16,
+ ),
+ requires_grad=False,
+ ),
+ )
+
+ # Zero points unused for symmetric GPTQ
+ self.register_parameter(
+ "qzeros",
+ torch.nn.Parameter(
+ torch.empty(0, dtype=torch.float16),
+ requires_grad=False,
+ ),
+ )
+
+ if bias:
+ self.register_buffer("bias", torch.zeros((self.out_features), dtype=torch.float16))
+ else:
+ self.bias = None
+
+ self.weight_type = self.TYPE_MAP[(self.bits, sym)]
+ self.has_zero_points = False
+
+ # Buffer storing permutation applied to activations (empty when unused)
+ self.register_buffer("input_perm", torch.empty(0, dtype=torch.int32))
+
+ @classmethod
+ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
+ if machete_import_exception is not None:
+ return False, ImportError(machete_import_exception)
+
+ ok, err = cls._validate(**args)
+ if not ok:
+ return ok, err
+
+ in_features = args.get("in_features")
+ out_features = args.get("out_features")
+ if in_features is not None and out_features is not None:
+ supported, reason = check_machete_supports_shape(in_features, out_features)
+ if not supported:
+ return False, ValueError(reason)
+
+ bits = args.get("bits")
+ sym = args.get("sym", True)
+ quant_type = cls.TYPE_MAP.get((bits, sym))
+ if quant_type is None:
+ return False, ValueError(f"Machete does not support bits={bits}, sym={sym}")
+
+ group_size = args.get("group_size")
+ dtype = args.get("dtype", torch.float16)
+ if group_size not in query_machete_supported_group_sizes(dtype):
+ return False, ValueError(
+ f"Machete does not support group_size={group_size} for dtype={dtype}"
+ )
+
+ return True, None
+
+ @classmethod
+ def validate_device(cls, device: DEVICE):
+ super().validate_device(device)
+ if device == DEVICE.CUDA:
+ if IS_ROCM:
+ raise NotImplementedError("Machete kernel is not supported on ROCm.")
+ if not _validate_machete_device_support():
+ raise NotImplementedError("Machete kernel requires compute capability >= 9.0.")
+
+ def post_init(self):
+ device = self.qweight.device
+
+ perm = None
+ if self.desc_act:
+ perm = torch.argsort(self.g_idx).to(torch.int32)
+ sorted_g_idx = self.g_idx[perm]
+ replace_parameter(
+ self,
+ "g_idx",
+ torch.nn.Parameter(sorted_g_idx.to(device=device), requires_grad=False),
+ )
+ self.input_perm = perm.to(device=device)
+ else:
+ self.input_perm = torch.empty(0, dtype=torch.int32, device=device)
+
+ qweight_unpacked = unpack_quantized_values_into_int32(
+ self.qweight.data, self.weight_type, packed_dim=0)
+ if perm is not None:
+ qweight_unpacked = qweight_unpacked[perm, :]
+
+ qweight_packed = pack_quantized_values_into_int32(
+ qweight_unpacked, self.weight_type, packed_dim=0)
+ qweight_packed = qweight_packed.t().contiguous().t()
+ prepacked = machete_prepack_B(
+ qweight_packed,
+ a_type=self.scales.dtype,
+ b_type=self.weight_type,
+ group_scales_type=self.scales.dtype,
+ )
+ replace_parameter(
+ self,
+ "qweight",
+ torch.nn.Parameter(prepacked.contiguous(), requires_grad=False),
+ )
+
+ replace_parameter(
+ self,
+ "scales",
+ torch.nn.Parameter(self.scales.data.contiguous(), requires_grad=False),
+ )
+
+ replace_parameter(
+ self,
+ "qzeros",
+ torch.nn.Parameter(torch.empty(0, dtype=self.scales.dtype, device=device), requires_grad=False),
+ )
+ self.has_zero_points = False
+
+ if self.bias is not None:
+ self.bias = self.bias.to(device=device)
+
+ super().post_init()
+
+ def list_buffers(self) -> List:
+ buf = super().list_buffers()
+ if hasattr(self, "input_perm") and self.input_perm is not None:
+ buf.append(self.input_perm)
+ return buf
+
+ def forward(self, x: torch.Tensor):
+ if x.shape[0] == 0:
+ return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device)
+
+ input_2d = x.reshape(-1, x.shape[-1])
+
+ if self.input_perm.numel() > 0:
+ perm = self.input_perm
+ if perm.device != input_2d.device:
+ perm = perm.to(device=input_2d.device)
+ input_2d = input_2d[:, perm]
+
+ group_scales = self.scales
+ if group_scales.dtype != input_2d.dtype:
+ group_scales = group_scales.to(dtype=input_2d.dtype)
+
+ group_zeros = self.qzeros if self.has_zero_points and self.qzeros.numel() > 0 else None
+
+ output = machete_mm(
+ a=input_2d,
+ b_q=self.qweight,
+ b_type=self.weight_type,
+ b_group_scales=group_scales,
+ b_group_zeros=group_zeros,
+ b_group_size=self.group_size,
+ )
+
+ if self.bias is not None:
+ output.add_(self.bias)
+
+ result = output.reshape(x.shape[:-1] + (self.out_features,))
+
+ if self.adapter:
+ result = self.adapter.apply(x=x, out=result)
+
+ return result
+
+
+__all__ = ["MacheteQuantLinear"]
diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py
index 58af9135e..664e0e40c 100644
--- a/gptqmodel/utils/backend.py
+++ b/gptqmodel/utils/backend.py
@@ -17,6 +17,7 @@ class BACKEND(str, Enum):
EXLLAMA_V1 = "exllama_v1" # FAST: optimized for batching == 1
EXLLAMA_V2 = "exllama_v2" # FASTER: optimized for batching > 1
EXLLAMA_EORA = "exllama_eora"
+ MACHETE = "machete" # CUTLASS-based kernel optimized for Hopper (SM90+)
MARLIN = "marlin" # FASTEST: marlin reduce ops in fp32 (higher precision -> more accurate, slightly slower)
MARLIN_FP16 = "marlin_fp16" # FASTEST and then some: marlin reduce ops in fp16 (lower precision -> less accurate, slightly faster)
BITBLAS = "bitblas" # EXTREMELY FAST: speed at the cost of 10+ minutes of AOT (ahead of time compilation with disk cache)
diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py
index 2cf2db75d..3bb46f107 100644
--- a/gptqmodel/utils/importer.py
+++ b/gptqmodel/utils/importer.py
@@ -18,11 +18,13 @@
from ..nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear
from ..nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear
from ..nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear
+from ..nn_modules.qlinear.awq_machete import AwqMacheteQuantLinear
from ..nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear
from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear
from ..nn_modules.qlinear.exllama import ExllamaQuantLinear
from ..nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear
from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear
+from ..nn_modules.qlinear.machete import MacheteQuantLinear
from ..nn_modules.qlinear.marlin import MarlinQuantLinear
from ..nn_modules.qlinear.qqq import QQQQuantLinear
from ..nn_modules.qlinear.torch import TorchQuantLinear
@@ -45,6 +47,7 @@
AUTO_SELECT_BACKEND_ORDER_MAP = {
METHOD.GPTQ: OrderedDict({
+ BACKEND.MACHETE: MacheteQuantLinear, # optimized for sm90+
BACKEND.MARLIN: MarlinQuantLinear, # optimized for bs > 1
# BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear, #
BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear, # optimized for bs > 1
@@ -59,6 +62,7 @@
BACKEND.QQQ: QQQQuantLinear, # qqq kernel based on marlin
}),
METHOD.AWQ: OrderedDict({
+ BACKEND.MACHETE: AwqMacheteQuantLinear,
BACKEND.MARLIN: AwqMarlinQuantLinear,
BACKEND.EXLLAMA_V2: AwqExllamaV2QuantLinear,
BACKEND.EXLLAMA_V1: AwqExllamaQuantLinear,
@@ -70,7 +74,7 @@
SUPPORTS_BACKEND_MAP = {
METHOD.GPTQ: {
- FORMAT.GPTQ: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TORCH_FUSED, BACKEND.TRITON, BACKEND.TORCH_FUSED, BACKEND.TORCH, BACKEND.MARLIN_FP16, BACKEND.EXLLAMA_EORA],
+ FORMAT.GPTQ: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TORCH_FUSED, BACKEND.TRITON, BACKEND.TORCH_FUSED, BACKEND.TORCH, BACKEND.MARLIN_FP16, BACKEND.EXLLAMA_EORA],
FORMAT.GPTQ_V2: [BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TORCH_FUSED, BACKEND.TRITON, BACKEND.TORCH],
FORMAT.MARLIN: [BACKEND.MARLIN, BACKEND.MARLIN_FP16],
FORMAT.BITBLAS: [BACKEND.BITBLAS],
@@ -79,10 +83,10 @@
FORMAT.QQQ: [BACKEND.QQQ],
},
METHOD.AWQ: {
- FORMAT.GEMM: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM],
+ FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM],
FORMAT.GEMV: [BACKEND.GEMV],
FORMAT.GEMV_FAST: [BACKEND.GEMV_FAST],
- FORMAT.MARLIN: [BACKEND.MARLIN],
+ FORMAT.MARLIN: [BACKEND.MACHETE, BACKEND.MARLIN],
}
}
@@ -280,6 +284,11 @@ def select_quant_linear(
qlinear = TritonV2QuantLinear
elif backend == BACKEND.BITBLAS:
qlinear = BitBLASQuantLinear
+ elif backend == BACKEND.MACHETE:
+ if quant_method == METHOD.AWQ:
+ qlinear = AwqMacheteQuantLinear
+ else:
+ qlinear = MacheteQuantLinear
elif backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16]:
if quant_method == METHOD.AWQ:
qlinear = AwqMarlinQuantLinear
diff --git a/gptqmodel/utils/machete.py b/gptqmodel/utils/machete.py
new file mode 100644
index 000000000..57aaee535
--- /dev/null
+++ b/gptqmodel/utils/machete.py
@@ -0,0 +1,187 @@
+# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
+# SPDX-License-Identifier: Apache-2.0
+# Contact: qubitium@modelcloud.ai, x.com/qubitium
+
+from __future__ import annotations
+
+from typing import List, Optional
+
+import torch
+
+from ._extension_loader import load_extension_module
+from .logger import setup_logger
+from .marlin_scalar_type import ScalarType, scalar_types
+
+
+log = setup_logger()
+
+machete_import_exception: Optional[str] = None
+try:
+ gptqmodel_machete_kernels = load_extension_module("gptqmodel_machete_kernels")
+except ImportError as e: # pragma: no cover - surfaced at runtime
+ machete_import_exception = str(e)
+ gptqmodel_machete_kernels = None
+
+MACHETE_PREPACKED_BLOCK_SHAPE = (64, 128)
+
+
+def _validate_machete_device_support() -> bool:
+ return (torch.cuda.is_available()
+ and torch.cuda.get_device_capability()[0] >= 9)
+
+
+def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]:
+ if zero_points:
+ return [scalar_types.uint4, scalar_types.uint8]
+ return [scalar_types.uint4b8, scalar_types.uint8b128]
+
+
+def query_machete_supported_act_types(_zero_points: bool) -> List[torch.dtype]:
+ return [torch.float16, torch.bfloat16]
+
+
+def query_machete_supported_group_sizes(act_type: torch.dtype) -> List[int]:
+ if act_type in (torch.float16, torch.bfloat16):
+ return [-1, 64, 128]
+ return [-1, 128]
+
+
+def check_machete_supports_shape(in_features: int,
+ out_features: int) -> tuple[bool, Optional[str]]:
+ if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
+ return (False,
+ f"Input features size must be divisible by {MACHETE_PREPACKED_BLOCK_SHAPE[0]}")
+ if out_features % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0:
+ return (False,
+ f"Output features size must be divisible by {MACHETE_PREPACKED_BLOCK_SHAPE[1]}")
+ return (True, None)
+
+
+def _ensure_machete_loaded():
+ if machete_import_exception is not None:
+ raise ImportError(
+ f"Trying to use the machete backend, but could not import the C++/CUDA dependencies: {machete_import_exception}"
+ )
+
+
+def _maybe_scalar_type(t: Optional[torch.Tensor]) -> Optional[torch.dtype]:
+ return t.dtype if t is not None else None
+
+
+def machete_prepack_B(
+ weight: torch.Tensor,
+ a_type: torch.dtype,
+ b_type: ScalarType,
+ group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
+ _ensure_machete_loaded()
+ return gptqmodel_machete_kernels.machete_prepack_B(
+ weight,
+ a_type,
+ b_type.id,
+ group_scales_type,
+ )
+
+
+def machete_supported_schedules(
+ a_type: torch.dtype,
+ b_type: ScalarType,
+ group_scales_type: Optional[torch.dtype] = None,
+ group_zeros_type: Optional[torch.dtype] = None,
+ channel_scales_type: Optional[torch.dtype] = None,
+ token_scales_type: Optional[torch.dtype] = None,
+ out_type: Optional[torch.dtype] = None) -> List[str]:
+ _ensure_machete_loaded()
+ return gptqmodel_machete_kernels.machete_supported_schedules(
+ a_type,
+ b_type.id,
+ group_scales_type,
+ group_zeros_type,
+ channel_scales_type,
+ token_scales_type,
+ out_type,
+ )
+
+
+def machete_mm(
+ *,
+ a: torch.Tensor,
+ b_q: torch.Tensor,
+ b_type: ScalarType,
+ b_group_scales: Optional[torch.Tensor] = None,
+ b_group_zeros: Optional[torch.Tensor] = None,
+ b_group_size: Optional[int] = None,
+ b_channel_scales: Optional[torch.Tensor] = None,
+ a_token_scales: Optional[torch.Tensor] = None,
+ out_type: Optional[torch.dtype] = None,
+ schedule: Optional[str] = None) -> torch.Tensor:
+ _ensure_machete_loaded()
+ return gptqmodel_machete_kernels.machete_mm(
+ a,
+ b_q,
+ b_type.id,
+ out_type,
+ b_group_scales,
+ b_group_zeros,
+ b_group_size,
+ b_channel_scales,
+ a_token_scales,
+ schedule,
+ )
+
+
+def pack_quantized_values_into_int32(
+ tensor: torch.Tensor,
+ qtype: ScalarType,
+ packed_dim: int = 0) -> torch.Tensor:
+ perm = tuple(i for i in range(tensor.ndim) if i != packed_dim) + (packed_dim,)
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
+ temp = tensor.permute(perm)
+
+ pack_factor = 32 // qtype.size_bits
+ mask = (1 << qtype.size_bits) - 1
+
+ assert temp.shape[-1] % pack_factor == 0
+ new_shape = list(temp.shape)
+ new_shape[-1] //= pack_factor
+
+ result = torch.zeros(new_shape, dtype=torch.int32, device=tensor.device)
+ for i in range(pack_factor):
+ result |= ((temp[..., i::pack_factor] & mask) << (qtype.size_bits * i))
+
+ return result.permute(inv_perm)
+
+
+def unpack_quantized_values_into_int32(
+ tensor: torch.Tensor,
+ qtype: ScalarType,
+ packed_dim: int = 0) -> torch.Tensor:
+ perm = tuple(i for i in range(tensor.ndim) if i != packed_dim) + (packed_dim,)
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
+ temp = tensor.permute(perm)
+
+ pack_factor = 32 // qtype.size_bits
+ mask = (1 << qtype.size_bits) - 1
+
+ new_shape = list(temp.shape)
+ new_shape[-1] *= pack_factor
+
+ result = torch.zeros(new_shape, dtype=torch.int32, device=tensor.device)
+ for i in range(pack_factor):
+ result[..., i::pack_factor] = (temp >> (qtype.size_bits * i)) & mask
+
+ return result.permute(inv_perm)
+
+
+__all__ = [
+ "_validate_machete_device_support",
+ "check_machete_supports_shape",
+ "machete_import_exception",
+ "machete_mm",
+ "machete_prepack_B",
+ "machete_supported_schedules",
+ "pack_quantized_values_into_int32",
+ "query_machete_supported_act_types",
+ "query_machete_supported_group_sizes",
+ "query_machete_supported_quant_types",
+ "unpack_quantized_values_into_int32",
+]
diff --git a/gptqmodel_ext/cutlass_extensions/__init__.py b/gptqmodel_ext/cutlass_extensions/__init__.py
new file mode 100644
index 000000000..a903f4cee
--- /dev/null
+++ b/gptqmodel_ext/cutlass_extensions/__init__.py
@@ -0,0 +1 @@
+# Cutlass extension helpers for GPTQModel
diff --git a/gptqmodel_ext/cutlass_extensions/common.cpp b/gptqmodel_ext/cutlass_extensions/common.cpp
new file mode 100644
index 000000000..3d2093ab9
--- /dev/null
+++ b/gptqmodel_ext/cutlass_extensions/common.cpp
@@ -0,0 +1,11 @@
+#include "cutlass_extensions/common.hpp"
+
+int32_t get_sm_version_num() {
+ int32_t major_capability, minor_capability;
+ cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
+ 0);
+ cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
+ 0);
+ int32_t version_num = major_capability * 10 + minor_capability;
+ return version_num;
+}
\ No newline at end of file
diff --git a/gptqmodel_ext/cutlass_extensions/common.hpp b/gptqmodel_ext/cutlass_extensions/common.hpp
new file mode 100644
index 000000000..f2c1dcf69
--- /dev/null
+++ b/gptqmodel_ext/cutlass_extensions/common.hpp
@@ -0,0 +1,72 @@
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include
+#include "cuda_runtime.h"
+#include
+
+/**
+ * Helper function for checking CUTLASS errors
+ */
+#define CUTLASS_CHECK(status) \
+ { \
+ cutlass::Status error = status; \
+ TORCH_CHECK(error == cutlass::Status::kSuccess, \
+ cutlassGetStatusString(error)); \
+ }
+
+inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
+ int max_shared_mem_per_block_opt_in = 0;
+ cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
+ return max_shared_mem_per_block_opt_in;
+}
+
+int32_t get_sm_version_num();
+
+/**
+ * A wrapper for a kernel that is used to guard against compilation on
+ * architectures that will never use the kernel. The purpose of this is to
+ * reduce the size of the compiled binary.
+ * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
+ * into code that will be executed on the device where it is defined.
+ */
+template
+struct enable_sm90_or_later : Kernel {
+ template
+ CUTLASS_DEVICE void operator()(Args&&... args) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
+ Kernel::operator()(std::forward(args)...);
+#endif
+ }
+};
+
+template
+struct enable_sm90_only : Kernel {
+ template
+ CUTLASS_DEVICE void operator()(Args&&... args) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900
+ Kernel::operator()(std::forward(args)...);
+#endif
+ }
+};
+
+template
+struct enable_sm100_only : Kernel {
+ template
+ CUTLASS_DEVICE void operator()(Args&&... args) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
+ Kernel::operator()(std::forward(args)...);
+#endif
+ }
+};
+
+template
+struct enable_sm120_only : Kernel {
+ template
+ CUTLASS_DEVICE void operator()(Args&&... args) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
+ Kernel::operator()(std::forward(args)...);
+#endif
+ }
+};
diff --git a/gptqmodel_ext/cutlass_extensions/cute_utils.cuh b/gptqmodel_ext/cutlass_extensions/cute_utils.cuh
new file mode 100644
index 000000000..f61fe3ceb
--- /dev/null
+++ b/gptqmodel_ext/cutlass_extensions/cute_utils.cuh
@@ -0,0 +1,68 @@
+#pragma once
+
+#include
+#include
+namespace cute {
+
+////////////////////////////////////////////////////////////////////
+// layout utils
+////////////////////////////////////////////////////////////////////
+
+// Permute layout based on indices, example:
+// permute_layout<1, 0>(layout) will swap the two dimensions
+// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
+template
+CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
+ static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
+ return cute::make_layout(cute::get(l)...);
+}
+
+// is the layout f(x) = x
+template
+CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
+ if constexpr (std::is_same_v) {
+ return true;
+ } else {
+ constexpr auto coalesced_layout = coalesce(Layout{});
+ if constexpr (rank(coalesced_layout) == 1 &&
+ stride<0>(coalesced_layout) == 1) {
+ return true;
+ }
+ return false;
+ }
+}
+
+////////////////////////////////////////////////////////////////////
+// Pointer utils
+////////////////////////////////////////////////////////////////////
+
+template
+static constexpr auto get_logical_ptr(PointerType* ptr) {
+ if constexpr (cute::sizeof_bits_v < 8) {
+ return cute::subbyte_iterator(ptr);
+ } else {
+ return ptr;
+ }
+}
+
+////////////////////////////////////////////////////////////////////
+// Misc utils
+////////////////////////////////////////////////////////////////////
+
+template
+CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
+ constexpr auto bits = sizeof_bits_v * Elements{};
+ if constexpr (bits % 128 == 0) {
+ return AutoVectorizingCopyWithAssumedAlignment<128>{};
+ } else if constexpr (bits % 64 == 0) {
+ return AutoVectorizingCopyWithAssumedAlignment<64>{};
+ } else if constexpr (bits % 32 == 0) {
+ return AutoVectorizingCopyWithAssumedAlignment<32>{};
+ } else if constexpr (bits % 16 == 0) {
+ return AutoVectorizingCopyWithAssumedAlignment<16>{};
+ } else {
+ return AutoVectorizingCopyWithAssumedAlignment<8>{};
+ }
+}
+
+}; // namespace cute
diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
new file mode 100644
index 000000000..5c1d6e3f4
--- /dev/null
+++ b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
@@ -0,0 +1,457 @@
+/***************************************************************************************************
+ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ *POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+//
+// This file is a modified excerpt of
+// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
+// from https://github.com/NVIDIA/cutlass v3.5.0
+// It has been modified to support either row/column or scalar broadcasting
+// where the tensor being loaded from is always passed in via a device pointer.
+// This lets one compiled kernel handle all cases of per-tensor or
+// per-channel/per-token quantization.
+//
+// This interface also allows the scales to be passed in as tensors that
+// consistently reside on the device, which avoids an issue with a previous
+// implementation where scalars needed to be on the CPU since they
+// were passed in via float values. This created a potential performance hazard
+// if scales were initially on the device, and caused torch.compile graphs
+// breaks when moving scales to the CPU.
+//
+#pragma once
+
+// Turn off clang-format for the entire file to keep it close to upstream
+// clang-format off
+
+#include "cutlass/cutlass.h"
+#include "cutlass/arch/barrier.h"
+
+#include "cute/tensor.hpp"
+#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
+
+namespace cutlass::epilogue::fusion {
+
+using namespace cute;
+using namespace detail;
+
+// Row vector broadcast
+template<
+ int Stages,
+ class CtaTileShapeMNK,
+ class Element,
+ class StrideMNL = Stride<_0,_1,_0>,
+ int Alignment = 128 / sizeof_bits_v
+>
+struct Sm90RowOrScalarBroadcastArray {
+ static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
+ static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static
+ static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
+
+ struct SharedStorage {
+ array_aligned(CtaTileShapeMNK{})> smem;
+ };
+
+ // This struct has been modified to have a bool indicating that ptr_row is a
+ // scalar that must be broadcast, instead of containing a scalar that is
+ // valid if ptr_row is null.
+ struct Arguments {
+ const Element* const* ptr_row_array = nullptr;
+ bool row_broadcast = true;
+ StrideMNL dRow = {};
+ };
+
+ using Params = Arguments;
+
+ template
+ static constexpr Params
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
+ return args;
+ }
+
+ template
+ static bool
+ can_implement(ProblemShape const& problem_shape, Arguments const& args) {
+ return true;
+ }
+
+ template
+ static size_t
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
+ return 0;
+ }
+
+ template
+ static cutlass::Status
+ initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
+ CudaHostAdapter* cuda_adapter = nullptr) {
+ return cutlass::Status::kSuccess;
+ }
+
+ CUTLASS_HOST_DEVICE
+ Sm90RowOrScalarBroadcastArray() { }
+
+ CUTLASS_HOST_DEVICE
+ Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
+ : params(params)
+ , smem(const_cast(shared_storage.smem.data())) { }
+
+ Params params;
+ Element *smem = nullptr;
+
+ CUTLASS_DEVICE bool
+ is_producer_load_needed() const {
+ return false;
+ }
+
+ CUTLASS_DEVICE bool
+ is_C_load_needed() const {
+ return false;
+ }
+
+ CUTLASS_DEVICE bool
+ is_zero() const {
+ return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0));
+ }
+
+ template
+ CUTLASS_DEVICE auto
+ get_producer_load_callbacks(ProducerLoadArgs const& args) {
+ return EmptyProducerLoadCallbacks{};
+ }
+
+ template
+ struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
+ CUTLASS_DEVICE
+ ConsumerStoreCallbacks(
+ GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
+ GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
+ SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
+ CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_,
+ int group, Params const& params_)
+ : tGS_gRow(tGS_gRow_)
+ , tGS_sRow(tGS_sRow_)
+ , tGS_cRow(tGS_cRow_)
+ , tiled_G2S(tiled_g2s_)
+ , tSR_sRow(tSR_sRow_)
+ , tSR_rRow(tSR_rRow_)
+ , tCcRow(tCcRow_)
+ , residue_tCcRow(residue_tCcRow_)
+ , group(group)
+ , params(params_) {}
+
+ GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
+ GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
+ GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
+ Tiled_G2S tiled_G2S;
+
+ SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+ SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+
+ CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+ ThrResidue residue_tCcRow; // (m, n)
+ ThrNum thr_num;
+ int group;
+ Params const& params;
+
+ CUTLASS_DEVICE void
+ begin() {
+ if (!params.row_broadcast) {
+ fill(tSR_rRow, *(params.ptr_row_array[group]));
+ return;
+ }
+
+ auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
+ Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
+ Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
+ Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
+
+ for (int i = 0; i < size(tGS_gRow_flt); ++i) {
+ if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
+ continue; // OOB of SMEM,
+ }
+ if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
+ tGS_sRow_flt(i) = tGS_gRow_flt(i);
+ }
+ else {
+ tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
+ }
+ }
+ synchronize();
+ }
+
+ CUTLASS_DEVICE void
+ begin_loop(int epi_m, int epi_n) {
+ if (epi_m == 0) { // Assumes M-major subtile loop
+ if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
+ Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
+ Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
+ copy(tSR_sRow_flt, tSR_rRow_flt);
+ }
+ }
+
+ template
+ CUTLASS_DEVICE Array
+ visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) {
+ Array frg_row;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < FragmentSize; ++i) {
+ frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
+ }
+
+ return frg_row;
+ }
+ };
+
+ template <
+ bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
+ class... Args
+ >
+ CUTLASS_DEVICE auto
+ get_consumer_store_callbacks(ConsumerStoreArgs const& args) {
+ auto [M, N, K, L] = args.problem_shape_mnkl;
+ auto [m, n, k, l] = args.tile_coord_mnkl;
+ using ThreadCount = decltype(size(args.tiled_copy));
+
+ Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
+ Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
+ Tensor sRow = make_tensor(make_smem_ptr(smem),
+ make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
+ //// G2S: Gmem to Smem
+ auto tiled_g2s = make_tiled_copy(Copy_Atom{},
+ Layout< Shape<_1, ThreadCount>,
+ Stride<_0, _1>>{},
+ Layout<_1>{});
+ auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
+ Tensor tGS_gRow = thr_g2s.partition_S(gRow);
+ Tensor tGS_sRow = thr_g2s.partition_D(sRow);
+
+ //// G2S: Coord
+ auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
+ Tensor tGS_cRow = thr_g2s.partition_S(cRow);
+
+ //// S2R: Smem to Reg
+ Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
+ Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
+
+ return ConsumerStoreCallbacks(
+ tGS_gRow,
+ tGS_sRow,
+ tGS_cRow, tiled_g2s,
+ tSR_sRow,
+ tSR_rRow,
+ args.tCcD,
+ args.residue_cD,
+ ThreadCount{},
+ l,
+ params);
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Column vector broadcast
+template<
+ int Stages,
+ class CtaTileShapeMNK,
+ class Element,
+ class StrideMNL = Stride<_1,_0,_0>,
+ int Alignment = 128 / sizeof_bits_v
+>
+struct Sm90ColOrScalarBroadcastArray {
+ static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
+ static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet");
+ static_assert(
+ (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias
+ (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias
+
+ // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
+ struct SharedStorage { };
+
+ // This struct has been modified to have a bool indicating that ptr_col is a
+ // scalar that must be broadcast, instead of containing a scalar that is
+ // valid if ptr_col is null.
+ struct Arguments {
+ const Element* const* ptr_col_array = nullptr;
+ bool col_broadcast = true;
+ StrideMNL dCol = {};
+ };
+
+ using Params = Arguments;
+
+ template
+ static constexpr Params
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
+ return args;
+ }
+
+ template
+ static bool
+ can_implement(ProblemShape const& problem_shape, Arguments const& args) {
+ return true;
+ }
+
+ template
+ static size_t
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
+ return 0;
+ }
+
+ template
+ static cutlass::Status
+ initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
+ CudaHostAdapter* cuda_adapter = nullptr) {
+ return cutlass::Status::kSuccess;
+ }
+
+ CUTLASS_DEVICE bool
+ is_producer_load_needed() const {
+ return false;
+ }
+
+ CUTLASS_DEVICE bool
+ is_C_load_needed() const {
+ return false;
+ }
+
+ CUTLASS_DEVICE bool
+ is_zero() const {
+ return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0));
+ }
+
+ CUTLASS_HOST_DEVICE
+ Sm90ColOrScalarBroadcastArray() { }
+
+ CUTLASS_HOST_DEVICE
+ Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
+ : params(params) { }
+
+ Params params;
+
+ template
+ CUTLASS_DEVICE auto
+ get_producer_load_callbacks(ProducerLoadArgs const& args) {
+ return EmptyProducerLoadCallbacks{};
+ }
+
+ template
+ struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
+ CUTLASS_DEVICE
+ ConsumerStoreCallbacks(
+ GTensor&& tCgCol,
+ RTensor&& tCrCol,
+ CTensor&& tCcCol,
+ ProblemShape problem_shape,
+ int group,
+ Params const& params
+ ):
+ tCgCol(cute::forward(tCgCol)),
+ tCrCol(cute::forward(tCrCol)),
+ tCcCol(cute::forward(tCcCol)),
+ m(get<0>(problem_shape)),
+ group(group),
+ params(params) {}
+
+ GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+ RTensor tCrCol;
+ CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+ Params const& params;
+ int m;
+ int group;
+
+ CUTLASS_DEVICE void
+ begin() {
+ Tensor pred = make_tensor(shape(tCgCol));
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < size(pred); ++i) {
+ pred(i) = get<0>(tCcCol(i)) < m;
+ }
+
+ if (!params.col_broadcast) {
+ fill(tCrCol, *(params.ptr_col_array[group]));
+ return;
+ }
+
+ // Filter so we don't issue redundant copies over stride-0 modes
+ // (only works if 0-strides are in same location, which is by construction)
+ copy_if(pred, filter(tCgCol), filter(tCrCol));
+ }
+
+ template
+ CUTLASS_DEVICE Array
+ visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) {
+ Array frg_col;
+ Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < FragmentSize; ++i) {
+ frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
+ }
+
+ return frg_col;
+ }
+
+ };
+
+ template <
+ bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
+ class... Args
+ >
+ CUTLASS_DEVICE auto
+ get_consumer_store_callbacks(ConsumerStoreArgs const& args) {
+
+ auto [M, N, K, L] = args.problem_shape_mnkl;
+ auto [m, n, k, l] = args.tile_coord_mnkl;
+
+ Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
+ Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+ mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
+ Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+
+ // Generate an identity tensor matching the shape of the global tensor and
+ // partition the same way, this will be used to generate the predicate
+ // tensor for loading
+ Tensor cCol = make_identity_tensor(mCol.shape());
+ Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+ cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
+
+ return ConsumerStoreCallbacks(
+ cute::move(tCgCol),
+ cute::move(tCrCol),
+ cute::move(tCcCol),
+ args.problem_shape_mnkl,
+ l,
+ params
+ );
+ }
+};
+
+}
diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
new file mode 100644
index 000000000..7aa87feb4
--- /dev/null
+++ b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
@@ -0,0 +1,497 @@
+/***************************************************************************************************
+ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ *POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+//
+// This file is a modified excerpt of
+// include/cutlass/epilogue/fusion/visitor_load.hpp from
+// https://github.com/NVIDIA/cutlass v3.5.0
+// It has been modified to support either
+// row/column or scalar broadcasting where the tensor being loaded from is
+// always passed in via a device pointer. This lets one compiled kernel handle
+// all cases of per-tensor or per-channel/per-token quantization.
+//
+// This interface also allows the scales to be passed in as tensors that
+// consistently reside on the device, which avoids an issue with a previous
+// implementation where scalars needed to be on the CPU since they
+// were passed in via float values. This created a potential performance hazard
+// if scales were initially on the device, and caused torch.compile graph
+// breaks when moving scales to the CPU.
+//
+#pragma once
+
+// Turn off clang-format for the entire file to keep it close to upstream
+// clang-format off
+
+#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
+#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
+#include "cute/tensor.hpp"
+
+namespace cutlass::epilogue::threadblock {
+
+using namespace cute;
+using namespace detail;
+
+template<
+ class ThreadMap,
+ class Element,
+ class StrideMNL
+>
+struct VisitorRowOrScalarBroadcast {
+
+ // This struct has been modified to have a bool indicating that ptr_row is a
+ // scalar that must be broadcast.
+ struct Arguments {
+ Element const* ptr_row = nullptr;
+ bool row_broadcast = true;
+ StrideMNL dRow = {};
+ };
+
+ using Params = Arguments;
+
+ template
+ static constexpr Params
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
+ return args;
+ }
+
+ template
+ static size_t
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
+ return 0;
+ }
+
+ struct SharedStorage {};
+
+ // Global load type
+ static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value;
+ using VecType = uint_bit_t;
+ static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
+
+ CUTLASS_HOST_DEVICE
+ VisitorRowOrScalarBroadcast() { }
+
+ CUTLASS_HOST_DEVICE
+ VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
+ : params_ptr(¶ms) { }
+
+ Params const* params_ptr;
+
+ template
+ struct Callbacks : EmptyCallbacks {
+ CUTLASS_DEVICE
+ Callbacks(
+ GTensor&& tC_gRow,
+ RTensor&& tC_rRow,
+ CTensor&& tC_cRow,
+ ProblemShape problem_shape,
+ Params const* params_ptr
+ ):
+ tC_gRow(cute::forward(tC_gRow)),
+ tC_rRow(cute::forward(tC_rRow)),
+ tC_cRow(cute::forward(tC_cRow)),
+ n(get<1>(problem_shape)),
+ params_ptr(params_ptr) { }
+
+ GTensor tC_gRow;
+ RTensor tC_rRow;
+ CTensor tC_cRow;
+ Params const* params_ptr;
+ int n;
+
+ // This function is modified from VisitorRowBroadcast
+ CUTLASS_DEVICE void
+ begin_epilogue() {
+ clear(tC_rRow);
+ auto src_v = filter(tC_gRow);
+ auto coord_v = filter(tC_cRow);
+ auto dst_v = filter(tC_rRow);
+
+ if (params_ptr->row_broadcast) {
+ // In this case we are loading from a row vector and broadcasting
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < size(src_v); ++i) {
+ bool guard = get<1>(coord_v(i)) < n;
+ cutlass::arch::global_load(
+ dst_v(i), (void const*)&src_v(i), guard);
+ }
+ } else {
+ // In this case we are loading from a scalar and broadcasting
+ VecType filled_vec;
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < VecLength; i++) {
+ reinterpret_cast(&filled_vec)[i] = *(params_ptr->ptr_row);
+ }
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < size(src_v); ++i) {
+ if (get<1>(coord_v(i)) < n) {
+ dst_v(i) = filled_vec;
+ }
+ }
+ }
+ }
+
+ template
+ CUTLASS_DEVICE auto // returns an Array
+ visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
+ Array const& frg_acc) {
+ Tensor rRow_frg = recast>(coalesce(tC_rRow));
+ return rRow_frg(column_idx);
+ }
+ };
+
+ template
+ CUTLASS_DEVICE auto
+ get_callbacks(
+ gemm::GemmCoord threadblock_tile_offset,
+ int thread_idx,
+ ProblemShape problem_shape
+ ) {
+ Tensor mRow = make_tensor(
+ make_gmem_ptr(params_ptr->ptr_row),
+ problem_shape,
+ params_ptr->dRow);
+
+ // VECTOR, FRAGMENT_COLUMN
+ Tensor tC_gRow = recast(
+ ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
+ )(_,_,_0{},_0{},_0{},_0{});
+ Tensor tC_rRow = make_tensor_like(tC_gRow);
+
+ // Generate the pred tensor
+ Tensor cRow = make_identity_tensor(mRow.shape());
+ Tensor tC_cRow = outer_partition(
+ ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
+ Shape>{},
+ (_0{})
+ );
+
+ return Callbacks<
+ decltype(tC_gRow), decltype(tC_rRow),
+ decltype(tC_cRow), ProblemShape>(
+ cute::move(tC_gRow),
+ cute::move(tC_rRow),
+ cute::move(tC_cRow),
+ problem_shape,
+ params_ptr
+ );
+ }
+
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
+template<
+ class ThreadMap,
+ class Element,
+ class StrideMNL
+>
+struct VisitorRowOrZeroBroadcast {
+
+ // This struct has been modified to remove null_default (because it's always 0)
+ struct Arguments {
+ Element const* ptr_row = nullptr;
+ StrideMNL dRow = {};
+ };
+
+ using Params = Arguments;
+
+ template
+ static constexpr Params
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
+ return args;
+ }
+
+ template
+ static size_t
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
+ return 0;
+ }
+
+ struct SharedStorage {};
+
+ // Global load type
+ static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value;
+ using VecType = uint_bit_t;
+ static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
+
+ CUTLASS_HOST_DEVICE
+ VisitorRowOrZeroBroadcast() { }
+
+ CUTLASS_HOST_DEVICE
+ VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
+ : params_ptr(¶ms) { }
+
+ Params const* params_ptr;
+
+ template
+ struct Callbacks : EmptyCallbacks {
+ CUTLASS_DEVICE
+ Callbacks(
+ GTensor&& tC_gRow,
+ RTensor&& tC_rRow,
+ CTensor&& tC_cRow,
+ ProblemShape problem_shape,
+ Params const* params_ptr
+ ):
+ tC_gRow(cute::forward(tC_gRow)),
+ tC_rRow(cute::forward(tC_rRow)),
+ tC_cRow(cute::forward(tC_cRow)),
+ n(get<1>(problem_shape)),
+ params_ptr(params_ptr) { }
+
+ GTensor tC_gRow;
+ RTensor tC_rRow;
+ CTensor tC_cRow;
+ Params const* params_ptr;
+ int n;
+
+ // This function is modified from VisitorRowBroadcast
+ CUTLASS_DEVICE void
+ begin_epilogue() {
+ clear(tC_rRow);
+ auto src_v = filter(tC_gRow);
+ auto coord_v = filter(tC_cRow);
+ auto dst_v = filter(tC_rRow);
+
+ if (params_ptr->ptr_row != nullptr) {
+ // In this case we are loading from a row vector and broadcasting
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < size(src_v); ++i) {
+ bool guard = get<1>(coord_v(i)) < n;
+ cutlass::arch::global_load(
+ dst_v(i), (void const*)&src_v(i), guard);
+ }
+ } else {
+ // In this case we are broadcasting 0
+ VecType filled_vec;
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < VecLength; i++) {
+ reinterpret_cast(&filled_vec)[i] = Element{0};
+ }
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < size(src_v); ++i) {
+ if (get<1>(coord_v(i)) < n) {
+ dst_v(i) = filled_vec;
+ }
+ }
+ }
+ }
+
+ template
+ CUTLASS_DEVICE auto // returns an Array
+ visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
+ Array const& frg_acc) {
+ Tensor rRow_frg = recast>(coalesce(tC_rRow));
+ return rRow_frg(column_idx);
+ }
+ };
+
+ template
+ CUTLASS_DEVICE auto
+ get_callbacks(
+ gemm::GemmCoord threadblock_tile_offset,
+ int thread_idx,
+ ProblemShape problem_shape
+ ) {
+ Tensor mRow = make_tensor(
+ make_gmem_ptr(params_ptr->ptr_row),
+ problem_shape,
+ params_ptr->dRow);
+
+ // VECTOR, FRAGMENT_COLUMN
+ Tensor tC_gRow = recast(
+ ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
+ )(_,_,_0{},_0{},_0{},_0{});
+ Tensor tC_rRow = make_tensor_like(tC_gRow);
+
+ // Generate the pred tensor
+ Tensor cRow = make_identity_tensor(mRow.shape());
+ Tensor tC_cRow = outer_partition(
+ ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
+ Shape>{},
+ (_0{})
+ );
+
+ return Callbacks<
+ decltype(tC_gRow), decltype(tC_rRow),
+ decltype(tC_cRow), ProblemShape>(
+ cute::move(tC_gRow),
+ cute::move(tC_rRow),
+ cute::move(tC_cRow),
+ problem_shape,
+ params_ptr
+ );
+ }
+
+};
+
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Column vector broadcast
+template<
+ class ThreadMap,
+ class Element,
+ class StrideMNL = Stride<_1,_0,_0>
+>
+struct VisitorColOrScalarBroadcast {
+
+ // This struct has been modified to have a bool indicating that ptr_col is a
+ // scalar that must be broadcast.
+ struct Arguments {
+ Element const* ptr_col = nullptr;
+ bool col_broadcast = true;
+ StrideMNL dCol = {};
+ };
+
+ using Params = Arguments;
+
+ template
+ static constexpr Params
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
+ return args;
+ }
+
+ template
+ static size_t
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
+ return 0;
+ }
+
+ struct SharedStorage { };
+
+ CUTLASS_HOST_DEVICE
+ VisitorColOrScalarBroadcast() { }
+
+ CUTLASS_HOST_DEVICE
+ VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
+ : params_ptr(¶ms) { }
+
+ Params const* params_ptr;
+
+ template
+ struct Callbacks : EmptyCallbacks {
+ CUTLASS_DEVICE
+ Callbacks(
+ GTensor&& tC_gCol,
+ RTensor&& tC_rCol,
+ CTensor&& tC_cCol,
+ ProblemShape problem_shape,
+ Params const* params_ptr
+ ):
+ tC_gCol(cute::forward(tC_gCol)),
+ tC_rCol(cute::forward(tC_rCol)),
+ tC_cCol(cute::forward(tC_cCol)),
+ m(get<0>(problem_shape)),
+ params_ptr(params_ptr) { }
+
+ GTensor tC_gCol;
+ RTensor tC_rCol;
+ CTensor tC_cCol;
+ Params const* params_ptr;
+ int m;
+
+ // This function is modified from VisitorColBroadcast
+ CUTLASS_DEVICE void
+ begin_epilogue() {
+ clear(tC_rCol);
+
+ Tensor pred = make_tensor(shape(tC_gCol));
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < size(pred); ++i) {
+ pred(i) = get<0>(tC_cCol(i)) < m;
+ }
+
+ if (params_ptr->col_broadcast) {
+ // In this case we are loading from a column vector and broadcasting
+ copy_if(pred, tC_gCol, tC_rCol);
+ } else {
+ // In this case we are loading from a scalar and broadcasting
+ auto dst_v = filter(tC_rCol);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < size(dst_v); ++i) {
+ if (pred(i)) {
+ dst_v(i) = *(params_ptr->ptr_col);
+ }
+ }
+ }
+ }
+
+ template
+ CUTLASS_DEVICE auto // returns an Array
+ visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
+ Array const& frg_acc) {
+ Array frg_col;
+ frg_col.fill(tC_rCol(row_idx,iter_idx));
+ return frg_col;
+ }
+ };
+
+ template
+ CUTLASS_DEVICE auto
+ get_callbacks(
+ gemm::GemmCoord threadblock_tile_offset,
+ int thread_idx,
+ ProblemShape problem_shape
+ ) {
+ Tensor mCol = make_tensor(
+ make_gmem_ptr(params_ptr->ptr_col),
+ problem_shape,
+ params_ptr->dCol);
+
+ // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
+ Tensor tC_gCol = group_modes<1,4>(
+ ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
+ Tensor tC_rCol = make_tensor_like(tC_gCol);
+
+ // Generate the pred tensor
+ Tensor cCol = make_identity_tensor(mCol.shape());
+ Tensor tC_cCol = group_modes<1,4>(
+ ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
+
+ return Callbacks<
+ decltype(tC_gCol), decltype(tC_rCol),
+ decltype(tC_cCol), ProblemShape>(
+ cute::move(tC_gCol),
+ cute::move(tC_rCol),
+ cute::move(tC_cCol),
+ problem_shape,
+ params_ptr
+ );
+ }
+};
+
+}
diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
new file mode 100644
index 000000000..58b1e8ff1
--- /dev/null
+++ b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
@@ -0,0 +1,447 @@
+/***************************************************************************************************
+ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ *POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+//
+// This file is a modified excerpt of
+// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
+// from https://github.com/NVIDIA/cutlass v3.5.0
+// It has been modified to support either row/column or scalar broadcasting
+// where the tensor being loaded from is always passed in via a device pointer.
+// This lets one compiled kernel handle all cases of per-tensor or
+// per-channel/per-token quantization.
+//
+// This interface also allows the scales to be passed in as tensors that
+// consistently reside on the device, which avoids an issue with a previous
+// implementation where scalars needed to be on the CPU since they
+// were passed in via float values. This created a potential performance hazard
+// if scales were initially on the device, and caused torch.compile graphs
+// breaks when moving scales to the CPU.
+//
+#pragma once
+
+// Turn off clang-format for the entire file to keep it close to upstream
+// clang-format off
+
+#include "cutlass/cutlass.h"
+#include "cutlass/arch/barrier.h"
+
+#include "cute/tensor.hpp"
+#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
+
+namespace cutlass::epilogue::fusion {
+
+using namespace cute;
+using namespace detail;
+
+// Row vector broadcast
+template<
+ int Stages,
+ class CtaTileShapeMNK,
+ class Element,
+ class StrideMNL = Stride<_0,_1,_0>,
+ int Alignment = 128 / sizeof_bits_v
+>
+struct Sm90RowOrScalarBroadcast {
+ static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
+ static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static
+ static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
+
+ struct SharedStorage {
+ array_aligned(CtaTileShapeMNK{})> smem;
+ };
+
+ // This struct has been modified to have a bool indicating that ptr_row is a
+ // scalar that must be broadcast, instead of containing a scalar that is
+ // valid if ptr_row is null.
+ struct Arguments {
+ Element const* ptr_row = nullptr;
+ bool row_broadcast = true;
+ StrideMNL dRow = {};
+ };
+
+ using Params = Arguments;
+
+ template
+ static constexpr Params
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
+ return args;
+ }
+
+ template
+ static bool
+ can_implement(ProblemShape const& problem_shape, Arguments const& args) {
+ return true;
+ }
+
+ template
+ static size_t
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
+ return 0;
+ }
+
+ template
+ static cutlass::Status
+ initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
+ CudaHostAdapter* cuda_adapter = nullptr) {
+ return cutlass::Status::kSuccess;
+ }
+
+ CUTLASS_HOST_DEVICE
+ Sm90RowOrScalarBroadcast() { }
+
+ CUTLASS_HOST_DEVICE
+ Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
+ : params(params)
+ , smem(const_cast(shared_storage.smem.data())) { }
+
+ Params params;
+ Element *smem = nullptr;
+
+ CUTLASS_DEVICE bool
+ is_producer_load_needed() const {
+ return false;
+ }
+
+ CUTLASS_DEVICE bool
+ is_C_load_needed() const {
+ return false;
+ }
+
+ CUTLASS_DEVICE bool
+ is_zero() const {
+ return (!params.row_broadcast && *(params.ptr_row) == Element(0));
+ }
+
+ template
+ CUTLASS_DEVICE auto
+ get_producer_load_callbacks(ProducerLoadArgs const& args) {
+ return EmptyProducerLoadCallbacks{};
+ }
+
+ template
+ struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
+ CUTLASS_DEVICE
+ ConsumerStoreCallbacks(
+ GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
+ GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
+ SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
+ CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
+ : tGS_gRow(tGS_gRow_)
+ , tGS_sRow(tGS_sRow_)
+ , tGS_cRow(tGS_cRow_)
+ , tiled_G2S(tiled_g2s_)
+ , tSR_sRow(tSR_sRow_)
+ , tSR_rRow(tSR_rRow_)
+ , tCcRow(tCcRow_)
+ , residue_tCcRow(residue_tCcRow_)
+ , params(params_) {}
+
+ GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
+ GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
+ GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
+ Tiled_G2S tiled_G2S;
+
+ SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+ SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+
+ CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+ ThrResidue residue_tCcRow; // (m, n)
+ ThrNum thr_num;
+ Params const& params;
+
+ CUTLASS_DEVICE void
+ begin() {
+ if (!params.row_broadcast) {
+ fill(tSR_rRow, *(params.ptr_row));
+ return;
+ }
+
+ auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
+ Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
+ Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
+ Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
+
+ for (int i = 0; i < size(tGS_gRow_flt); ++i) {
+ if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
+ continue; // OOB of SMEM,
+ }
+ if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
+ tGS_sRow_flt(i) = tGS_gRow_flt(i);
+ }
+ else {
+ tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
+ }
+ }
+ synchronize();
+ }
+
+ CUTLASS_DEVICE void
+ begin_loop(int epi_m, int epi_n) {
+ if (epi_m == 0) { // Assumes M-major subtile loop
+ if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
+ Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
+ Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
+ copy(tSR_sRow_flt, tSR_rRow_flt);
+ }
+ }
+
+ template
+ CUTLASS_DEVICE Array
+ visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) {
+ Array frg_row;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < FragmentSize; ++i) {
+ frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
+ }
+
+ return frg_row;
+ }
+ };
+
+ template <
+ bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
+ class... Args
+ >
+ CUTLASS_DEVICE auto
+ get_consumer_store_callbacks(ConsumerStoreArgs const& args) {
+ auto [M, N, K, L] = args.problem_shape_mnkl;
+ auto [m, n, k, l] = args.tile_coord_mnkl;
+ using ThreadCount = decltype(size(args.tiled_copy));
+
+ Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
+ Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
+ Tensor sRow = make_tensor(make_smem_ptr(smem),
+ make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
+ //// G2S: Gmem to Smem
+ auto tiled_g2s = make_tiled_copy(Copy_Atom{},
+ Layout< Shape<_1, ThreadCount>,
+ Stride<_0, _1>>{},
+ Layout<_1>{});
+ auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
+ Tensor tGS_gRow = thr_g2s.partition_S(gRow);
+ Tensor tGS_sRow = thr_g2s.partition_D(sRow);
+
+ //// G2S: Coord
+ auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
+ Tensor tGS_cRow = thr_g2s.partition_S(cRow);
+
+ //// S2R: Smem to Reg
+ Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
+ Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
+
+ return ConsumerStoreCallbacks(
+ tGS_gRow,
+ tGS_sRow,
+ tGS_cRow, tiled_g2s,
+ tSR_sRow,
+ tSR_rRow,
+ args.tCcD,
+ args.residue_cD,
+ ThreadCount{},
+ params);
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Column vector broadcast
+template<
+ int Stages,
+ class CtaTileShapeMNK,
+ class Element,
+ class StrideMNL = Stride<_1,_0,_0>,
+ int Alignment = 128 / sizeof_bits_v
+>
+struct Sm90ColOrScalarBroadcast {
+ static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
+ static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet");
+ static_assert(
+ (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias
+ (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias
+
+ // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
+ struct SharedStorage { };
+
+ // This struct has been modified to have a bool indicating that ptr_col is a
+ // scalar that must be broadcast, instead of containing a scalar that is
+ // valid if ptr_col is null.
+ struct Arguments {
+ Element const* ptr_col = nullptr;
+ bool col_broadcast = true;
+ StrideMNL dCol = {};
+ };
+
+ using Params = Arguments;
+
+ template
+ static constexpr Params
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
+ return args;
+ }
+
+ template
+ static bool
+ can_implement(ProblemShape const& problem_shape, Arguments const& args) {
+ return true;
+ }
+
+ template
+ static size_t
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
+ return 0;
+ }
+
+ template
+ static cutlass::Status
+ initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
+ CudaHostAdapter* cuda_adapter = nullptr) {
+ return cutlass::Status::kSuccess;
+ }
+
+ CUTLASS_DEVICE bool
+ is_producer_load_needed() const {
+ return false;
+ }
+
+ CUTLASS_DEVICE bool
+ is_C_load_needed() const {
+ return false;
+ }
+
+ CUTLASS_DEVICE bool
+ is_zero() const {
+ return (!params.col_broadcast && *(params.ptr_col) == Element(0));
+ }
+
+ CUTLASS_HOST_DEVICE
+ Sm90ColOrScalarBroadcast() { }
+
+ CUTLASS_HOST_DEVICE
+ Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
+ : params(params) { }
+
+ Params params;
+
+ template
+ CUTLASS_DEVICE auto
+ get_producer_load_callbacks(ProducerLoadArgs const& args) {
+ return EmptyProducerLoadCallbacks{};
+ }
+
+ template
+ struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
+ CUTLASS_DEVICE
+ ConsumerStoreCallbacks(
+ GTensor&& tCgCol,
+ RTensor&& tCrCol,
+ CTensor&& tCcCol,
+ ProblemShape problem_shape,
+ Params const& params
+ ):
+ tCgCol(cute::forward(tCgCol)),
+ tCrCol(cute::forward(tCrCol)),
+ tCcCol(cute::forward(tCcCol)),
+ m(get<0>(problem_shape)),
+ params(params) {}
+
+ GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+ RTensor tCrCol;
+ CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+ Params const& params;
+ int m;
+
+ CUTLASS_DEVICE void
+ begin() {
+ Tensor pred = make_tensor(shape(tCgCol));
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < size(pred); ++i) {
+ pred(i) = get<0>(tCcCol(i)) < m;
+ }
+
+ if (!params.col_broadcast) {
+ fill(tCrCol, *(params.ptr_col));
+ return;
+ }
+
+ // Filter so we don't issue redundant copies over stride-0 modes
+ // (only works if 0-strides are in same location, which is by construction)
+ copy_if(pred, filter(tCgCol), filter(tCrCol));
+ }
+
+ template
+ CUTLASS_DEVICE Array
+ visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) {
+ Array frg_col;
+ Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < FragmentSize; ++i) {
+ frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
+ }
+
+ return frg_col;
+ }
+
+ };
+
+ template <
+ bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
+ class... Args
+ >
+ CUTLASS_DEVICE auto
+ get_consumer_store_callbacks(ConsumerStoreArgs const& args) {
+
+ auto [M, N, K, L] = args.problem_shape_mnkl;
+ Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
+ Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+ mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
+ Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+
+ // Generate an identity tensor matching the shape of the global tensor and
+ // partition the same way, this will be used to generate the predicate
+ // tensor for loading
+ Tensor cCol = make_identity_tensor(mCol.shape());
+ Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
+ cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
+
+ return ConsumerStoreCallbacks(
+ cute::move(tCgCol),
+ cute::move(tCrCol),
+ cute::move(tCcCol),
+ args.problem_shape_mnkl,
+ params
+ );
+ }
+};
+
+}
diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
new file mode 100644
index 000000000..ad8c0067d
--- /dev/null
+++ b/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
@@ -0,0 +1,321 @@
+#pragma once
+
+#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
+
+/*
+ This file defines custom epilogues for fusing channel scales, token scales,
+ bias, and activation zero-points onto a GEMM operation using the
+ CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
+
+ Epilogues must contain a public type named EVTCompute of type Sm80EVT,
+ as well as a static prepare_args function that constructs an
+ EVTCompute::Arguments struct.
+*/
+
+namespace vllm::c2x {
+
+using namespace cute;
+
+/*
+ * This class provides the common load descriptors for the
+ * ScaledEpilogue[...] classes
+ */
+template
+struct ScaledEpilogueBase {
+ protected:
+ using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
+
+ template
+ using ColOrScalarLoad =
+ cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
+ OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>;
+
+ template
+ using RowOrScalarLoad =
+ cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
+ OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>;
+
+ template
+ using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
+ OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>;
+
+ template
+ using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
+ OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>;
+
+ template
+ using RowOrZeroLoad =
+ cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
+ OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>;
+
+ // This utility function constructs the arguments for the load descriptors
+ // from a tensor. It can handle both row and column, as well as row/column or
+ // scalar cases.
+ template
+ static auto args_from_tensor(torch::Tensor const& tensor) {
+ using Arguments = typename Descriptor::Arguments;
+ auto* data_ptr = static_cast(tensor.data_ptr());
+ if constexpr (std::is_same_v> ||
+ std::is_same_v>) {
+ return Arguments{data_ptr, tensor.numel() != 1};
+ } else {
+ // it would technically work but no use case as data_ptr is never nullptr
+ static_assert(!std::is_same_v>);
+ return Arguments{data_ptr};
+ }
+ }
+
+ // This overload handles the case where there might not be a tensor, in which
+ // case a nullptr is passed and a constant (0) is used.
+ template
+ static auto args_from_tensor(std::optional const& tensor) {
+ static_assert(std::is_same_v>);
+ using Arguments = typename Descriptor::Arguments;
+ auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr;
+ return Arguments{data_ptr};
+ }
+};
+
+/*
+ This epilogue function defines a quantized GEMM operation similar to
+ torch._scaled_mm.
+
+ A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
+ per-row. B can be quantized per-tensor or per-column.
+ Any combination of per-tensor and per-row or column is supported.
+ A and B must have symmetric quantization (zero point == 0).
+
+ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
+ scales are applied elementwise with numpy-style broadcasting.
+
+ ScaleA and ScaleB define the epilogue functions that apply the scales for
+ the A and B operands respectively. These scales may be either per-tensor or
+ per row or column.
+*/
+template
+struct ScaledEpilogue
+ : private ScaledEpilogueBase {
+ private:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+
+ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTCompute0 =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute =
+ cutlass::epilogue::threadblock::Sm80EVT;
+ using ArgumentType = typename EVTCompute::Arguments;
+
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+
+ typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
+ return ArgumentType{a_args, evt0_args, {}};
+ }
+};
+
+/*
+ * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
+ * This bias can also be used in the per-tensor azp case, where the activation
+ * zero point (azp) is used to compute an azp correction term,
+ * which is folded into the bias.
+ *
+ * The bias tensor must be per-output channel.
+ * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
+ */
+template
+struct ScaledEpilogueBias
+ : protected ScaledEpilogueBase {
+ protected:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+ using Bias = typename SUPER::template RowLoad;
+ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTCompute0 =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::homogeneous_multiply_add, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT;
+ using ArgumentType = typename EVTCompute::Arguments;
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales,
+ torch::Tensor const& bias) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+ auto bias_args = SUPER::template args_from_tensor(bias);
+
+ typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
+ return ArgumentType{a_args, evt0_args, bias_args, {}};
+ }
+};
+
+/*
+ * This epilogue directly supports per-tensor azp in int32 form.
+ * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
+ * term, which should already be multiplied with the scalar azp.
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
+ *
+ * This epilogue also supports bias, which remains per-channel.
+ */
+template
+struct ScaledEpilogueBiasAzp
+ : protected ScaledEpilogueBase {
+ private:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+ using Bias = typename SUPER::template RowOrZeroLoad;
+
+ // This is the full AZP term, azp * J @ B, shape (1,n)
+ using AzpWithAdj = typename SUPER::template RowLoad;
+
+ // Compute float(accum - azp_adj), both operands are int32_t
+ using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::minus, float, int32_t,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeAzp =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeScaleB =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::homogeneous_multiply_add, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using ArgumentType = typename EVTCompute::Arguments;
+
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales,
+ torch::Tensor const& azp_adj,
+ std::optional const& bias) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+ auto bias_args = SUPER::template args_from_tensor(bias);
+ auto azp_adj_args =
+ SUPER::template args_from_tensor(azp_adj);
+
+ typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{
+ b_args, evt_azp_args, {}};
+ return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
+ }
+};
+
+/*
+ * This epilogue supports per-token azp by computing and applying
+ * the correction term using a rank-1 update. If the term were materialized,
+ * it would require O(m*n) space, and this way it only requires O(m+n) space.
+ * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
+ * point for each row of A.
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
+ *
+ * This epilogue also supports bias, which remains per-channel.
+ */
+template
+struct ScaledEpilogueBiasAzpToken
+ : protected ScaledEpilogueBase {
+ private:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+ using Bias = typename SUPER::template RowOrZeroLoad;
+
+ // Per-token azp term, shape (m,1)
+ using Azp = typename SUPER::template ColLoad;
+
+ // This is the AZP adjustment term, J @ B, shape (1,n)
+ using AzpAdj = typename SUPER::template RowLoad;
+
+ // Compute azp * azp_adj
+ using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, int32_t, int32_t,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeAzp =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ // Compute float(accum - azp*azp_adj), all operands are int32_t
+ using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::minus, float, int32_t,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeAcc =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTComputeScaleB =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::homogeneous_multiply_add, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using ArgumentType = typename EVTCompute::Arguments;
+
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales,
+ torch::Tensor const& azp_adj,
+ torch::Tensor const& azp,
+ std::optional const& bias) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+ auto bias_args = SUPER::template args_from_tensor(bias);
+ auto azp_args = SUPER::template args_from_tensor(azp);
+ auto azp_adj_args =
+ SUPER::template args_from_tensor(azp_adj);
+
+ typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
+ typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{
+ b_args, evt_acc_args, {}};
+ return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
+ }
+};
+
+}; // namespace vllm::c2x
diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
new file mode 100644
index 000000000..c43eea0a0
--- /dev/null
+++ b/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
@@ -0,0 +1,450 @@
+#pragma once
+
+#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
+#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
+
+/*
+ This file defines custom epilogues for fusing channel scales, token scales,
+ bias, and activation zero-points onto a GEMM operation using the
+ CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
+
+ Epilogues must contain a public type named EVTCompute of type Sm90EVT,
+ as well as a static prepare_args function that constructs an
+ EVTCompute::Arguments struct.
+*/
+
+namespace vllm::c3x {
+
+using namespace cute;
+
+template
+struct identity {
+ CUTLASS_HOST_DEVICE
+ T operator()(T lhs) const { return lhs; }
+};
+
+template
+struct TrivialEpilogue {
+ private:
+ using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
+ using Compute = cutlass::epilogue::fusion::Sm90Compute<
+ cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute = cutlass::epilogue::fusion::Sm90EVT;
+ using ArgumentType = typename EVTCompute::Arguments;
+
+ template
+ static ArgumentType prepare_args(Args... args) {
+ return {};
+ }
+};
+
+/*
+ * This class provides the common load descriptors for the
+ * ScaledEpilogue[...] classes
+ */
+template
+struct ScaledEpilogueBase {
+ protected:
+ using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
+
+ template
+ using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
+ 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>;
+
+ template
+ using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
+ 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>;
+
+ // Don't want to support nullptr by default
+ template
+ using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
+ 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>,
+ 128 / sizeof_bits_v, EnableNullPtr>;
+
+ // Don't want to support nullptr by default
+ template
+ using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
+ 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>,
+ 128 / sizeof_bits_v, EnableNullPtr>;
+
+ template
+ using ColOrScalarLoadArray =
+ cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
+ 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>;
+
+ template
+ using RowOrScalarLoadArray =
+ cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
+ 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>;
+
+ // This utility function constructs the arguments for the load descriptors
+ // from a tensor. It can handle both row and column, as well as row/column or
+ // scalar cases.
+ template
+ static auto args_from_tensor(torch::Tensor const& tensor) {
+ using Arguments = typename Descriptor::Arguments;
+ auto* data_ptr = static_cast(tensor.data_ptr());
+ if constexpr (std::is_same_v> ||
+ std::is_same_v>) {
+ return Arguments{data_ptr, tensor.numel() != 1};
+ } else {
+ static_assert(!std::is_same_v> &&
+ !std::is_same_v>);
+ return Arguments{data_ptr};
+ }
+ }
+
+ // This overload handles the case where there might not be a tensor, in which
+ // case a nullptr is passed and a constant (0) is used.
+ template
+ static auto args_from_tensor(std::optional const& tensor) {
+ using Arguments = typename Descriptor::Arguments;
+ auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr;
+ static_assert(std::is_same_v> ||
+ std::is_same_v>);
+ return Arguments{data_ptr};
+ }
+
+ template